class NSRM(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(NSRM, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1
        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.query_encoder = Encoder(self.config.emsize,
                                     self.config.nhid_query, True, self.config)
        self.document_encoder = Encoder(self.config.emsize,
                                        self.config.nhid_doc, True,
                                        self.config)
        self.session_encoder = Encoder(
            self.config.nhid_query * self.num_directions,
            self.config.nhid_session, False, self.config)
        self.projection = nn.Linear(
            (self.config.nhid_query * self.num_directions) +
            self.config.nhid_session,
            self.config.nhid_doc * self.num_directions)
        self.decoder = Decoder(self.config.emsize, self.config.nhid_session,
                               len(self.dictionary), self.config)

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)

    @staticmethod
    def compute_decoding_loss(logits, target, seq_idx, length):
        """
        Compute negative log-likelihood loss for a batch of predictions.
        :param logits: 2d tensor [batch_size x vocab_size]
        :param target: 2d tensor [batch_size x 1]
        :param seq_idx: an integer represents the current index of the sequences
        :param length: 1d tensor [batch_size], represents each sequences' true length
        :return: total loss over the input mini-batch [autograd Variable] and number of loss elements
        """
        losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1))
        mask = helper.mask(length, seq_idx)  # mask: batch x 1
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if not num_non_zero_elem:
            return losses.sum(), 0
        else:
            return losses.sum(), num_non_zero_elem[0]

    @staticmethod
    def compute_click_loss(logits, target):
        """
        Compute logistic loss for a batch of clicks. Return average loss for the input mini-batch.
        :param logits: 2d tensor [batch_size x num_clicks_per_query]
        :param target: 2d tensor [batch_size x num_clicks_per_query]
        :return: average loss over batch [autograd Variable]
        """
        # taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L695
        neg_abs = -logits.abs()
        loss = logits.clamp(min=0) - logits * target + (1 +
                                                        neg_abs.exp()).log()
        return loss.mean()

    def forward(self, batch_session, length, batch_clicks, click_labels):
        """
        Forward function of the neural click model. Return average loss for a batch of sessions.
        :param batch_session: 3d tensor [batch_size x session_length x max_query_length]
        :param length: 2d tensor [batch_size x session_length]
        :param batch_clicks: 4d tensor [batch_size x session_length x num_rel_docs_per_query x max_document_length]
        :param click_labels: 3d tensor [batch_size x session_length x num_rel_docs_per_query]
        :return: average loss over batch [autograd Variable]
        """
        # query level encoding
        embedded_queries = self.embedding(
            batch_session.view(-1, batch_session.size(-1)))
        if self.config.model == 'LSTM':
            encoder_hidden, encoder_cell = self.query_encoder.init_weights(
                embedded_queries.size(0))
            output, hidden = self.query_encoder(embedded_queries,
                                                (encoder_hidden, encoder_cell))
        else:
            encoder_hidden = self.query_encoder.init_weights(
                embedded_queries.size(0))
            output, hidden = self.query_encoder(embedded_queries,
                                                encoder_hidden)

        encoded_queries = torch.max(output, 1)[0].squeeze(1)
        # encoded_queries = batch_size x num_queries_in_a_session x hidden_size
        encoded_queries = encoded_queries.view(*batch_session.size()[:-1], -1)

        # document level encoding
        embedded_clicks = self.embedding(
            batch_clicks.view(-1, batch_clicks.size(-1)))
        if self.config.model == 'LSTM':
            encoder_hidden, encoder_cell = self.document_encoder.init_weights(
                embedded_clicks.size(0))
            output, hidden = self.document_encoder(
                embedded_clicks, (encoder_hidden, encoder_cell))
        else:
            encoder_hidden = self.document_encoder.init_weights(
                embedded_clicks.size(0))
            output, hidden = self.document_encoder(embedded_clicks,
                                                   encoder_hidden)

        encoded_clicks = torch.max(output, 1)[0].squeeze(1)
        # encoded_clicks = batch_size x num_queries_in_a_session x num_rel_docs_per_query x hidden_size
        encoded_clicks = encoded_clicks.view(*batch_clicks.size()[:-1], -1)

        # session level encoding
        sess_hidden = self.session_encoder.init_weights(
            encoded_queries.size(0))
        sess_output = Variable(
            torch.zeros(self.config.batch_size, 1, self.config.nhid_session))
        if self.config.cuda:
            sess_output = sess_output.cuda()
        hidden_states, cell_states = [], []
        click_loss = 0
        for idx in range(encoded_queries.size(1)):
            combined_rep = torch.cat(
                (sess_output.squeeze(), encoded_queries[:, idx, :]), 1)
            combined_rep = self.projection(combined_rep)
            combined_rep = combined_rep.unsqueeze(1).expand(
                *encoded_clicks[:, idx, :, :].size())
            click_score = torch.sum(
                torch.mul(combined_rep, encoded_clicks[:, idx, :, :]),
                2).squeeze(2)
            click_loss += self.compute_click_loss(click_score,
                                                  click_labels[:, idx, :])
            # update session state using query representations
            sess_output, sess_hidden = self.session_encoder(
                encoded_queries[:, idx, :].unsqueeze(1), sess_hidden)
            hidden_states.append(sess_hidden[0])
            cell_states.append(sess_hidden[1])

        click_loss = click_loss / encoded_queries.size(1)

        hidden_states = torch.stack(hidden_states, 2).squeeze(0)
        cell_states = torch.stack(cell_states, 2).squeeze(0)

        # decoding in sequence-to-sequence learning
        hidden_states = hidden_states[:, :-1, :].contiguous().view(
            -1, hidden_states.size(-1)).unsqueeze(0)
        cell_states = cell_states[:, :-1, :].contiguous().view(
            -1, cell_states.size(-1)).unsqueeze(0)
        decoder_input = batch_session[:, 1:, :].contiguous().view(
            -1, batch_session.size(-1))
        target_length = length[:, 1:].contiguous().view(-1)
        input_variable = Variable(
            torch.LongTensor(decoder_input.size(0)).fill_(
                self.dictionary.word2idx[self.dictionary.start_token]))
        if self.config.cuda:
            input_variable = input_variable.cuda()
            hidden_states = hidden_states.cuda()
            cell_states = cell_states.cuda()

        # Initialize hidden states of decoder with the last hidden states of the session encoder
        decoder_hidden = (hidden_states, cell_states)
        decoding_loss = 0
        total_local_decoding_loss_element = 0
        for idx in range(decoder_input.size(1)):
            if idx != 0:
                input_variable = decoder_input[:, idx - 1]
            embedded_decoder_input = self.embedding(input_variable).unsqueeze(
                1)
            decoder_output, decoder_hidden = self.decoder(
                embedded_decoder_input, decoder_hidden)
            target_variable = decoder_input[:, idx]
            local_loss, num_local_loss = self.compute_decoding_loss(
                decoder_output, target_variable, idx, target_length)
            decoding_loss += local_loss
            total_local_decoding_loss_element += num_local_loss

        if total_local_decoding_loss_element > 0:
            decoding_loss = decoding_loss / total_local_decoding_loss_element

        return click_loss + decoding_loss
