Exemple #1
0
    def forward(self, input, lengths=None, hidden=None):
        """
        Args:
            input (LongTensor): len x batch x nfeat
            lengths (LongTensor): batch
            hidden: Initial hidden state.

        Returns:
            hidden_t (FloatTensor): Pair of layers x batch x rnn_size - final
                                    Encoder state
            outputs (FloatTensor):  len x batch x rnn_size -  Memory bank
        """
        # CHECKS
        s_len, n_batch, n_feats = input.size()
        if lengths is not None:
            n_batch_, = lengths.size()
            aeq(n_batch, n_batch_)
        # END CHECKS

        emb = self.embeddings(input)
        s_len, n_batch, vec_size = emb.size()

        if self.encoder_type == "mean":
            # No RNN, just take mean as final state.
            mean = emb.mean(0) \
                   .expand(self.layers, n_batch, vec_size)
            return (mean, mean), emb

        elif self.encoder_type == "transformer":
            # Self-attention tranformer.
            out = emb.transpose(0, 1).contiguous()
            for i in range(self.layers):
                out = self.transformer[i](out, input[:, :, 0].transpose(0, 1))
            return Variable(emb.data), out.transpose(0, 1).contiguous()
        else:
            # Standard RNN encoder.
            packed_emb = emb
            if lengths is not None:
                # Lengths data is wrapped inside a Variable.
                lengths = lengths.view(-1).tolist()
                packed_emb = pack(emb, lengths)
            outputs, hidden_t = self.rnn(packed_emb, hidden)
            if lengths:
                outputs = unpack(outputs)[0]
            return hidden_t, outputs
Exemple #2
0
    def forward(self, input, context, state):
        """
        Forward through the decoder.

        Args:
            input (LongTensor): a sequence of input tokens tensors
                                of size (len x batch x nfeats).
            context (FloatTensor): output(tensor sequence) from the Encoder
                        RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the Encoder RNN for
                                 initializing the decoder.

        Returns:
            outputs (FloatTensor): a Tensor sequence of output from the Decoder
                                   of shape (len x batch x hidden_size).
            state (FloatTensor): final hidden state from the Decoder.
            attns (dict of (str, FloatTensor)): a dictionary of different
                                type of attention Tensor from the Decoder
                                of shape (src_len x batch).
        """
        # Args Check
        assert isinstance(state, RNNDecoderState)
        input_len, input_batch, _ = input.size()
        contxt_len, contxt_batch, _ = context.size()
        aeq(input_batch, contxt_batch)
        # END Args Check

        # Run the forward pass of the RNN.
        hidden, outputs, attns, coverage = \
            self._run_forward_pass(input, context, state)

        # Update the DecoderState with the result.
        final_output = outputs[-1]
        state = RNNDecoderState(
            hidden, final_output.unsqueeze(0),
            coverage.unsqueeze(0) if coverage is not None else None)

        # Concatenates sequence of tensors along a new dimension.
        outputs = torch.stack(outputs)
        for k in attns:
            attns[k] = torch.stack(attns[k])

        return outputs, state, attns
Exemple #3
0
    def score(self, h_t, h_s):
        """
        h_t (FloatTensor): batch x dim
        h_s (FloatTensor): batch x src_len x dim
        returns scores (FloatTensor): batch x src_len:
            raw attention scores for each src index
        """

        # Check input sizes
        src_batch, _, src_dim = h_s.size()
        tgt_batch, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t = self.linear_in(h_t)
            return torch.bmm(h_s, h_t.unsqueeze(2)).squeeze(2)
        else:
            # MLP
            # batch x 1 x dim
            wq = self.linear_query(h_t).unsqueeze(1)
            # batch x src_len x dim
            uh = self.linear_context(h_s.contiguous())
            # batch x src_len x dim
            wquh = uh + wq.expand_as(uh)
            # batch x src_len x dim
            wquh = self.tanh(wquh)
            # batch x src_len
            return self.v(wquh.contiguous()).squeeze(2)
Exemple #4
0
    def forward(self, input):
        """
        Return the embeddings for words, and features if there are any.
        Args:
            input (LongTensor): len x batch x nfeat
        Return:
            emb (FloatTensor): len x batch x self.embedding_size
        """
        in_length, in_batch, nfeat = input.size()
        aeq(nfeat, len(self.emb_luts))

        emb = self.make_embedding(input)

        out_length, out_batch, emb_size = emb.size()
        aeq(in_length, out_length)
        aeq(in_batch, out_batch)
        aeq(emb_size, self.embedding_size)

        return emb
Exemple #5
0
    def forward(self, src_input):
        """
        Embed the words or utilize features and MLP.

        Args:
            src_input (LongTensor): len x batch x nfeat

        Return:
            emb (FloatTensor): len x batch x emb_size
                                emb_size is word_vec_size if there are no
                                features or the merge action is sum.
                                It is the sum of all feature dimensions
                                if the merge action is concatenate.
        """
        in_length, in_batch, nfeat = src_input.size()
        aeq(nfeat, len(self.emb_luts))

        if len(self.emb_luts) == 1:
            emb = self.word_lut(src_input.squeeze(2))
        else:
            feat_inputs = (feat.squeeze(2)
                           for feat in src_input.split(1, dim=2))
            features = [
                lut(feat) for lut, feat in zip(self.emb_luts, feat_inputs)
            ]
            emb = self.merge(features)

        if self.positional_encoding:
            emb = emb + Variable(
                self.pe[:emb.size(0), :1, :emb.size(2)].expand_as(emb))
            emb = self.dropout(emb)

        out_length, out_batch, emb_size = emb.size()
        aeq(in_length, out_length)
        aeq(in_length, out_length)
        aeq(emb_size, self.embedding_size)

        return emb
    def forward(self, src_input):
        """
        Embed the words or utilize features and MLP.

        Args:
            src_input (LongTensor): len x batch x nfeat

        Return:
            emb (FloatTensor): len x batch x emb_size
                                emb_size is word_vec_size 

        """
        in_length, in_batch, nfeat = src_input.size()
        aeq(nfeat, len(self.emb_luts))

        if len(self.emb_luts) == 1:
            emb = self.word_lut(src_input.squeeze(2))

        out_length, out_batch, emb_size = emb.size()
        aeq(in_length, out_length)
        aeq(in_length, out_length)
        aeq(emb_size, self.embedding_size)

        return emb
    def forward(self, input, context, src_words, tgt_words):
        # CHECKS
        n_batch, t_len, _ = input.size()
        n_batch_, s_len, _ = context.size()
        n_batch__, s_len_ = src_words.size()
        n_batch___, t_len_ = tgt_words.size()
        aeq(n_batch, n_batch_, n_batch__, n_batch___)
        aeq(s_len, s_len_)
        aeq(t_len, t_len_)
        # END CHECKS

        attn_mask = get_attn_padding_mask(tgt_words, tgt_words)
        dec_mask = torch.gt(
            attn_mask +
            self.mask[:, :attn_mask.size(1), :attn_mask.size(1)].expand_as(
                attn_mask), 0)

        pad_mask = get_attn_padding_mask(tgt_words, src_words)
        query, attn = self.self_attn(input, input, input, mask=dec_mask)
        mid, attn = self.context_attn(context, context, query, mask=pad_mask)
        output = self.feed_forward(mid)

        return output, attn
    def score(self, h_t, h_s):
        """
        h_t (FloatTensor): batch x tgt_len x dim
        h_s (FloatTensor): batch x src_len x dim
        returns scores (FloatTensor): batch x tgt_len x src_len:
            raw attention scores for each src index
        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
    def forward(self,
                input,
                src,
                context,
                state,
                fertility_vals=None,
                fert_dict=None,
                fert_sents=None,
                upper_bounds=None,
                test=False):
        """
        Forward through the decoder.

        Args:
            input (LongTensor):  (len x batch) -- Input tokens
            src (LongTensor)
            context:  (src_len x batch x rnn_size)  -- Memory bank
            state: an object initializing the decoder.

        Returns:
            outputs: (len x batch x rnn_size)
            final_states: an object of the same form as above
            attns: Dictionary of (src_len x batch)
        """
        # CHECKS
        t_len, n_batch = input.size()
        s_len, n_batch_, _ = src.size()
        s_len_, n_batch__, _ = context.size()
        aeq(n_batch, n_batch_, n_batch__)

        # aeq(s_len, s_len_)
        # END CHECKS
        if self.decoder_layer == "transformer":
            if state.previous_input:
                input = torch.cat([state.previous_input.squeeze(2), input], 0)
        emb = self.embeddings(input.unsqueeze(2))
        # n.b. you can increase performance if you compute W_ih * x for all
        # iterations in parallel, but that's only possible if
        # self.input_feed=False
        outputs = []

        # Setup the different types of attention.
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []
        if self.exhaustion_loss:
            attns["upper_bounds"] = []
        if self.fertility_loss:
            attns["predicted_fertility_vals"] = []
            attns["true_fertility_vals"] = []
        if self.decoder_layer == "transformer":
            # Tranformer Decoder.
            assert isinstance(state, TransformerDecoderState)
            output = emb.transpose(0, 1).contiguous()
            src_context = context.transpose(0, 1).contiguous()
            for i in range(self.layers):
                output, attn \
                    = self.transformer[i](output, src_context,
                                          src[:, :, 0].transpose(0, 1),
                                          input.transpose(0, 1))
            outputs = output.transpose(0, 1).contiguous()
            if state.previous_input:
                outputs = outputs[state.previous_input.size(0):]
                attn = attn[:, state.previous_input.size(0):].squeeze()
                attn = torch.stack([attn])
            attns["std"] = attn
            if self._copy:
                attns["copy"] = attn
            state = TransformerDecoderState(input.unsqueeze(2))
        else:
            assert isinstance(state, RNNDecoderState)
            output = state.input_feed.squeeze(0)
            hidden = state.hidden
            # CHECKS
            n_batch_, _ = output.size()
            aeq(n_batch, n_batch_)
            # END CHECKS

            coverage = state.coverage.squeeze(0) \
                if state.coverage is not None else None

            # NOTE: something goes wrong when I try to define a "upper_bounds"
            # variable here -- memory blows up. Apparently the presence of such
            # variable prevents the computation graph to be deleted after
            # processing each batch. I need to investigate this further.
            # A workaround for now is to do one round of softmax (without
            # upper bound constraints) followed by several rounds of constrained
            # softmax.
            # upper_bounds = Variable(torch.ones(attn.size()).cuda())
            # Standard RNN decoder.
            for i, emb_t in enumerate(emb.split(1)):

                # Initialize upper bounds for the current batch

                if upper_bounds is None:
                    # if not test:
                    # 	tgt_lengths = [torch.nonzero(input[:,i].data).size(0) for i in range(n_batch_)]
                    # 	tgt_lengths = torch.Tensor(tgt_lengths).cuda()
                    # else:
                    #    # Maybe the ratio of tgt_len and src_len from training set would be a better estimate
                    #	 tgt_lengths = torch.ones(n_batch_).cuda()
                    if self.predict_fertility:
                        # comp_tensor = torch.Tensor([float(emb.size(0)) / context.size(0)]).repeat(n_batch_, s_len_).cuda()
                        # comp_tensor = (tgt_lengths/s_len_).unsqueeze(1).repeat(1, s_len_).cuda()
                        # print("fertility_vals:", fertility_vals.data)
                        # max_word_coverage = Variable(torch.max(fertility_vals.data, comp_tensor))
                        max_word_coverage = fertility_vals.clone()
                    elif self.guided_fertility:
                        # comp_tensor = torch.Tensor([float(emb.size(0)) / context.size(0)]).repeat(n_batch_, s_len_).cuda()
                        # comp_tensor = (tgt_lengths/s_len_).unsqueeze(1).repeat(1, s_len_).cuda()
                        # import pdb; pdb.set_trace()
                        fertility_vals = Variable(
                            evaluation.getBatchFertilities(
                                fert_dict, src).transpose(1, 0).contiguous())
                        max_word_coverage = fertility_vals
                        # max_word_coverage = Variable(torch.max(fertility_vals, comp_tensor))
                    elif self.supervised_fertility:
                        # k should be index of first sentence in batch
                        predicted_fertility_vals = fertility_vals
                        true_fertility_vals = fert_sents[k:k + n_batch_]
                        if test:
                            max_word_coverage = predicted_fertility_vals
                        else:
                            max_word_coverage = true_fertility_vals
                    else:
                        # max_word_coverage = max(
                        #    self.fertility, float(emb.size(0)) / context.size(0))
                        max_word_coverage = Variable(
                            torch.Tensor([self.fertility
                                          ]).repeat(n_batch_, s_len_)).cuda()
                        # max_word_coverage = Variable(torch.max(torch.FloatTensor([self.fertility]).repeat(n_batch_).cuda(),
                        #				     tgt_lengths/s_len_).unsqueeze(1).repeat(1, s_len_))
                        #    upper_bounds = -attn + max_word_coverage
                        # else:
                        #    upper_bounds -= attn
                    upper_bounds = max_word_coverage

                # Use <SINK> token for absorbing remaining attention weight

                # import pdb; pdb.set_trace()
                upper_bounds[:, -1] = Variable(
                    100. * torch.ones(upper_bounds.size(0)))
                # if (upper_bounds.size(0) > torch.sum(torch.sum(upper_bounds, 1)).cpu().data.numpy())[0]:
                #    print("inv sum:", torch.sum(upper_bounds, 1))
                #    print("att:", attn)

                emb_t = emb_t.squeeze(0)
                if self.input_feed:
                    emb_t = torch.cat([emb_t, output], 1)

                rnn_output, hidden = self.rnn(emb_t, hidden)
                attn_output, attn = self.attn(rnn_output,
                                              context.transpose(0, 1),
                                              upper_bounds=upper_bounds)
                # import pdb; pdb.set_trace()
                # print_attention = True
                # if print_attention:
                #    attn_probs = attn.data.cpu().numpy()
                #    for k in range(attn_probs.shape[0]):
                #        print('\t'.join(str(val) for val in list(attn_probs[k, :])))

                upper_bounds -= attn
                # k_attn = 1
                # upper_bounds = torch.max(upper_bounds - k_attn * attn, Variable(torch.zeros(upper_bounds.size(0), upper_bounds.size(1)).cuda()))
                # if np.any(upper_bounds.cpu().data.numpy()<1):
                #     print("upper bounds less than 1.0")
                # print("attn: ", attn)
                # print("upper_bounds: ", upper_bounds)

                if self.context_gate is not None:
                    output = self.context_gate(emb_t, rnn_output, attn_output)
                    output = self.dropout(output)
                else:
                    output = self.dropout(attn_output)
                outputs += [output]
                attns["std"] += [attn]

                # COVERAGE
                if self._coverage:
                    coverage = (coverage + attn) if coverage else attn
                    attns["coverage"] += [coverage]

                # COPY
                if self._copy:
                    _, copy_attn = self.copy_attn(output,
                                                  context.transpose(0, 1))
                    attns["copy"] += [copy_attn]
                if self.exhaustion_loss:
                    attns["upper_bounds"] += [upper_bounds]
            if self.supervised_fertility:
                attns["true_fertility_vals"] += [true_fertility_vals]
                attns["predicted_fertility_vals"] += [predicted_fertility_vals]
            state = RNNDecoderState(
                hidden, output.unsqueeze(0),
                coverage.unsqueeze(0) if coverage is not None else None,
                upper_bounds)
            outputs = torch.stack(outputs)
            for k in attns:
                attns[k] = torch.stack(attns[k])
        return outputs, state, attns, upper_bounds
Exemple #10
0
    def forward(self, input, lengths=None, hidden=None):
        """
        Args:
            input (LongTensor): len x batch x nfeat
            lengths (LongTensor): batch
            hidden: Initial hidden state.
        Returns:
            hidden_t (FloatTensor): Pair of layers x batch x rnn_size - final
                                    Encoder state
            outputs (FloatTensor):  len x batch x rnn_size -  Memory bank
        """
        # CHECKS
        s_len, n_batch, n_feats = input.size()
        if lengths is not None:
            n_batch_, = lengths.size()
            aeq(n_batch, n_batch_)
        # END CHECKS

        emb = self.embeddings(input)
        s_len, n_batch, emb_dim = emb.size()

        if self.encoder_type == "mean":
            # No RNN, just take mean as final state.
            mean = emb.mean(0).expand(self.num_layers, n_batch, emb_dim)
            return (mean, mean), emb

        elif self.encoder_type == "transformer":
            # Self-attention tranformer.
            out = emb.transpose(0, 1).contiguous()
            words = input[:, :, 0].transpose(0, 1)
            # CHECKS
            out_batch, out_len, _ = out.size()
            w_batch, w_len = words.size()
            aeq(out_batch, w_batch)
            aeq(out_len, w_len)
            # END CHECKS

            # Make mask.
            padding_idx = self.embeddings.padding_idx
            mask = words.data.eq(padding_idx).unsqueeze(1) \
                .expand(w_batch, w_len, w_len)

            # Run the forward pass of every layer of the tranformer.
            for i in range(self.num_layers):
                out = self.transformer[i](out, mask)

            return Variable(emb.data), out.transpose(0, 1).contiguous()
        elif self.encoder_type == "cnn":
            out = emb.transpose(0, 1).contiguous()
            out, emb_remap = self.cnn(out)
            return emb_remap.transpose(0, 1).contiguous(),\
                out.transpose(0, 1).contiguous()
        else:
            # Standard RNN encoder.
            packed_emb = emb
            if lengths is not None:
                # Lengths data is wrapped inside a Variable.
                lengths = lengths.view(-1).tolist()
                packed_emb = pack(emb, lengths)
            outputs, hidden_t = self.rnn(packed_emb, hidden)
            if lengths:
                outputs = unpack(outputs)[0]
            return hidden_t, outputs
Exemple #11
0
    def forward(self,
                input,
                src,
                context,
                state,
                fertility_vals=None,
                fert_dict=None,
                fert_sents=None,
                upper_bounds=None,
                test=False):
        """
        Forward through the decoder.

        Args:
            input (LongTensor):  (len x batch) -- Input tokens
            src (LongTensor)
            context:  (src_len x batch x rnn_size)  -- Memory bank
            state: an object initializing the decoder.

        Returns:
            outputs: (len x batch x rnn_size)
            final_states: an object of the same form as above
            attns: Dictionary of (src_len x batch)
        """
        # CHECKS
        t_len, n_batch = input.size()
        s_len, n_batch_, _ = src.size()
        s_len_, n_batch__, _ = context.size()
        aeq(n_batch, n_batch_, n_batch__)

        # aeq(s_len, s_len_)
        # END CHECKS
        emb = self.embeddings(input.unsqueeze(2))
        # n.b. you can increase performance if you compute W_ih * x for all iterations in parallel, but that's only possible if self.input_feed=False
        outputs = []

        # Setup the different types of attention.
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []
        if self.exhaustion_loss:
            attns["upper_bounds"] = []

        assert isinstance(state, RNNDecoderState)
        output = state.input_feed.squeeze(0)
        hidden = state.hidden
        # CHECKS
        n_batch_, _ = output.size()
        aeq(n_batch, n_batch_)
        # END CHECKS

        coverage = state.coverage.squeeze(
            0) if state.coverage is not None else None

        for i, emb_t in enumerate(emb.split(1)):
            # Initialize upper bounds for the current batch
            if upper_bounds is None:
                upper_bounds = Variable(
                    torch.Tensor([self.fertility]).repeat(n_batch_,
                                                          s_len_)).cuda()

            # Use <SINK> token for absorbing remaining attention weight
            upper_bounds[:, -1] = Variable(100. *
                                           torch.ones(upper_bounds.size(0)))

            emb_t = emb_t.squeeze(0)
            if self.input_feed:
                emb_t = torch.cat([emb_t, output], 1)

            rnn_output, hidden = self.rnn(emb_t, hidden)
            attn_output, attn = self.attn(rnn_output,
                                          context.transpose(0, 1),
                                          upper_bounds=upper_bounds)

            upper_bounds -= attn
            if self.context_gate is not None:
                output = self.context_gate(emb_t, rnn_output, attn_output)
                output = self.dropout(output)
            else:
                output = self.dropout(attn_output)
            outputs += [output]
            attns["std"] += [attn]

            # COVERAGE
            if self._coverage:
                coverage = (coverage + attn) if coverage else attn
                attns["coverage"] += [coverage]

            # COPY
            if self._copy:
                _, copy_attn = self.copy_attn(output, context.transpose(0, 1))
                attns["copy"] += [copy_attn]
            if self.exhaustion_loss:
                attns["upper_bounds"] += [upper_bounds]

        state = RNNDecoderState(
            hidden, output.unsqueeze(0),
            coverage.unsqueeze(0) if coverage is not None else None,
            upper_bounds)
        outputs = torch.stack(outputs)
        for k in attns:
            attns[k] = torch.stack(attns[k])
        return outputs, state, attns, upper_bounds
Exemple #12
0
    def forward(self, input, context, coverage=None):
        """
        input (FloatTensor): batch x dim: decoder's rnn's output.
        context (FloatTensor): batch x src_len x dim: src hidden states
        coverage (FloatTensor): batch x src_len
        """

        # Check input sizes
        batch, sourceL, dim = context.size()
        batch_, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if self.mask is not None:
            beam_, batch_, sourceL_ = self.mask.size()
            aeq(batch, batch_*beam_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            context += self.linear_cover(cover).view_as(context)
            context = self.tanh(context)

        # compute attention scores, as in Luong et al.
        a_t = self.score(input, context)

        if self.mask is not None:
            a_t.data.masked_fill_(self.mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vector = self.sm(a_t)

        # the context vector c_t is the weighted average
        # over all the source hidden states
        c_t = torch.bmm(align_vector.unsqueeze(1), context).squeeze(1)

        # concatenate
        attn_h_t = self.linear_out(torch.cat([c_t, input], 1))
        if self.attn_type in ["general", "dot"]:
            attn_h_t = self.tanh(attn_h_t)

        # Check output sizes
        batch_, sourceL_ = align_vector.size()
        aeq(batch, batch_)
        aeq(sourceL, sourceL_)
        batch_, dim_ = attn_h_t.size()
        aeq(batch, batch_)
        aeq(dim, dim_)

        return attn_h_t, align_vector
Exemple #13
0
    def forward(self, input, context, coverage=None):
        """
        input (FloatTensor): batch x dim
        context (FloatTensor): batch x sourceL x dim
        coverage (FloatTensor): batch x sourceL
        """
        # Check input sizes
        batch, sourceL, dim = context.size()
        batch_, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if self.mask:
            beam_, batch_, sourceL_ = self.mask.size()
            aeq(batch, batch_ * beam_)
            aeq(sourceL, sourceL_)

        if coverage:
            context += self.linear_cover(coverage.view(-1).unsqueeze(1)) \
                           .view_as(context)
            context = self.tanh(context)

        # Alignment/Attention Function
        if self.attn_type == "dotprod":
            # batch x dim x 1
            targetT = self.linear_in(input).unsqueeze(2)
            # batch x sourceL
            attn = torch.bmm(context, targetT).squeeze(2)
        elif self.attn_type == "mlp":
            # batch x 1 x dim
            wq = self.linear_query(input).unsqueeze(1)
            # batch x sourceL x dim
            uh = self.linear_context(context.contiguous())
            # batch x sourceL x dim
            wquh = uh + wq.expand_as(uh)
            # batch x sourceL x dim
            wquh = self.mlp_tanh(wquh)
            # batch x sourceL
            attn = self.v(wquh.contiguous()).squeeze(2)

        if self.mask is not None:
            attn.data.masked_fill_(self.mask, -float('inf'))

        # SoftMax
        attn = self.sm(attn)

        # Compute context weighted by attention.
        # batch x 1 x sourceL
        attn3 = attn.view(attn.size(0), 1, attn.size(1))
        # batch x dim
        weightedContext = torch.bmm(attn3, context).squeeze(1)

        # Concatenate the input to context (Luong only)
        weightedContext = torch.cat((weightedContext, input), 1)
        weightedContext = self.linear_out(weightedContext)
        # if self.attn_type == "dotprod":
        weightedContext = self.tanh(weightedContext)

        # Check output sizes
        batch_, sourceL_ = attn.size()
        aeq(batch, batch_)
        aeq(sourceL, sourceL_)
        batch_, dim_ = weightedContext.size()
        aeq(batch, batch_)
        aeq(dim, dim_)

        return weightedContext, attn
Exemple #14
0
    def forward(self, input, context, coverage=None):
        """
        input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output.
        context (FloatTensor): batch x src_len x dim: src hidden states
        coverage (FloatTensor): None (not supported yet)
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = context.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if self.mask is not None:
            beam_, batch_, sourceL_ = self.mask.size()
            aeq(batch, batch_ * beam_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            context += self.linear_cover(cover).view_as(context)
            context = self.tanh(context)

        # compute attention scores, as in Luong et al.
        align = self.score(input, context)

        if self.mask is not None:
            mask_ = self.mask.view(batch, 1, sourceL)  # make it broardcastable
            align.data.masked_fill_(mask_, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch * targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, context)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
Exemple #15
0
    def forward(self, input, context, src_pad_mask, tgt_pad_mask):
        # Args Checks
        input_batch, input_len, _ = input.size()
        contxt_batch, contxt_len, _ = context.size()
        aeq(input_batch, contxt_batch)

        src_batch, t_len, s_len = src_pad_mask.size()
        tgt_batch, t_len_, t_len__ = tgt_pad_mask.size()
        aeq(input_batch, contxt_batch, src_batch, tgt_batch)
        aeq(t_len, t_len_, t_len__, input_len)
        aeq(s_len, contxt_len)
        # END Args Checks

        dec_mask = torch.gt(
            tgt_pad_mask + self.mask[:, :tgt_pad_mask.size(1), :tgt_pad_mask.
                                     size(1)].expand_as(tgt_pad_mask), 0)
        query, attn = self.self_attn(input, input, input, mask=dec_mask)
        mid, attn = self.context_attn(context,
                                      context,
                                      query,
                                      mask=src_pad_mask)
        output = self.feed_forward(mid)

        # CHECKS
        output_batch, output_len, _ = output.size()
        aeq(input_len, output_len)
        aeq(contxt_batch, output_batch)

        n_batch_, t_len_, s_len_ = attn.size()
        aeq(input_batch, n_batch_)
        aeq(contxt_len, s_len_)
        aeq(input_len, t_len_)
        # END CHECKS

        return output, attn
Exemple #16
0
    def _run_forward_pass(self, input, context, state):
        """
        Private helper for running the specific RNN forward pass.
        Must be overriden by all subclasses.

        Args:
            input (LongTensor): a sequence of input tokens tensors
                                of size (len x batch x nfeats).
            context (FloatTensor): output(tensor sequence) from the Encoder
                        RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the Encoder RNN for
                                 initializing the decoder.

        Returns:
            hidden (FloatTensor): final hidden state from the Decoder.
            outputs ([FloatTensor]): an array of output of every time
                                     step from the Decoder.
            attns (dict of (str, [FloatTensor]): a dictionary of different
                            type of attention Tensor array of every time
                            step from the Decoder.
            coverage (FloatTensor, optional): coverage from the Decoder.
        """
        assert not self._copy  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        # Initialize local and return variables.
        outputs = []
        attns = {"std": []}
        coverage = None

        emb = self.embeddings(input)
        assert emb.dim() == 3  # len x batch x embedding_dim

        # Run the forward pass of the RNN.
        rnn_output, hidden = self.rnn(emb, state.hidden)
        # Reuslt Check
        input_len, input_batch, _ = input.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(input_len, output_len)
        aeq(input_batch, output_batch)
        # END Reuslt Check

        # Calculate the attention.
        attn_outputs, attn_scores = self.attn(
            rnn_output.transpose(0, 1).contiguous(),  # (output_len, batch, d)
            context.transpose(0, 1)  # (contxt_len, batch, d)
        )
        attns["std"] = attn_scores

        # Calculate the context gate.
        if self.context_gate is not None:
            outputs = self.context_gate(
                emb.view(-1, emb.size(2)),
                rnn_output.view(-1, rnn_output.size(2)),
                attn_outputs.view(-1, attn_outputs.size(2)))
            outputs = outputs.view(input_len, input_batch, self.hidden_size)
            outputs = self.dropout(outputs)
        else:
            outputs = self.dropout(attn_outputs)  # (input_len, batch, d)

        # Return result.
        return hidden, outputs, attns, coverage
    def forward(self, input, lengths=None, hidden=None):
        """
        Args:
            input (LongTensor): len x batch x nfeat
            lengths (LongTensor): batch
            hidden: Initial hidden state.

        Returns:
            hidden_t (FloatTensor): Pair of layers x batch x rnn_size - final
                                    Encoder state
            outputs (FloatTensor):  len x batch x rnn_size -  Memory bank
        """
        # CHECKS
        s_len, n_batch, n_feats = input.size()
        if lengths is not None:
            _, n_batch_ = lengths.size()
            aeq(n_batch, n_batch_)
        # END CHECKS

        emb = self.embeddings(input)
        s_len, n_batch, vec_size = emb.size()

        if self.encoder_layer == "mean":
            # No RNN, just take mean as final state.
            mean = emb.mean(0) \
                .expand(self.layers, n_batch, vec_size)
            return (mean, mean), emb

        elif self.encoder_layer == "transformer":
            # Self-attention tranformer.
            out = emb.transpose(0, 1).contiguous()
            for i in range(self.layers):
                out = self.transformer[i](out, input[:, :, 0].transpose(0, 1))
            return Variable(emb.data), out.transpose(0, 1).contiguous()

        else:
            # import pdb; pdb.set_trace()
            # Standard RNN encoder.
            packed_emb = emb
            if lengths is not None:
                # Lengths data is wrapped inside a Variable.
                lengths = lengths.data.view(-1).tolist()
                packed_emb = pack(emb, lengths)
            outputs, hidden_t = self.rnn(packed_emb, hidden)
            if lengths:
                outputs = unpack(outputs)[0]
            if self.predict_fertility:
                if self.use_sigmoid_fertility:
                    fertility_vals = self.fertility * F.sigmoid(
                        self.fertility_out(
                            torch.cat([
                                outputs.view(
                                    -1,
                                    self.hidden_size * self.num_directions),
                                emb.view(-1, vec_size)
                            ],
                                      dim=1)))
                else:
                    fertility_vals = F.relu(
                        self.fertility_linear(
                            torch.cat([
                                outputs.view(
                                    -1,
                                    self.hidden_size * self.num_directions),
                                emb.view(-1, vec_size)
                            ],
                                      dim=1)))
                    fertility_vals = F.relu(
                        self.fertility_linear_2(fertility_vals))
                    fertility_vals = 1 + torch.exp(
                        self.fertility_out(fertility_vals))
                fertility_vals = fertility_vals.view(n_batch, s_len)
                # fertility_vals = fertility_vals / torch.sum(fertility_vals, 1).repeat(1, s_len) * s_len
            elif self.guided_fertility:
                fertility_vals = None  # evaluation.get_fertility()
            elif self.supervised_fertility:
                fertility_vals = F.relu(
                    self.sup_linear(
                        outputs.view(-1,
                                     self.hidden_size * self.num_directions)))
                fertility_vals = F.relu(self.sup_linear_2(fertility_vals))
                fertility_vals = 1 + torch.exp(fertility_vals)
            else:
                fertility_vals = None
            return hidden_t, outputs, fertility_vals
Exemple #18
0
    def forward(self, key, value, query, mask=None):
        # CHECKS
        batch, k_len, d = key.size()
        batch_, k_len_, d_ = value.size()
        aeq(batch, batch_)
        aeq(k_len, k_len_)
        aeq(d, d_)
        batch_, q_len, d_ = query.size()
        aeq(batch, batch_)
        aeq(d, d_)
        aeq(self.d_model % 8, 0)
        if mask is not None:
            batch_, q_len_, k_len_ = mask.size()
            aeq(batch_, batch)
            aeq(k_len_, k_len)
            aeq(q_len_ == q_len)
        # END CHECKS

        def shape_projection(x):
            b, l, d = x.size()
            return x.view(b, l, self.heads, self.d_k).transpose(1, 2) \
                    .contiguous().view(b * self.heads, l, self.d_k)

        def unshape_projection(x, q):
            b, l, d = q.size()
            return x.view(b, self.heads, l, self.d_k) \
                    .transpose(1, 2).contiguous() \
                    .view(b, l, self.heads * self.d_k)

        residual = query
        key_up = shape_projection(self.linear_keys(key))
        value_up = shape_projection(self.linear_values(value))
        query_up = shape_projection(self.linear_query(query))

        scaled = torch.bmm(query_up, key_up.transpose(1, 2))
        scaled = scaled / math.sqrt(self.d_k)
        bh, l, d_k = scaled.size()
        b = bh // self.heads
        if mask is not None:

            scaled = scaled.view(b, self.heads, l, d_k)
            mask = mask.unsqueeze(1).expand_as(scaled)
            scaled = scaled.masked_fill(Variable(mask), -float('inf')) \
                           .view(bh, l, d_k)
        attn = self.sm(scaled)
        # Return one attn
        top_attn = attn.view(b, self.heads, l, d_k)[:, 0, :, :].contiguous()

        drop_attn = self.dropout(self.sm(scaled))

        # values : (batch * 8) x qlen x dim
        out = unshape_projection(torch.bmm(drop_attn, value_up), residual)

        # Residual and layer norm
        res = self.res_dropout(out) + residual
        ret = self.layer_norm(res)

        # CHECK
        batch_, q_len_, d_ = ret.size()
        aeq(q_len, q_len_)
        aeq(batch, batch_)
        aeq(d, d_)
        # END CHECK
        return ret, top_attn
Exemple #19
0
    def forward(self, input, context, state):
        """
        Forward through the TransformerDecoder.

        Args:
            input (LongTensor): a sequence of input tokens tensors
                                of size (len x batch x nfeats).
            context (FloatTensor): output(tensor sequence) from the Encoder
                        RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the Encoder RNN for
                                 initializing the decoder.

        Returns:
            outputs (FloatTensor): a Tensor sequence of output from the Decoder
                                   of shape (len x batch x hidden_size).
            state (FloatTensor): final hidden state from the Decoder.
            attns (dict of (str, FloatTensor)): a dictionary of different
                                type of attention Tensor from the Decoder
                                of shape (src_len x batch).
        """
        # CHECKS
        assert isinstance(state, TransformerDecoderState)
        input_len, input_batch, _ = input.size()
        contxt_len, contxt_batch, _ = context.size()
        aeq(input_batch, contxt_batch)

        if state.previous_input is not None:
            input = torch.cat([state.previous_input, input], 0)

        src = state.src
        src_words = src[:, :, 0].transpose(0, 1)
        tgt_words = input[:, :, 0].transpose(0, 1)
        src_batch, src_len = src_words.size()
        tgt_batch, tgt_len = tgt_words.size()
        aeq(input_batch, contxt_batch, src_batch, tgt_batch)
        aeq(contxt_len, src_len)
        aeq(input_len, tgt_len)
        # END CHECKS

        # Initialize return variables.
        outputs = []
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []

        # Run the forward pass of the TransformerDecoder.
        emb = self.embeddings(input)
        output = emb.transpose(0, 1).contiguous()
        src_context = context.transpose(0, 1).contiguous()

        padding_idx = self.embeddings.padding_idx
        src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
            .expand(src_batch, tgt_len, src_len)
        tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \
            .expand(tgt_batch, tgt_len, tgt_len)

        for i in range(self.num_layers):
            output, attn \
                = self.transformer[i](output, src_context,
                                      src_pad_mask, tgt_pad_mask)

        # Process the result and update the attentions.
        outputs = output.transpose(0, 1).contiguous()
        if state.previous_input is not None:
            outputs = outputs[state.previous_input.size(0):]
            attn = attn[:, state.previous_input.size(0):].squeeze()
            attn = torch.stack([attn])
        attns["std"] = attn
        if self._copy:
            attns["copy"] = attn

        # Update the TransformerDecoderState.
        state = TransformerDecoderState(src, input)

        return outputs, state, attns
    def forward(self, input, src, context, state):
        """
        Forward through the decoder.

        Args:
            input (LongTensor):  (len x batch) -- Input tokens
            src (LongTensor)
            context:  (src_len x batch x rnn_size)  -- Memory bank
            state: an object initializing the decoder.

        Returns:
            outputs: (len x batch x rnn_size)
            final_states: an object of the same form as above
            attns: Dictionary of (src_len x batch)
        """
        # CHECKS
        t_len, n_batch = input.size()
        s_len, n_batch_, _ = src.size()
        s_len_, n_batch__, _ = context.size()
        aeq(n_batch, n_batch_, n_batch__)
        # aeq(s_len, s_len_)
        # END CHECKS

        emb = self.embeddings(input.unsqueeze(2))

        # n.b. you can increase performance if you compute W_ih * x for all
        # iterations in parallel, but that's only possible if
        # self.input_feed=False
        outputs = []

        # Setup the different types of attention.
        attns = {"std": []}
        if self._coverage:
            attns["coverage"] = []

        assert isinstance(state, RNNDecoderState)
        output = state.input_feed.squeeze(0)
        hidden = state.hidden
        # CHECKS
        n_batch_, _ = output.size()
        aeq(n_batch, n_batch_)
        # END CHECKS

        coverage = state.coverage.squeeze(0) \
            if state.coverage is not None else None

        # Standard RNN decoder.
        for i, emb_t in enumerate(emb.split(1)):
            emb_t = emb_t.squeeze(0)
            if self.input_feed:
                emb_t = torch.cat([emb_t, output], 1)

            rnn_output, hidden = self.rnn(emb_t, hidden)
            attn_output, attn = self.attn(rnn_output, context.transpose(0, 1))

            output = self.dropout(attn_output)
            outputs += [output]
            attns["std"] += [attn]

            # COVERAGE
            if self._coverage:
                coverage = (coverage + attn) if coverage is not None else attn
                attns["coverage"] += [coverage]

        state = RNNDecoderState(
            hidden, output.unsqueeze(0),
            coverage.unsqueeze(0) if coverage is not None else None)
        outputs = torch.stack(outputs)
        for k in attns:
            attns[k] = torch.stack(attns[k])
        return outputs, state, attns
Exemple #21
0
    def forward(self, input, src, context, state):
        """
        Forward through the decoder.

        Args:
            input (LongTensor):  (len x batch) -- Input tokens
            src (LongTensor)
            context:  (src_len x batch x rnn_size)  -- Memory bank
            state: an object initializing the decoder.

        Returns:
            outputs: (len x batch x rnn_size)
            final_states: an object of the same form as above
            attns: Dictionary of (src_len x batch)
        """
        # CHECKS
        t_len, n_batch = input.size()
        s_len, n_batch_, _ = src.size()
        s_len_, n_batch__, _ = context.size()
        aeq(n_batch, n_batch_, n_batch__)
        # aeq(s_len, s_len_)
        # END CHECKS
        if self.decoder_type == "transformer":
            if state.previous_input:
                input = torch.cat([state.previous_input.squeeze(2), input], 0)

        emb = self.embeddings(input.unsqueeze(2))

        # n.b. you can increase performance if you compute W_ih * x for all
        # iterations in parallel, but that's only possible if
        # self.input_feed=False
        outputs = []

        # Setup the different types of attention.
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        if self.decoder_type == "transformer":
            # Tranformer Decoder.
            assert isinstance(state, TransformerDecoderState)
            output = emb.transpose(0, 1).contiguous()
            src_context = context.transpose(0, 1).contiguous()
            for i in range(self.layers):
                output, attn \
                    = self.transformer[i](output, src_context,
                                          src[:, :, 0].transpose(0, 1),
                                          input.transpose(0, 1))
            outputs = output.transpose(0, 1).contiguous()
            if state.previous_input:
                outputs = outputs[state.previous_input.size(0):]
                attn = attn[:, state.previous_input.size(0):].squeeze()
                attn = torch.stack([attn])
            attns["std"] = attn
            if self._copy:
                attns["copy"] = attn
            state = TransformerDecoderState(input.unsqueeze(2))
        elif self.input_feed:
            assert isinstance(state, RNNDecoderState)
            output = state.input_feed.squeeze(0)
            hidden = state.hidden
            # CHECKS
            n_batch_, _ = output.size()
            aeq(n_batch, n_batch_)
            # END CHECKS

            coverage = state.coverage.squeeze(0) \
                if state.coverage is not None else None

            # Standard RNN decoder.
            for i, emb_t in enumerate(emb.split(1)):
                emb_t = emb_t.squeeze(0)
                if self.input_feed:
                    emb_t = torch.cat([emb_t, output], 1)

                rnn_output, hidden = self.rnn(emb_t, hidden)
                attn_output, attn = self.attn(rnn_output,
                                              context.transpose(0, 1))
                if self.context_gate is not None:
                    output = self.context_gate(emb_t, rnn_output, attn_output)
                    output = self.dropout(output)
                else:
                    output = self.dropout(attn_output)
                outputs += [output]
                attns["std"] += [attn]

                # COVERAGE
                if self._coverage:
                    coverage = coverage + attn \
                               if coverage is not None else attn
                    attns["coverage"] += [coverage]

                # COPY
                if self._copy:
                    _, copy_attn = self.copy_attn(output,
                                                  context.transpose(0, 1))
                    attns["copy"] += [copy_attn]
            state = RNNDecoderState(
                hidden, output.unsqueeze(0),
                coverage.unsqueeze(0) if coverage is not None else None)
            outputs = torch.stack(outputs)
            for k in attns:
                attns[k] = torch.stack(attns[k])
        else:
            assert isinstance(state, RNNDecoderState)
            assert emb.dim() == 3

            assert not self._coverage
            assert state.coverage is None

            # TODO: copy
            assert not self._copy

            hidden = state.hidden
            rnn_output, hidden = self.rnn(emb, hidden)

            # CHECKS
            t_len_, n_batch_, _ = rnn_output.size()
            aeq(n_batch, n_batch_)
            aeq(t_len, t_len_)
            # END CHECKS

            attn_outputs, attn_scores = self.attn(
                rnn_output.transpose(0, 1).contiguous(),  # (batch, t_len, d)
                context.transpose(0, 1)  # (batch, s_len, d)
            )

            if self.context_gate is not None:
                outputs = self.context_gate(
                    emb.view(-1, emb.size(2)),
                    rnn_output.view(-1, rnn_output.size(2)),
                    attn_outputs.view(-1, attn_outputs.size(2)))
                outputs = outputs.view(t_len, n_batch, self.hidden_size)
                outputs = self.dropout(outputs)
            else:
                outputs = self.dropout(attn_outputs)  # (t_len, batch, d)
            state = RNNDecoderState(hidden, outputs[-1].unsqueeze(0), None)
            attns["std"] = attn_scores

        return outputs, state, attns
Exemple #22
0
    def forward(self, input, context, coverage=None, upper_bounds=None):
        """
        input (FloatTensor): batch x dim
        context (FloatTensor): batch x sourceL x dim
        coverage (FloatTensor): batch x sourceL
        upper_bounds (FloatTensor): batch x sourceL
        """
        # Check input sizes
        batch, sourceL, dim = context.size()
        batch_, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if self.mask is not None:
            beam_, batch_, sourceL_ = self.mask.size()
            aeq(batch, batch_ * beam_)
            aeq(sourceL, sourceL_)

        if coverage:
            context += self.linear_cover(coverage.view(-1).unsqueeze(1)) \
                           .view_as(context)
            context = self.tanh(context)
        # Alignment/Attention Function
        if self.attn_type == "dotprod":
            # batch x dim x 1
            targetT = self.linear_in(input).unsqueeze(2)
            # batch x sourceL
            attn = torch.bmm(context, targetT).squeeze(2)
        elif self.attn_type == "mlp":
            # batch x dim x 1
            wq = self.linear_query(input).unsqueeze(1)
            # batch x sourceL x dim
            uh = self.linear_context(context.contiguous())
            # batch x sourceL x dim
            wquh = uh + wq.expand_as(uh)
            # batch x sourceL x dim
            wquh = self.tanh(wquh)
            # batch x sourceL
            #print("self.v: ", self.v.weight)
            attn = self.v(wquh.contiguous()).squeeze()

        # EXPERIMENTAL

        if upper_bounds is not None and 'constrained' in self.attn_transform and self.c_attn != 0.0:
            indices = torch.arange(0, upper_bounds.size(1) - 1).cuda().long()
            uu = torch.index_select(upper_bounds.data, 1, indices)
            attn = attn + self.c_attn * Variable(
                torch.cat((uu, torch.zeros(upper_bounds.size(0)).cuda()), 1))

        if self.mask is not None:
            attn.data.masked_fill_(self.mask, -float('inf'))
        if self.attn_transform == 'constrained_softmax':
            if upper_bounds is None:
                attn = nn.Softmax()(attn)
            else:
                # assert round(np.sum(upper_bounds.cpu().data.numpy()), 5) >= 1.0, pdb.set_trace()
                attn = self.sm(attn, upper_bounds)
        elif self.attn_transform == 'constrained_sparsemax':
            if upper_bounds is None:
                attn = Sparsemax()(attn)
            else:
                attn = self.sm(attn, upper_bounds)
        else:
            attn = self.sm(attn)
            #if upper_bounds is None:
            #    attn = self.sm(attn)
            #else:
            #    attn = self.sm(attn - upper_bounds)

        # Compute context weighted by attention.
        # batch x 1 x sourceL
        attn3 = attn.view(attn.size(0), 1, attn.size(1))
        # batch x dim
        weightedContext = torch.bmm(attn3, context).squeeze(1)
        # Concatenate the input to context (Luong only)
        if self.attn_type == "dotprod":
            weightedContext = torch.cat((weightedContext, input), 1)
            weightedContext = self.linear_out(weightedContext)
            weightedContext = self.tanh(weightedContext)

        # Check output sizes
        batch_, sourceL_ = attn.size()
        aeq(batch, batch_)
        aeq(sourceL, sourceL_)
        batch_, dim_ = weightedContext.size()
        aeq(batch, batch_)
        aeq(dim, dim_)

        return weightedContext, attn