class Sequence2Sequence(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(Sequence2Sequence, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.query_encoder = Encoder(self.config.emsize,
                                     self.config.nhid_query, self.config)
        self.session_encoder = Encoder(self.config.nhid_query,
                                       self.config.nhid_session, self.config)
        self.decoder = Decoder(self.config.emsize, self.config.nhid_session,
                               len(self.dictionary), self.config)

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)

    @staticmethod
    def compute_loss(logits, target, seq_idx, length):
        # logits: batch x vocab_size, target: batch x 1
        losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1))
        # mask: batch x 1
        mask = helper.mask(length, seq_idx)
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if not num_non_zero_elem:
            loss = losses.sum()
        else:
            loss = losses.sum() / num_non_zero_elem[0]
        return loss

    def forward(self, batch_session, length):
        """"Defines the forward computation of the question classifier."""
        embedded_input = self.embedding(
            batch_session.view(-1, batch_session.size(-1)))
        if self.config.model == 'LSTM':
            encoder_hidden, encoder_cell = self.query_encoder.init_weights(
                embedded_input.size(0))
            output, hidden = self.query_encoder(embedded_input,
                                                (encoder_hidden, encoder_cell))
        else:
            encoder_hidden = self.query_encoder.init_weights(
                embedded_input.size(0))
            output, hidden = self.query_encoder(embedded_input, encoder_hidden)

        if self.config.bidirection:
            output = torch.div(
                torch.add(
                    output[:, :, 0:self.config.nhid_query],
                    output[:, :,
                           self.config.nhid_query:2 * self.config.nhid_query]),
                2)

        session_input = output[:, -1, :].contiguous().view(
            batch_session.size(0), batch_session.size(1), -1)
        # session level encoding
        sess_hidden = self.session_encoder.init_weights(session_input.size(0))
        hidden_states, cell_states = [], []
        for idx in range(session_input.size(1)):
            sess_output, sess_hidden = self.session_encoder(
                session_input[:, idx, :].unsqueeze(1), sess_hidden)
            if self.config.bidirection:
                hidden_states.append(torch.mean(sess_hidden[0], 0))
                cell_states.append(torch.mean(sess_hidden[1], 0))
            else:
                hidden_states.append(sess_hidden[0])
                cell_states.append(sess_hidden[1])

        hidden_states = torch.stack(hidden_states, 2).squeeze(0)
        cell_states = torch.stack(cell_states, 2).squeeze(0)
        hidden_states = hidden_states[:, :-1, :].contiguous().view(
            -1, hidden_states.size(-1)).unsqueeze(0)
        cell_states = cell_states[:, :-1, :].contiguous().view(
            -1, cell_states.size(-1)).unsqueeze(0)

        decoder_input = batch_session[:, 1:, :].contiguous().view(
            -1, batch_session.size(-1))
        target_length = length[:, 1:].contiguous().view(-1)
        input_variable = Variable(
            torch.LongTensor(decoder_input.size(0)).fill_(
                self.dictionary.word2idx[self.dictionary.start_token]))
        if self.config.cuda:
            input_variable = input_variable.cuda()
            hidden_states = hidden_states.cuda()
            cell_states = cell_states.cuda()

        # Initialize hidden states of decoder with the last hidden states of the session encoder
        decoder_hidden = (hidden_states, cell_states)
        loss = 0
        for idx in range(decoder_input.size(1)):
            if idx != 0:
                input_variable = decoder_input[:, idx - 1]
            embedded_decoder_input = self.embedding(input_variable).unsqueeze(
                1)
            decoder_output, decoder_hidden = self.decoder(
                embedded_decoder_input, decoder_hidden)
            target_variable = decoder_input[:, idx]
            loss += self.compute_loss(decoder_output, target_variable, idx,
                                      target_length)

        return loss
Example #3
0
class MatchTensor(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(MatchTensor, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.linear_projection = nn.Linear(self.config.emsize,
                                           self.config.featsize)
        self.query_encoder = Encoder(self.config.featsize,
                                     self.config.nhid_query, True, self.config)
        self.document_encoder = Encoder(self.config.featsize,
                                        self.config.nhid_doc, True,
                                        self.config)
        self.query_projection = nn.Linear(
            self.config.nhid_query * self.num_directions,
            self.config.nchannels)
        self.document_projection = nn.Linear(
            self.config.nhid_doc * self.num_directions, self.config.nchannels)

        self.exact_match_channel = ExactMatchChannel()
        self.conv1 = nn.Conv2d(self.config.nchannels + 1,
                               self.config.nfilters, (3, 3),
                               padding=1)
        self.conv2 = nn.Conv2d(self.config.nchannels + 1,
                               self.config.nfilters, (3, 5),
                               padding=(1, 2))
        self.conv3 = nn.Conv2d(self.config.nchannels + 1,
                               self.config.nfilters, (3, 7),
                               padding=(1, 3))
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(self.config.nfilters * 3,
                              self.config.match_filter_size, (1, 1))
        self.output = nn.Linear(self.config.match_filter_size, 1)

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)

    def forward(self, batch_queries, batch_docs):
        """
        Forward function of the match tensor model. Return average loss for a batch of sessions.
        :param batch_queries: 2d tensor [batch_size x max_query_length]
        :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length]
        :return: average loss over batch [autograd Variable]
        """
        embedded_queries = self.embedding(batch_queries)
        embedded_docs = self.embedding(batch_docs.view(-1,
                                                       batch_docs.size(-1)))
        embedded_queries = self.linear_projection(
            embedded_queries.view(-1, embedded_queries.size(-1)))
        embedded_docs = self.linear_projection(
            embedded_docs.view(-1, embedded_docs.size(-1)))

        embedded_queries = embedded_queries.view(*batch_queries.size(),
                                                 self.config.featsize)
        embedded_docs = embedded_docs.view(-1,
                                           batch_docs.size()[-1],
                                           self.config.featsize)

        if self.config.model == 'LSTM':
            encoder_hidden, encoder_cell = self.query_encoder.init_weights(
                embedded_queries.size(0))
            output, hidden = self.query_encoder(embedded_queries,
                                                (encoder_hidden, encoder_cell))
        else:
            encoder_hidden = self.query_encoder.init_weights(
                embedded_queries.size(0))
            output, hidden = self.query_encoder(embedded_queries,
                                                encoder_hidden)

        embedded_queries = self.query_projection(
            output.view(-1,
                        output.size()[-1])).view(*batch_queries.size(), -1)
        embedded_queries = embedded_queries.unsqueeze(1).expand(
            embedded_queries.size(0), batch_docs.size(1),
            *embedded_queries.size()[1:])
        embedded_queries = embedded_queries.contiguous().view(
            -1,
            *embedded_queries.size()[2:])

        if self.config.model == 'LSTM':
            encoder_hidden, encoder_cell = self.document_encoder.init_weights(
                embedded_docs.size(0))
            output, hidden = self.document_encoder(
                embedded_docs, (encoder_hidden, encoder_cell))
        else:
            encoder_hidden = self.document_encoder.init_weights(
                embedded_docs.size(0))
            output, hidden = self.document_encoder(embedded_docs,
                                                   encoder_hidden)

        embedded_docs = self.document_projection(
            output.view(-1,
                        output.size()[-1]))
        embedded_docs = embedded_docs.view(-1, batch_docs.size(2),
                                           embedded_docs.size()[-1])

        embedded_queries = embedded_queries.unsqueeze(2).expand(
            *embedded_queries.size()[:2],
            batch_docs.size()[-1], embedded_queries.size(2))
        embedded_docs = embedded_docs.unsqueeze(1).expand(
            embedded_docs.size(0),
            batch_queries.size()[-1],
            *embedded_docs.size()[1:])

        query_document_product = embedded_queries * embedded_docs
        exact_match = self.exact_match_channel(batch_queries,
                                               batch_docs).unsqueeze(3)
        query_document_product = torch.cat(
            (query_document_product, exact_match), 3)
        query_document_product = query_document_product.transpose(2,
                                                                  3).transpose(
                                                                      1, 2)

        convoluted_feat1 = self.conv1(query_document_product)
        convoluted_feat2 = self.conv2(query_document_product)
        convoluted_feat3 = self.conv3(query_document_product)
        convoluted_feat = self.relu(
            torch.cat((convoluted_feat1, convoluted_feat2, convoluted_feat3),
                      1))
        convoluted_feat = self.conv(convoluted_feat).transpose(1, 2).transpose(
            2, 3)

        max_pooled_feat = torch.max(convoluted_feat, 2)[0].squeeze()
        max_pooled_feat = torch.max(max_pooled_feat, 1)[0].squeeze()
        return self.output(max_pooled_feat).squeeze().view(
            *batch_docs.size()[:2])