def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(NSRF, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(dictionary), self.config)
        self.embedding.init_embedding_weights(dictionary, embedding_index,
                                              self.config.emsize)

        self.query_encoder = Encoder(self.config.emsize,
                                     self.config.nhid_query,
                                     self.config.bidirection, self.config)
        self.document_encoder = Encoder(self.config.emsize,
                                        self.config.nhid_doc,
                                        self.config.bidirection, self.config)
        self.session_query_encoder = EncoderCell(
            self.config.nhid_query * self.num_directions,
            self.config.nhid_session, False, self.config)
        self.projection = nn.Sequential(
            OrderedDict([('linear',
                          nn.Linear(
                              self.config.nhid_query * self.num_directions +
                              self.config.nhid_session,
                              self.config.nhid_doc * self.num_directions)),
                         ('tanh', nn.Tanh())]))
        self.decoder = DecoderCell(self.config.emsize,
                                   self.config.nhid_session, len(dictionary),
                                   self.config)
Exemple #2
0
    def __init__(self, dictionary, embeddings_index, args):
        """"Constructor of the class."""
        super(SentenceClassifier, self).__init__()
        self.config = args
        self.num_directions = 2 if args.bidirection else 1

        self.embedding = EmbeddingLayer(len(dictionary), self.config)
        self.embedding.init_embedding_weights(dictionary, embeddings_index, self.config.emsize)
        self.encoder = Encoder(self.config.emsize, self.config.nhid, self.config.bidirection, self.config)

        if args.nonlinear_fc:
            self.ffnn = nn.Sequential(OrderedDict([
                ('dropout1', nn.Dropout(self.config.dropout_fc)),
                ('dense1', nn.Linear(self.config.nhid * self.num_directions * 4, self.config.fc_dim)),
                ('tanh', nn.Tanh()),
                ('dropout2', nn.Dropout(self.config.dropout_fc)),
                ('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)),
                ('tanh', nn.Tanh()),
                ('dropout3', nn.Dropout(self.config.dropout_fc)),
                ('dense3', nn.Linear(self.config.fc_dim, self.config.num_classes))
            ]))
        else:
            self.ffnn = nn.Sequential(OrderedDict([
                ('dropout1', nn.Dropout(self.config.dropout_fc)),
                ('dense1', nn.Linear(self.config.nhid * self.num_directions * 4, self.config.fc_dim)),
                ('dropout2', nn.Dropout(self.config.dropout_fc)),
                ('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)),
                ('dropout3', nn.Dropout(self.config.dropout_fc)),
                ('dense3', nn.Linear(self.config.fc_dim, self.config.num_classes))
            ]))
 def __init__(self,
              bsize,
              embed_dim,
              encod_dim,
              numlabel,
              n_layers,
              recom=100,
              feature_vec=None,
              init=False,
              model='LSTM'):
     super(Agent, self).__init__()
     # classifier
     self.batch_size = bsize
     self.nonlinear_fc = False
     self.n_classes = numlabel
     self.enc_lstm_dim = encod_dim
     self.encoder_type = 'Encoder'
     self.model = model
     self.gamma = 0.9
     self.n_layers = n_layers
     self.recom = recom  #Only top 10 items are selected
     self.embedding = EmbeddingLayer(numlabel, embed_dim)
     '''
     if init:
         self.embedding.init_embedding_weights(feature_vec)
     '''
     self.encoder = eval(self.encoder_type)(self.batch_size, embed_dim,
                                            self.enc_lstm_dim, self.model,
                                            self.n_layers)
     self.enc2out = nn.Linear(self.enc_lstm_dim, self.n_classes)
     if init:
         self.init_params()
Exemple #4
0
 def __init__(self,
              bsize,
              embed_dim,
              encod_dim,
              embed_dim_policy,
              encod_dim_policy,
              numlabel,
              rec_num,
              numclass=2,
              feature_vec=None,
              init_embed=False,
              model='LSTM'):
     super(Discriminator, self).__init__()
     # classifier
     self.batch_size = bsize
     self.nonlinear_fc = False
     self.n_classes = numclass
     self.enc_lstm_dim = encod_dim
     self.encoder_type = 'Encoder'
     self.model = model
     self.embedding = EmbeddingLayer(numlabel, embed_dim)
     if init_embed:
         self.embedding.init_embedding_weights(feature_vec, embed_dim)
     self.encoder = eval(self.encoder_type)(self.batch_size,
                                            embed_dim + rec_num + 1,
                                            self.enc_lstm_dim, self.model,
                                            1)
     self.enc2out = nn.Linear(self.enc_lstm_dim, self.n_classes)
     self.rec2enc = nn.Linear(embed_dim * rec_num, rec_num)
Exemple #5
0
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(BCN, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1
        self.dictionary = dictionary

        self.embedding = EmbeddingLayer(len(self.dictionary),
                                        self.config.emsize,
                                        self.config.emtraining, self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)

        self.selector = Selector(self.config.emsize, self.config.dropout)

        self.relu_network = nn.Sequential(
            OrderedDict([('dense1',
                          nn.Linear(self.config.emsize, self.config.nhid)),
                         ('nonlinearity', nn.ReLU())]))

        self.encoder = Encoder(self.config.nhid, self.config.nhid,
                               self.config.bidirection, self.config.nlayers,
                               self.config)
        self.biatt_encoder1 = Encoder(
            self.config.nhid * self.num_directions * 3, self.config.nhid,
            self.config.bidirection, 1, self.config)
        self.biatt_encoder2 = Encoder(
            self.config.nhid * self.num_directions * 3, self.config.nhid,
            self.config.bidirection, 1, self.config)

        self.ffnn = nn.Linear(self.config.nhid * self.num_directions, 1)
        self.maxout_network = MaxoutNetwork(self.config.nhid *
                                            self.num_directions * 4 * 2,
                                            self.config.num_class,
                                            num_units=self.config.num_units)
Exemple #6
0
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(MatchTensor, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.linear_projection = nn.Linear(self.config.emsize, self.config.featsize)
        self.query_encoder = Encoder(self.config.featsize, self.config.nhid_query, True, self.config)
        self.document_encoder = Encoder(self.config.featsize, self.config.nhid_doc, True, self.config)
        self.query_projection = nn.Linear(self.config.nhid_query * self.num_directions, self.config.nchannels)
        self.document_projection = nn.Linear(self.config.nhid_doc * self.num_directions, self.config.nchannels)

        self.exact_match_channel = ExactMatchChannel()
        self.conv1 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 3), padding=1)
        self.conv2 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 5), padding=(1, 2))
        self.conv3 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 7), padding=(1, 3))
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(self.config.nfilters * 3, self.config.match_filter_size, (1, 1))
        self.output = nn.Linear(self.config.match_filter_size, 1)

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize)
Exemple #7
0
class Selector(nn.Module):
    """Biattentive classification network architecture for sentence classification."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""

        super(Selector, self).__init__()
        self.config = args
        self.dictionary = dictionary
        self.embedding = EmbeddingLayer(len(self.dictionary),
                                        self.config.emsize,
                                        self.config.emtraining, self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)
        self.emsize = args.emsize
        self.num_labels = args.num_class
        self.linear = nn.Linear(self.emsize, self.num_labels)

    def forward(self, sentence1, threshold=0.5, is_train=0):
        embedded_x1 = self.embedding(sentence1)
        score = self.linear(embedded_x1)
        # print("linear size: ", score.size())
        score = score.squeeze(1)
        # print('output size: ', score.size())

        return score
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(NSRM, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1
        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.query_encoder = Encoder(self.config.emsize,
                                     self.config.nhid_query, True, self.config)
        self.document_encoder = Encoder(self.config.emsize,
                                        self.config.nhid_doc, True,
                                        self.config)
        self.session_encoder = Encoder(
            self.config.nhid_query * self.num_directions,
            self.config.nhid_session, False, self.config)
        self.projection = nn.Linear(
            (self.config.nhid_query * self.num_directions) +
            self.config.nhid_session,
            self.config.nhid_doc * self.num_directions)
        self.decoder = Decoder(self.config.emsize, self.config.nhid_session,
                               len(self.dictionary), self.config)

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)
Exemple #9
0
class Discriminator(nn.Module):
    def __init__(self,
                 bsize,
                 embed_dim,
                 encod_dim,
                 embed_dim_policy,
                 encod_dim_policy,
                 numlabel,
                 rec_num,
                 numclass=2,
                 feature_vec=None,
                 init_embed=False,
                 model='LSTM'):
        super(Discriminator, self).__init__()
        # classifier
        self.batch_size = bsize
        self.nonlinear_fc = False
        self.n_classes = numclass
        self.enc_lstm_dim = encod_dim
        self.encoder_type = 'Encoder'
        self.model = model
        self.embedding = EmbeddingLayer(numlabel, embed_dim)
        if init_embed:
            self.embedding.init_embedding_weights(feature_vec, embed_dim)
        self.encoder = eval(self.encoder_type)(self.batch_size,
                                               embed_dim + rec_num + 1,
                                               self.enc_lstm_dim, self.model,
                                               1)
        self.enc2out = nn.Linear(self.enc_lstm_dim, self.n_classes)
        self.rec2enc = nn.Linear(embed_dim * rec_num, rec_num)

    def forward(self, seq, reward, rec):
        # seq : (seq, seq_len)
        seq_em, seq_len = seq
        seq_em = self.embedding(seq_em)
        # rescale the recommendation list
        rec = rec.permute(0, 2, 1)

        rec_em = rec.contiguous().view(-1, rec.size(2))
        rec_em = self.embedding(rec_em)
        rec_em = rec_em.view(rec.size(0), rec.size(1), -1)

        rec_em = self.rec2enc(rec_em)
        #Concatenate with the reward
        seq_em = torch.cat((seq_em, rec_em, reward.unsqueeze(2)), 2)
        if self.model == 'LSTM':
            enc_out, (h, c) = self.encoder((seq_em, seq_len))
        else:
            enc_out, h = self.encoder((seq_em, seq_len))

        # Mean pooling
        seq_len = torch.FloatTensor(seq_len.copy()).unsqueeze(1).cuda()
        enc_out = torch.sum(enc_out, 1).squeeze(1)
        enc_out = enc_out / seq_len.expand_as(enc_out)

        # Extract the last hidden layer
        output = self.enc2out(enc_out)
        #output = self.enc2out(h.squeeze(0))#batch*hidden
        output = F.log_softmax(output, dim=1)  #batch*n_classes
        return output
Exemple #10
0
class CNN_ARC_II(nn.Module):
    """Implementation of the convolutional matching model (ARC-II)."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(CNN_ARC_II, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.conv1 = nn.Conv2d(self.config.emsize * 2, self.config.nfilters,
                               (3, 3))
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(self.config.nfilters, self.config.nfilters,
                               (2, 2))
        self.ffnn = nn.Sequential(
            nn.Linear(self.config.nfilters * 4, self.config.nfilters * 2),
            nn.Linear(self.config.nfilters * 2, self.config.nfilters),
            nn.Linear(self.config.nfilters, 1))

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)

    def forward(self, batch_queries, batch_docs):
        """
        Forward function of the match tensor model. Return average loss for a batch of sessions.
        :param batch_queries: 2d tensor [batch_size x max_query_length]
        :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length]
        :return: score representing click probability [batch_size x num_clicks_per_query]
        """
        embedded_queries = self.embedding(batch_queries)
        embedded_queries = embedded_queries.unsqueeze(1).expand(
            *batch_docs.size()[:2],
            *embedded_queries.size()[1:])
        embedded_queries = embedded_queries.contiguous().view(
            -1,
            *embedded_queries.size()[2:])
        embedded_docs = self.embedding(batch_docs.view(-1,
                                                       batch_docs.size(-1)))

        embedded_queries = embedded_queries.unsqueeze(1).expand(
            embedded_queries.size(0), batch_docs.size(2),
            *embedded_queries.size()[1:])
        embedded_docs = embedded_docs.unsqueeze(2).expand(
            *embedded_docs.size()[:2], batch_queries.size(1),
            embedded_docs.size(2))

        combined_rep = torch.cat((embedded_queries, embedded_docs), 3)
        combined_rep = combined_rep.transpose(2, 3).transpose(1, 2)
        conv1_out = self.pool1(F.relu(self.conv1(combined_rep)))
        conv2_out = self.pool1(F.relu(self.conv2(conv1_out))).squeeze().view(
            -1, self.config.nfilters * 4)

        return F.log_softmax(
            self.ffnn(conv2_out).squeeze().view(*batch_docs.size()[0:2]), 1)
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""

        super(Selector, self).__init__()
        self.config = args
        self.dictionary = dictionary
        self.embedding = EmbeddingLayer(len(self.dictionary),
                                        self.config.emsize,
                                        self.config.emtraining, self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)

        self.we_selector = WE_Selector(self.config.emsize, self.config.dropout)
Exemple #12
0
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""

        super(Selector, self).__init__()
        self.config = args
        self.dictionary = dictionary
        self.embedding = EmbeddingLayer(len(self.dictionary),
                                        self.config.emsize,
                                        self.config.emtraining, self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)
        self.emsize = args.emsize
        self.num_labels = args.num_class
        self.linear = nn.Linear(self.emsize, self.num_labels)
Exemple #13
0
    def __init__(self, dictionary, embedding_index, class_distributions, args):
        """"Constructor of the class."""
        super(BCN, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1
        self.dictionary = dictionary
        self.class_distributions = class_distributions  #dict of class counts

        #Model definition
        if self.config.pos:
            self.embedding_pos = EmbeddingLayer(len(pos_to_idx),
                                                POS_EMBEDDING_DIM, True,
                                                self.config)
            self.embedding_pos.init_pos_weights(pos_to_idx, POS_EMBEDDING_DIM)

        self.embedding = EmbeddingLayer(len(self.dictionary),
                                        self.config.emsize,
                                        self.config.emtraining, self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)

        if self.config.pos:
            self.relu_network = nn.Sequential(
                OrderedDict([('dense1',
                              nn.Linear(self.config.emsize + POS_EMBEDDING_DIM,
                                        self.config.nhid)),
                             ('nonlinearity', nn.ReLU())]))
        else:
            self.relu_network = nn.Sequential(
                OrderedDict([('dense1',
                              nn.Linear(self.config.emsize, self.config.nhid)),
                             ('nonlinearity', nn.ReLU())]))

        self.encoder = Encoder(self.config.nhid, self.config.nhid,
                               self.config.bidirection, self.config.nlayers,
                               self.config)
        self.biatt_encoder1 = Encoder(
            self.config.nhid * self.num_directions * 3, self.config.nhid,
            self.config.bidirection, 1, self.config)
        self.biatt_encoder2 = Encoder(
            self.config.nhid * self.num_directions * 3, self.config.nhid,
            self.config.bidirection, 1, self.config)

        self.ffnn = nn.Linear(self.config.nhid * self.num_directions, 1)
        self.maxout_network = MaxoutNetwork(self.config.nhid *
                                            self.num_directions * 4 * 2,
                                            self.config.num_class,
                                            num_units=self.config.num_units)
        print("BCN init num_units: ", self.config.num_class)
Exemple #14
0
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(Seq2Seq, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1
        self.dictionary = dictionary

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)

        self.encoder = Encoder(self.config.input_size, self.config.nhid_enc,
                               self.config.bidirection, self.config)
        self.decoder = Decoder(self.config.emsize,
                               self.config.nhid_enc * self.num_directions,
                               len(self.dictionary), self.config)
 def __init__(self,
              bsize,
              embed_dim,
              encod_dim,
              numlabel,
              n_layers,
              feature_vec=None,
              init=False,
              model='LSTM'):
     super(Generator, self).__init__()
     # classifier
     self.batch_size = bsize
     self.nonlinear_fc = False
     self.n_classes = numlabel
     self.enc_lstm_dim = encod_dim
     self.encoder_type = 'Encoder'
     self.n_layers = n_layers
     #self.end = self.n_classes-1
     self.model = model
     self.gamma = 0.9
     self.embedding = EmbeddingLayer(numlabel, embed_dim, 0)
     '''
     if init:
         self.embedding.init_embedding_weights(feature_vec, embed_dim)
     '''
     self.encoder = eval(self.encoder_type)(self.batch_size, embed_dim,
                                            self.enc_lstm_dim, self.model,
                                            self.n_layers, 0)
     self.enc2out = nn.Linear(self.enc_lstm_dim, embed_dim)
     self.enc2rewd = nn.Linear(self.enc_lstm_dim, embed_dim)
     if init:
         self.init_params()
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(Sequence2Sequence, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.query_encoder = Encoder(self.config.emsize,
                                     self.config.nhid_query, self.config)
        self.session_encoder = Encoder(self.config.nhid_query,
                                       self.config.nhid_session, self.config)
        self.decoder = Decoder(self.config.emsize, self.config.nhid_session,
                               len(self.dictionary), self.config)

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(DRMM, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.bins = [-1.0, -0.5, 0, 0.5, 1.0, 1.0]

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.gating_network = GatingNetwork(self.config.emsize)
        self.ffnn = nn.Sequential(nn.Linear(self.config.nbins, 1),
                                  nn.Linear(1, 1))
        self.output = nn.Linear(1, 1)

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)
Exemple #18
0
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(HRED_QS, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(dictionary), self.config)
        self.embedding.init_embedding_weights(dictionary, embedding_index,
                                              self.config.emsize)

        self.query_encoder = Encoder(self.config.emsize,
                                     self.config.nhid_query,
                                     self.config.bidirection, self.config)
        self.session_encoder = EncoderCell(
            self.config.nhid_query * self.num_directions,
            self.config.nhid_session, False, self.config)
        self.decoder = DecoderCell(self.config.emsize,
                                   self.config.nhid_session, len(dictionary),
                                   self.config)
Exemple #19
0
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(CNN_ARC_II, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.conv1 = nn.Conv2d(self.config.emsize * 2, self.config.nfilters,
                               (3, 3))
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(self.config.nfilters, self.config.nfilters,
                               (2, 2))
        self.ffnn = nn.Sequential(
            nn.Linear(self.config.nfilters * 4, self.config.nfilters * 2),
            nn.Linear(self.config.nfilters * 2, self.config.nfilters),
            nn.Linear(self.config.nfilters, 1))

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(CNN_ARC_I, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.convolution1 = nn.Conv1d(self.config.emsize, self.config.nfilters,
                                      1)
        self.convolution2 = nn.Conv1d(self.config.emsize, self.config.nfilters,
                                      2)
        self.convolution3 = nn.Conv1d(self.config.emsize,
                                      self.config.nfilters * 2, 3)
        self.ffnn = nn.Sequential(
            nn.Linear(self.config.nfilters * 8, self.config.nfilters * 4),
            nn.Linear(self.config.nfilters * 4, self.config.nfilters * 2),
            nn.Linear(self.config.nfilters * 2, 1))

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)
Exemple #21
0
class HRED_QS(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(HRED_QS, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(dictionary), self.config)
        self.embedding.init_embedding_weights(dictionary, embedding_index,
                                              self.config.emsize)

        self.query_encoder = Encoder(self.config.emsize,
                                     self.config.nhid_query,
                                     self.config.bidirection, self.config)
        self.session_encoder = EncoderCell(
            self.config.nhid_query * self.num_directions,
            self.config.nhid_session, False, self.config)
        self.decoder = DecoderCell(self.config.emsize,
                                   self.config.nhid_session, len(dictionary),
                                   self.config)

    @staticmethod
    def compute_loss(logits, target, seq_idx, length, regularize):
        """
        Compute negative log-likelihood loss for a batch of predictions.
        :param logits: 2d tensor [batch_size x vocab_size]
        :param target: 1d tensor [batch_size]
        :param seq_idx: an integer represents the current index of the sequences
        :param length: 1d tensor [batch_size], represents each sequences' true length
        :param regularize: boolean, whether use entropy regularization in loss computation
        :return: total loss over the input mini-batch [autograd Variable] and number of loss elements
        """
        losses = -torch.gather(logits, dim=1,
                               index=target.unsqueeze(1)).squeeze()
        mask = helper.mask(length, seq_idx)  # mask: batch x 1
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if regularize:
            regularized_loss = logits.exp().mul(logits).sum(
                1).squeeze() * regularize
            loss = losses.sum() + regularized_loss.sum()
            if not num_non_zero_elem:
                return loss, 0
            else:
                return loss, num_non_zero_elem[0]
        else:
            if not num_non_zero_elem:
                return losses.sum(), 0
            else:
                return losses.sum(), num_non_zero_elem[0]

    def forward(self, session_queries, session_query_length):
        """
        Forward function of the neural click model. Return average loss for a batch of sessions.
        :param session_queries: 3d tensor [batch_size x session_length x max_query_length]
        :param session_query_length: 2d tensor [batch_size x session_length]
        :return: average loss over batch [autograd Variable]
        """
        # query encoding
        embedded_queries = self.embedding(
            session_queries.view(-1, session_queries.size(-1)))
        encoded_queries = self.query_encoder(
            embedded_queries,
            session_query_length.view(-1).data.cpu().numpy())
        encoded_queries = self.apply_pooling(encoded_queries,
                                             self.config.pool_type)

        # encoded_queries: batch_size x session_length x (nhid_query * self.num_directions)
        encoded_queries = encoded_queries.contiguous().view(
            *session_queries.size()[:-1], -1)

        # session level encoding
        sess_query_hidden = self.session_encoder.init_weights(
            encoded_queries.size(0))
        hidden_states, cell_states = [], []

        # loop over all the queries in a session
        for idx in range(encoded_queries.size(1)):
            # update session-level query encoder state using query representations
            sess_q_out, sess_query_hidden = self.session_encoder(
                encoded_queries[:, idx, :].unsqueeze(1), sess_query_hidden)
            # -1: only consider hidden states of the last layer
            if self.config.model == 'LSTM':
                hidden_states.append(sess_query_hidden[0][-1])
                cell_states.append(sess_query_hidden[1][-1])
            else:
                hidden_states.append(sess_query_hidden[-1])

        hidden_states = torch.stack(hidden_states, 1)
        # remove the last hidden states which stand for the last queries in sessions
        hidden_states = hidden_states[:, :-1, :].contiguous().view(
            -1, hidden_states.size(-1)).unsqueeze(0)
        if self.config.model == 'LSTM':
            cell_states = torch.stack(cell_states, 1)
            cell_states = cell_states[:, :-1, :].contiguous().view(
                -1, cell_states.size(-1)).unsqueeze(0)
            # Initialize hidden states of decoder with the last hidden states of the session encoder
            decoder_hidden = (hidden_states, cell_states)
        else:
            # Initialize hidden states of decoder with the last hidden states of the session encoder
            decoder_hidden = hidden_states

        # train the decoder for all the queries in a session except the first
        embedded_queries = embedded_queries.view(*session_queries.size(), -1)
        decoder_input = embedded_queries[:, 1:, :, :].contiguous().view(
            -1,
            *embedded_queries.size()[2:])
        decoder_target = session_queries[:, 1:, :].contiguous().view(
            -1, session_queries.size(-1))
        target_length = session_query_length[:, 1:].contiguous().view(-1)
        decoding_loss, total_local_decoding_loss_element = 0, 0
        for idx in range(decoder_input.size(1) - 1):
            input_variable = decoder_input[:, idx, :].unsqueeze(1)
            decoder_output, decoder_hidden = self.decoder(
                input_variable, decoder_hidden)
            local_loss, num_local_loss = self.compute_loss(
                decoder_output, decoder_target[:, idx + 1], idx, target_length,
                self.config.regularize)
            decoding_loss += local_loss
            total_local_decoding_loss_element += num_local_loss

        if total_local_decoding_loss_element > 0:
            decoding_loss = decoding_loss / total_local_decoding_loss_element

        return decoding_loss

    @staticmethod
    def apply_pooling(encodings, pool_type):
        if pool_type == 'max':
            pooled_encodings = torch.max(encodings, 1)[0].squeeze()
        elif pool_type == 'mean':
            pooled_encodings = torch.sum(encodings,
                                         1).squeeze() / encodings.size(1)
        elif pool_type == 'last':
            pooled_encodings = encodings[:, -1, :]

        return pooled_encodings
Exemple #22
0
class SentenceClassifier(nn.Module):
    """Predicts the label given a pair of sentences."""

    def __init__(self, dictionary, embeddings_index, args):
        """"Constructor of the class."""
        super(SentenceClassifier, self).__init__()
        self.config = args
        self.num_directions = 2 if args.bidirection else 1

        self.embedding = EmbeddingLayer(len(dictionary), self.config)
        self.embedding.init_embedding_weights(dictionary, embeddings_index, self.config.emsize)
        self.encoder = Encoder(self.config.emsize, self.config.nhid, self.config.bidirection, self.config)

        if args.nonlinear_fc:
            self.ffnn = nn.Sequential(OrderedDict([
                ('dropout1', nn.Dropout(self.config.dropout_fc)),
                ('dense1', nn.Linear(self.config.nhid * self.num_directions * 4, self.config.fc_dim)),
                ('tanh', nn.Tanh()),
                ('dropout2', nn.Dropout(self.config.dropout_fc)),
                ('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)),
                ('tanh', nn.Tanh()),
                ('dropout3', nn.Dropout(self.config.dropout_fc)),
                ('dense3', nn.Linear(self.config.fc_dim, self.config.num_classes))
            ]))
        else:
            self.ffnn = nn.Sequential(OrderedDict([
                ('dropout1', nn.Dropout(self.config.dropout_fc)),
                ('dense1', nn.Linear(self.config.nhid * self.num_directions * 4, self.config.fc_dim)),
                ('dropout2', nn.Dropout(self.config.dropout_fc)),
                ('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)),
                ('dropout3', nn.Dropout(self.config.dropout_fc)),
                ('dense3', nn.Linear(self.config.fc_dim, self.config.num_classes))
            ]))

    def forward(self, batch_sentence1, sent_len1, batch_sentence2, sent_len2):
        """"Defines the forward computation of the sentence pair classifier."""
        embedded1 = self.embedding(batch_sentence1)
        embedded2 = self.embedding(batch_sentence2)

        # For the first sentences in batch
        output1 = self.encoder(embedded1, sent_len1)
        # For the second sentences in batch
        output2 = self.encoder(embedded2, sent_len2)

        if self.config.pool_type == 'max':
            encoded_questions1 = torch.max(output1, 1)[0]
            encoded_questions2 = torch.max(output2, 1)[0]
        elif self.config.pool_type == 'mean':
            encoded_questions1 = torch.mean(output1, 1)
            encoded_questions2 = torch.mean(output2, 1)
        elif self.config.pool_type == 'last':
            if self.num_directions == 2:
                encoded_questions1 = torch.cat((output1[:, -1, :self.config.nhid], output1[:, 0, self.config.nhid:]), 1)
                encoded_questions2 = torch.cat((output2[:, -1, :self.config.nhid], output2[:, 0, self.config.nhid:]), 1)
            else:
                encoded_questions1 = output1[:, -1, :]
                encoded_questions2 = output2[:, -1, :]

        assert encoded_questions1.size(0) == encoded_questions2.size(0)

        # compute angle between sentence representation
        angle = torch.mul(encoded_questions1, encoded_questions2)
        # compute distance between sentence representation
        distance = torch.abs(encoded_questions1 - encoded_questions2)
        # combined_representation = batch_size x (hidden_size * num_directions * 4)
        combined_representation = torch.cat((encoded_questions1, encoded_questions2, angle, distance), 1)

        return self.ffnn(combined_representation)
class Selector(nn.Module):
    """Biattentive classification network architecture for sentence classification."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""

        super(Selector, self).__init__()
        self.config = args
        self.dictionary = dictionary
        self.embedding = EmbeddingLayer(len(self.dictionary),
                                        self.config.emsize,
                                        self.config.emtraining, self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)

        self.we_selector = WE_Selector(self.config.emsize, self.config.dropout)

    def forward(self,
                sentence1,
                sentence1_len_old,
                sentence2,
                sentence2_len_old,
                threshold=0.5,
                is_train=0):
        """
        Forward computation of the biattentive classification network.
        Returns classification scores for a batch of sentence pairs.
        :param sentence1: 2d tensor [batch_size x max_length]
        :param sentence1_len: 1d numpy array [batch_size]
        :param sentence2: 2d tensor [batch_size x max_length]
        :param sentence2_len: 1d numpy array [batch_size]
        :return: classification scores over batch [batch_size x num_classes]
        """
        # step1: embed the words into vectors [batch_size x max_length x emsize]
        embedded_x1 = self.embedding(sentence1)
        embedded_y1 = self.embedding(sentence2)

        ###################################### selection ######################################
        pbx = self.we_selector(embedded_x1)
        pby = self.we_selector(embedded_y1)

        assert pbx.size() == sentence1.size()
        assert pby.size() == sentence2.size()

        #torch byte tesnor Variable of size (batch x len)
        selection_x = pbx.bernoulli().long()  #(pbx>=threshold).long()
        selection_y = pby.bernoulli().long()  #(pby>=threshold).long()

        result_x = sentence1.mul(
            selection_x
        )  #word ids that are selected; contains zeros where it's not selected (ony selected can be found by selected_x[selected_x!=0])
        result_y = sentence2.mul(selection_y)

        selected_x, sentence1_len = helper.get_selected_tensor(
            result_x, pbx, sentence1, sentence1_len_old,
            self.config.cuda)  #sentence1_len is a numpy array
        selected_y, sentence2_len = helper.get_selected_tensor(
            result_y, pby, sentence2, sentence2_len_old,
            self.config.cuda)  #sentence2_len is a numpy array

        logpz = zsum = zdiff = -1.0
        if is_train == 1:
            mask1 = (sentence1 != 0).long()
            mask2 = (sentence2 != 0).long()

            masked_selection_x = selection_x.mul(mask1)
            masked_selection_y = selection_y.mul(mask2)

            #logpz (batch x len)
            logpx = -helper.binary_cross_entropy(
                pbx, selection_x.float().detach(), reduce=False
            )  #as reduce is not available for this version I am doing this code myself:
            logpy = -helper.binary_cross_entropy(
                pby, selection_y.float().detach(), reduce=False)
            assert logpx.size() == sentence1.size()

            # batch
            logpx = logpx.mul(mask1.float()).sum(1)
            logpy = logpy.mul(mask2.float()).sum(1)
            logpz = (logpx + logpy)
            # zsum = ##### same as sentence1_len #####T.sum(z, axis=0, dtype=theano.config.floatX)
            zdiff1 = (
                masked_selection_x[:, 1:] - masked_selection_x[:, :-1]
            ).abs().sum(
                1
            )  ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)
            zdiff2 = (
                masked_selection_y[:, 1:] - masked_selection_y[:, :-1]
            ).abs().sum(
                1
            )  ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)

            assert zdiff1.size()[0] == sentence1.size()[0]
            assert logpz.size()[0] == sentence1.size()[0]

            zdiff = zdiff1 + zdiff2

            xsum = masked_selection_x.sum(1)
            ysum = masked_selection_y.sum(1)
            zsum = xsum + ysum

            assert zsum.size()[0] == sentence1.size()[0]

            assert logpz.dim() == zsum.dim()
            assert logpz.dim() == zdiff.dim()
            return selected_x, sentence1_len, selected_y, sentence2_len, logpz, zsum.float(
            ), zdiff.float()

        # return selected_x (var), sentence1_len (numpy), selected_y (var), sentence2_len (numpy), selector_loss (var of size 1)
        return selected_x, sentence1_len, selected_y, sentence2_len, logpz, zsum, zdiff
class Sequence2Sequence(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(Sequence2Sequence, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.query_encoder = Encoder(self.config.emsize,
                                     self.config.nhid_query, self.config)
        self.session_encoder = Encoder(self.config.nhid_query,
                                       self.config.nhid_session, self.config)
        self.decoder = Decoder(self.config.emsize, self.config.nhid_session,
                               len(self.dictionary), self.config)

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)

    @staticmethod
    def compute_loss(logits, target, seq_idx, length):
        # logits: batch x vocab_size, target: batch x 1
        losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1))
        # mask: batch x 1
        mask = helper.mask(length, seq_idx)
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if not num_non_zero_elem:
            loss = losses.sum()
        else:
            loss = losses.sum() / num_non_zero_elem[0]
        return loss

    def forward(self, batch_session, length):
        """"Defines the forward computation of the question classifier."""
        embedded_input = self.embedding(
            batch_session.view(-1, batch_session.size(-1)))
        if self.config.model == 'LSTM':
            encoder_hidden, encoder_cell = self.query_encoder.init_weights(
                embedded_input.size(0))
            output, hidden = self.query_encoder(embedded_input,
                                                (encoder_hidden, encoder_cell))
        else:
            encoder_hidden = self.query_encoder.init_weights(
                embedded_input.size(0))
            output, hidden = self.query_encoder(embedded_input, encoder_hidden)

        if self.config.bidirection:
            output = torch.div(
                torch.add(
                    output[:, :, 0:self.config.nhid_query],
                    output[:, :,
                           self.config.nhid_query:2 * self.config.nhid_query]),
                2)

        session_input = output[:, -1, :].contiguous().view(
            batch_session.size(0), batch_session.size(1), -1)
        # session level encoding
        sess_hidden = self.session_encoder.init_weights(session_input.size(0))
        hidden_states, cell_states = [], []
        for idx in range(session_input.size(1)):
            sess_output, sess_hidden = self.session_encoder(
                session_input[:, idx, :].unsqueeze(1), sess_hidden)
            if self.config.bidirection:
                hidden_states.append(torch.mean(sess_hidden[0], 0))
                cell_states.append(torch.mean(sess_hidden[1], 0))
            else:
                hidden_states.append(sess_hidden[0])
                cell_states.append(sess_hidden[1])

        hidden_states = torch.stack(hidden_states, 2).squeeze(0)
        cell_states = torch.stack(cell_states, 2).squeeze(0)
        hidden_states = hidden_states[:, :-1, :].contiguous().view(
            -1, hidden_states.size(-1)).unsqueeze(0)
        cell_states = cell_states[:, :-1, :].contiguous().view(
            -1, cell_states.size(-1)).unsqueeze(0)

        decoder_input = batch_session[:, 1:, :].contiguous().view(
            -1, batch_session.size(-1))
        target_length = length[:, 1:].contiguous().view(-1)
        input_variable = Variable(
            torch.LongTensor(decoder_input.size(0)).fill_(
                self.dictionary.word2idx[self.dictionary.start_token]))
        if self.config.cuda:
            input_variable = input_variable.cuda()
            hidden_states = hidden_states.cuda()
            cell_states = cell_states.cuda()

        # Initialize hidden states of decoder with the last hidden states of the session encoder
        decoder_hidden = (hidden_states, cell_states)
        loss = 0
        for idx in range(decoder_input.size(1)):
            if idx != 0:
                input_variable = decoder_input[:, idx - 1]
            embedded_decoder_input = self.embedding(input_variable).unsqueeze(
                1)
            decoder_output, decoder_hidden = self.decoder(
                embedded_decoder_input, decoder_hidden)
            target_variable = decoder_input[:, idx]
            loss += self.compute_loss(decoder_output, target_variable, idx,
                                      target_length)

        return loss
Exemple #25
0
class MatchTensor(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(MatchTensor, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)
        self.linear_projection = nn.Linear(self.config.emsize,
                                           self.config.featsize)
        self.query_encoder = Encoder(self.config.featsize,
                                     self.config.nhid_query, True, self.config)
        self.document_encoder = Encoder(self.config.featsize,
                                        self.config.nhid_doc, True,
                                        self.config)
        self.query_projection = nn.Linear(
            self.config.nhid_query * self.num_directions,
            self.config.nchannels)
        self.document_projection = nn.Linear(
            self.config.nhid_doc * self.num_directions, self.config.nchannels)

        self.exact_match_channel = ExactMatchChannel()
        self.conv1 = nn.Conv2d(self.config.nchannels + 1,
                               self.config.nfilters, (3, 3),
                               padding=1)
        self.conv2 = nn.Conv2d(self.config.nchannels + 1,
                               self.config.nfilters, (3, 5),
                               padding=(1, 2))
        self.conv3 = nn.Conv2d(self.config.nchannels + 1,
                               self.config.nfilters, (3, 7),
                               padding=(1, 3))
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(self.config.nfilters * 3,
                              self.config.match_filter_size, (1, 1))
        self.output = nn.Linear(self.config.match_filter_size, 1)

        self.session_encoder = EncoderCell(self.config.nchannels,
                                           self.config.nhid_session, False,
                                           self.config)
        self.decoder = DecoderCell(self.config.emsize,
                                   self.config.nhid_session, len(dictionary),
                                   self.config)

    @staticmethod
    def compute_decoding_loss(logits, target, seq_idx, length, regularize):
        """
        Compute negative log-likelihood loss for a batch of predictions.
        :param logits: 2d tensor [batch_size x vocab_size]
        :param target: 1d tensor [batch_size]
        :param seq_idx: an integer represents the current index of the sequences
        :param length: 1d tensor [batch_size], represents each sequences' true length
        :return: total loss over the input mini-batch [autograd Variable] and number of loss elements
        """
        losses = -torch.gather(logits, dim=1,
                               index=target.unsqueeze(1)).squeeze()
        mask = helper.mask(length, seq_idx)  # mask: batch x 1
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if regularize:
            regularized_loss = logits.exp().mul(logits).sum(
                1).squeeze() * regularize
            loss = losses.sum() + regularized_loss.sum()
            if not num_non_zero_elem:
                return loss, 0
            else:
                return loss, num_non_zero_elem[0]
        else:
            if not num_non_zero_elem:
                return losses.sum(), 0
            else:
                return losses.sum(), num_non_zero_elem[0]

    def forward(self, session_queries, session_query_length, rel_docs,
                rel_docs_length, doc_labels):
        """
        Forward function of the neural click model. Return average loss for a batch of sessions.
        :param session_queries: 3d tensor [batch_size x session_length x max_query_length]
        :param session_query_length: 2d tensor [batch_size x session_length]
        :param rel_docs: 4d tensor [batch_size x session_length x num_rel_docs_per_query x max_doc_length]
        :param rel_docs_length: 3d tensor [batch_size x session_length x num_rel_docs_per_query]
        :param doc_labels: 3d tensor [batch_size x session_length x num_rel_docs_per_query]
        :return: average loss over batch [autograd Variable]
        """
        batch_queries = session_queries.view(-1, session_queries.size(-1))
        batch_docs = rel_docs.view(-1, *rel_docs.size()[2:])

        projected_queries = self.encode_query(
            batch_queries, session_query_length)  # (B*S) x L x H
        projected_docs = self.encode_document(batch_docs, rel_docs_length)
        score = self.document_ranker(projected_queries, projected_docs,
                                     batch_queries, batch_docs)
        click_loss = f.binary_cross_entropy_with_logits(
            score, doc_labels.view(-1, doc_labels.size(2)))

        # encoded_queries: batch_size x session_length x nhid_query
        encoded_queries = projected_queries.max(1)[0].view(
            *session_queries.size()[:2], -1)
        decoding_loss = self.query_recommender(session_queries,
                                               session_query_length,
                                               encoded_queries)

        return click_loss, decoding_loss

    def query_recommender(self, session_queries, session_query_length,
                          encoded_queries):
        # session level encoding
        sess_q_hidden = self.session_encoder.init_weights(
            encoded_queries.size(0))
        hidden_states, cell_states = [], []
        # loop over all the queries in a session
        for idx in range(encoded_queries.size(1)):
            # update session-level query encoder state using query representations
            sess_q_out, sess_q_hidden = self.session_encoder(
                encoded_queries[:, idx, :].unsqueeze(1), sess_q_hidden)
            # -1 stands for: only consider hidden states from the last layer
            if self.config.model == 'LSTM':
                hidden_states.append(sess_q_hidden[0][-1])
                cell_states.append(sess_q_hidden[1][-1])
            else:
                hidden_states.append(sess_q_hidden[-1])

        hidden_states = torch.stack(hidden_states, 1)
        # remove the last hidden states which stand for the last queries in sessions
        hidden_states = hidden_states[:, :-1, :].contiguous().view(
            -1, hidden_states.size(-1)).unsqueeze(0)
        if self.config.model == 'LSTM':
            cell_states = torch.stack(cell_states, 1)
            cell_states = cell_states[:, :-1, :].contiguous().view(
                -1, cell_states.size(-1)).unsqueeze(0)
            # Initialize hidden states of decoder with the last hidden states of the session encoder
            decoder_hidden = (hidden_states, cell_states)
        else:
            # Initialize hidden states of decoder with the last hidden states of the session encoder
            decoder_hidden = hidden_states

        embedded_queries = self.embedding(
            session_queries.view(-1, session_queries.size(-1)))
        # train the decoder for all the queries in a session except the last
        embedded_queries = embedded_queries.view(*session_queries.size(), -1)
        decoder_input = embedded_queries[:, 1:, :, :].contiguous().view(
            -1,
            *embedded_queries.size()[2:])
        decoder_target = session_queries[:, 1:, :].contiguous().view(
            -1, session_queries.size(-1))
        target_length = session_query_length[:, 1:].contiguous().view(-1)

        decoding_loss, total_local_decoding_loss_element = 0, 0
        for idx in range(decoder_input.size(1) - 1):
            input_variable = decoder_input[:, idx, :].unsqueeze(1)
            decoder_output, decoder_hidden = self.decoder(
                input_variable, decoder_hidden)
            local_loss, num_local_loss = self.compute_decoding_loss(
                decoder_output, decoder_target[:, idx + 1], idx, target_length,
                self.config.regularize)
            decoding_loss += local_loss
            total_local_decoding_loss_element += num_local_loss

        if total_local_decoding_loss_element > 0:
            decoding_loss = decoding_loss / total_local_decoding_loss_element

        return decoding_loss

    def document_ranker(self, projected_queries, projected_docs, batch_queries,
                        batch_docs):
        # step6: 2d product between projected query and doc vectors
        projected_queries = projected_queries.unsqueeze(1).expand(
            projected_queries.size(0), batch_docs.size(1),
            *projected_queries.size()[1:])
        projected_queries = projected_queries.contiguous().view(
            -1,
            *projected_queries.size()[2:])

        projected_docs = projected_docs.view(-1, batch_docs.size(2),
                                             projected_docs.size()[-1])

        projected_queries = projected_queries.unsqueeze(2).expand(
            *projected_queries.size()[:2],
            batch_docs.size()[-1], projected_queries.size(2))
        projected_docs = projected_docs.unsqueeze(1).expand(
            projected_docs.size(0),
            batch_queries.size()[-1],
            *projected_docs.size()[1:])
        query_document_product = projected_queries * projected_docs

        # step7: append exact match channel
        exact_match = self.exact_match_channel(batch_queries,
                                               batch_docs).unsqueeze(3)
        query_document_product = torch.cat(
            (query_document_product, exact_match), 3)
        query_document_product = query_document_product.transpose(2,
                                                                  3).transpose(
                                                                      1, 2)

        # step8: run the convolutional operation, max-pooling and linear projection
        convoluted_feat1 = self.conv1(query_document_product)
        convoluted_feat2 = self.conv2(query_document_product)
        convoluted_feat3 = self.conv3(query_document_product)
        convoluted_feat = self.relu(
            torch.cat((convoluted_feat1, convoluted_feat2, convoluted_feat3),
                      1))
        convoluted_feat = self.conv(convoluted_feat).transpose(1, 2).transpose(
            2, 3)

        max_pooled_feat = torch.max(convoluted_feat, 2)[0].squeeze()
        max_pooled_feat = torch.max(max_pooled_feat, 1)[0].squeeze()
        return self.output(max_pooled_feat).squeeze().view(
            *batch_docs.size()[:2])

    def encode_query(self, batch_queries, session_query_length):
        # step1: apply embedding lookup
        embedded_queries = self.embedding(batch_queries)
        # step2: apply linear projection on embedded queries and documents
        embedded_queries = self.linear_projection(
            embedded_queries.view(-1, embedded_queries.size(-1)))
        # step3: transform the tensors so that they can be given as input to RNN
        embedded_queries = embedded_queries.view(*batch_queries.size(),
                                                 self.config.featsize)
        # step4: pass the encoded query and doc through a bi-LSTM
        encoded_queries = self.query_encoder(
            embedded_queries,
            session_query_length.view(-1).data.cpu().numpy())
        # step5: apply linear projection on query hidden states
        projected_queries = self.query_projection(
            encoded_queries.view(-1,
                                 encoded_queries.size()[-1])).view(
                                     *batch_queries.size(), -1)
        return projected_queries

    def encode_document(self, batch_docs, rel_docs_length):
        # step1: apply embedding lookup
        embedded_docs = self.embedding(batch_docs.view(-1,
                                                       batch_docs.size(-1)))
        # step2: apply linear projection on embedded queries and documents
        embedded_docs = self.linear_projection(
            embedded_docs.view(-1, embedded_docs.size(-1)))
        # step3: transform the tensors so that they can be given as input to RNN
        embedded_docs = embedded_docs.view(-1,
                                           batch_docs.size()[-1],
                                           self.config.featsize)
        # step4: pass the encoded query and doc through a bi-LSTM
        encoded_docs = self.document_encoder(
            embedded_docs,
            rel_docs_length.view(-1).data.cpu().numpy())
        # step5: apply linear projection on query hidden states
        projected_docs = self.document_projection(
            encoded_docs.view(-1,
                              encoded_docs.size()[-1]))
        return projected_docs
Exemple #26
0
class BCN(nn.Module):
    """Biattentive classification network architecture for sentence classification."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(BCN, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1
        self.dictionary = dictionary

        self.embedding = EmbeddingLayer(len(self.dictionary),
                                        self.config.emsize,
                                        self.config.emtraining, self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)

        self.selector = Selector(self.config.emsize, self.config.dropout)

        self.relu_network = nn.Sequential(
            OrderedDict([('dense1',
                          nn.Linear(self.config.emsize, self.config.nhid)),
                         ('nonlinearity', nn.ReLU())]))

        self.encoder = Encoder(self.config.nhid, self.config.nhid,
                               self.config.bidirection, self.config.nlayers,
                               self.config)
        self.biatt_encoder1 = Encoder(
            self.config.nhid * self.num_directions * 3, self.config.nhid,
            self.config.bidirection, 1, self.config)
        self.biatt_encoder2 = Encoder(
            self.config.nhid * self.num_directions * 3, self.config.nhid,
            self.config.bidirection, 1, self.config)

        self.ffnn = nn.Linear(self.config.nhid * self.num_directions, 1)
        self.maxout_network = MaxoutNetwork(self.config.nhid *
                                            self.num_directions * 4 * 2,
                                            self.config.num_class,
                                            num_units=self.config.num_units)

    def forward(self, sentence1, sentence1_len_old, sentence2,
                sentence2_len_old):
        """
        Forward computation of the biattentive classification network.
        Returns classification scores for a batch of sentence pairs.
        :param sentence1: 2d tensor [batch_size x max_length]
        :param sentence1_len: 1d numpy array [batch_size]
        :param sentence2: 2d tensor [batch_size x max_length]
        :param sentence2_len: 1d numpy array [batch_size]
        :return: classification scores over batch [batch_size x num_classes]
        """
        # step1: embed the words into vectors [batch_size x max_length x emsize]
        embedded_x1 = self.embedding(sentence1)
        embedded_y1 = self.embedding(sentence2)

        ###################################### selection ######################################
        selection_x = self.selector(embedded_x1)
        selection_y = self.selector(embedded_y1)

        assert selection_x.size() == sentence1.size()
        assert selection_y.size() == sentence2.size()

        result_x = sentence1.mul(
            selection_x
        )  #word ids that are selected contains zeros where it's not selected (ony selected can be found by selected_x[selected_x!=0])
        result_y = sentence2.mul(selection_y)

        selected_x, sentence1_len = helper.get_selected_tensor(
            result_x, self.config.cuda)  #sentence1_len is a numpy array
        selected_y, sentence2_len = helper.get_selected_tensor(
            result_y, self.config.cuda)  #sentence2_len is a numpy array

        embedded_x = self.embedding(selected_x)
        embedded_y = self.embedding(selected_y)

        # batch
        # zsum = ##### same as sentence1_len #####T.sum(z, axis=0, dtype=theano.config.floatX)
        zdiff1 = (selection_x[:, 1:] - selection_x[:, :-1]).abs().sum(
            1
        )  ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)
        zdiff2 = (selection_y[:, 1:] - selection_y[:, :-1]).abs().sum(
            1
        )  ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)

        assert zdiff1.size()[0] == len(sentence1_len)

        ###################################### selection ######################################

        # step2: pass the embedded words through the ReLU network [batch_size x max_length x hidden_size]
        embedded_x = self.relu_network(embedded_x)
        embedded_y = self.relu_network(embedded_y)

        # step3: pass the word vectors through the encoder [batch_size x max_length x hidden_size * num_directions]
        encoded_x = self.encoder(embedded_x, sentence1_len)
        # For the second sentences in batch
        encoded_y = self.encoder(embedded_y, sentence2_len)

        # step4: compute affinity matrix [batch_size x sent1_max_length x sent2_max_length]
        affinity_mat = torch.bmm(encoded_x, encoded_y.transpose(1, 2))

        # step5: compute conditioned representations [batch_size x max_length x hidden_size * num_directions]
        conditioned_x = torch.bmm(
            f.softmax(affinity_mat, 2).transpose(1, 2), encoded_x)
        conditioned_y = torch.bmm(
            f.softmax(affinity_mat.transpose(1, 2), 2).transpose(1, 2),
            encoded_y)

        # step6: generate input of the biattentive encoders [batch_size x max_length x hidden_size * num_directions * 3]
        biatt_input_x = torch.cat(
            (encoded_x, torch.abs(encoded_x - conditioned_y),
             torch.mul(encoded_x, conditioned_y)), 2)
        biatt_input_y = torch.cat(
            (encoded_y, torch.abs(encoded_y - conditioned_x),
             torch.mul(encoded_y, conditioned_x)), 2)

        # step7: pass the conditioned information through the biattentive encoders
        # [batch_size x max_length x hidden_size * num_directions]
        biatt_x = self.biatt_encoder1(biatt_input_x, sentence1_len)
        biatt_y = self.biatt_encoder2(biatt_input_y, sentence2_len)

        # step8: compute self-attentive pooling features
        att_weights_x = self.ffnn(biatt_x.view(-1, biatt_x.size(2))).squeeze(1)
        att_weights_x = f.softmax(att_weights_x.view(*biatt_x.size()[:-1]), 1)
        att_weights_y = self.ffnn(biatt_y.view(-1, biatt_y.size(2))).squeeze(1)
        att_weights_y = f.softmax(att_weights_y.view(*biatt_y.size()[:-1]), 1)
        self_att_x = torch.bmm(biatt_x.transpose(1, 2),
                               att_weights_x.unsqueeze(2)).squeeze(2)
        self_att_y = torch.bmm(biatt_y.transpose(1, 2),
                               att_weights_y.unsqueeze(2)).squeeze(2)

        # step9: compute the joint representations [batch_size x hidden_size * num_directions * 4]
        # print (' self_att_x size: ', self_att_x.size())
        pooled_x = torch.cat((biatt_x.max(1)[0], biatt_x.mean(1),
                              biatt_x.min(1)[0], self_att_x), 1)
        pooled_y = torch.cat((biatt_y.max(1)[0], biatt_y.mean(1),
                              biatt_y.min(1)[0], self_att_y), 1)

        # step10: pass the pooled representations through the maxout network
        score = self.maxout_network(torch.cat((pooled_x, pooled_y), 1))
        return score, sentence1_len, sentence2_len, zdiff1, zdiff2
class Seq2Seq(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(Seq2Seq, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(dictionary), self.config)
        self.embedding.init_embedding_weights(dictionary, embedding_index,
                                              self.config.emsize)

        self.encoder = Encoder(self.config.emsize, self.config.nhid_enc,
                               self.config.bidirection, self.config)
        if self.config.attn_type:
            self.decoder = AttentionDecoder(
                self.config.emsize, self.config.nhid_enc * self.num_directions,
                len(dictionary), self.config)
        else:
            self.decoder = Decoder(self.config.emsize,
                                   self.config.nhid_enc * self.num_directions,
                                   len(dictionary), self.config)

    @staticmethod
    def compute_decoding_loss(logits, target, seq_idx, length, regularize):
        losses = -torch.gather(logits, dim=1,
                               index=target.unsqueeze(1)).squeeze()
        mask = helper.mask(length, seq_idx)  # mask: batch x 1
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if regularize:
            regularized_loss = logits.exp().mul(logits).sum(
                1).squeeze() * regularize
            loss = losses.sum() + regularized_loss.sum()
            if not num_non_zero_elem:
                return loss, 0
            else:
                return loss, num_non_zero_elem[0]
        else:
            if not num_non_zero_elem:
                return losses.sum(), 0
            else:
                return losses.sum(), num_non_zero_elem[0]

    def forward(self, q1_var, q1_len, q2_var, q2_len):
        # encode the query
        embedded_q1 = self.embedding(q1_var)
        encoded_q1, hidden = self.encoder(embedded_q1, q1_len)

        if self.config.bidirection:
            if self.config.model == 'LSTM':
                h_t, c_t = hidden[0][-2:], hidden[1][-2:]
                decoder_hidden = torch.cat(
                    (h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2), torch.cat(
                        (c_t[0].unsqueeze(0), c_t[1].unsqueeze(0)), 2)
            else:
                h_t = hidden[0][-2:]
                decoder_hidden = torch.cat(
                    (h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2)
        else:
            if self.config.model == 'LSTM':
                decoder_hidden = hidden[0][-1], hidden[1][-1]
            else:
                decoder_hidden = hidden[-1]

        if self.config.attn_type:
            decoder_context = Variable(
                torch.zeros(encoded_q1.size(0),
                            encoded_q1.size(2))).unsqueeze(1)
            if self.config.cuda:
                decoder_context = decoder_context.cuda()

        decoding_loss, total_local_decoding_loss_element = 0, 0
        for idx in range(q2_var.size(1) - 1):
            input_variable = q2_var[:, idx]
            embedded_decoder_input = self.embedding(input_variable).unsqueeze(
                1)
            if self.config.attn_type:
                decoder_output, decoder_hidden, decoder_context, attn_weights = self.decoder(
                    embedded_decoder_input, decoder_hidden, decoder_context,
                    encoded_q1)
            else:
                decoder_output, decoder_hidden = self.decoder(
                    embedded_decoder_input, decoder_hidden)

            local_loss, num_local_loss = self.compute_decoding_loss(
                decoder_output, q2_var[:, idx + 1], idx, q2_len,
                self.config.regularize)
            decoding_loss += local_loss
            total_local_decoding_loss_element += num_local_loss

        if total_local_decoding_loss_element > 0:
            decoding_loss = decoding_loss / total_local_decoding_loss_element

        return decoding_loss
class CNN_ARC_I(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(CNN_ARC_I, self).__init__()
        self.dictionary = dictionary
        self.embedding_index = embedding_index
        self.config = args

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.convolution1 = nn.Conv1d(self.config.emsize, self.config.nfilters,
                                      1)
        self.convolution2 = nn.Conv1d(self.config.emsize, self.config.nfilters,
                                      2)
        self.convolution3 = nn.Conv1d(self.config.emsize,
                                      self.config.nfilters * 2, 3)
        self.ffnn = nn.Sequential(
            nn.Linear(self.config.nfilters * 8, self.config.nfilters * 4),
            nn.Linear(self.config.nfilters * 4, self.config.nfilters * 2),
            nn.Linear(self.config.nfilters * 2, 1))

        # Initializing the weight parameters for the embedding layer.
        self.embedding.init_embedding_weights(self.dictionary,
                                              self.embedding_index,
                                              self.config.emsize)

    def forward(self, batch_queries, batch_docs):
        """
        Forward function of the match tensor model. Return average loss for a batch of sessions.
        :param batch_queries: 2d tensor [batch_size x max_query_length]
        :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length]
        :return: average loss over batch [autograd Variable]
        """
        embedded_queries = self.embedding(batch_queries)
        embedded_docs = self.embedding(batch_docs.view(-1,
                                                       batch_docs.size(-1)))

        convolved_query_1 = self.convolution1(embedded_queries.transpose(
            1, 2)).transpose(1, 2)
        max_pooled_query_1 = torch.max(convolved_query_1, 1)[0].squeeze()
        convolved_doc_1 = self.convolution1(embedded_docs.transpose(
            1, 2)).transpose(1, 2)
        max_pooled_doc_1 = torch.max(convolved_doc_1, 1)[0].squeeze()

        convolved_query_2 = self.convolution2(embedded_queries.transpose(
            1, 2)).transpose(1, 2)
        max_pooled_query_2 = torch.max(convolved_query_2, 1)[0].squeeze()
        convolved_doc_2 = self.convolution2(embedded_docs.transpose(
            1, 2)).transpose(1, 2)
        max_pooled_doc_2 = torch.max(convolved_doc_2, 1)[0].squeeze()

        convolved_query_3 = self.convolution3(embedded_queries.transpose(
            1, 2)).transpose(1, 2)
        max_pooled_query_3 = torch.max(convolved_query_3, 1)[0].squeeze()
        convolved_doc_3 = self.convolution3(embedded_docs.transpose(
            1, 2)).transpose(1, 2)
        max_pooled_doc_3 = torch.max(convolved_doc_3, 1)[0].squeeze()

        query_rep = torch.cat(
            (max_pooled_query_1, max_pooled_query_2, max_pooled_query_3),
            1).unsqueeze(1)
        query_rep = query_rep.expand(*batch_docs.size()[0:2],
                                     query_rep.size(2))
        query_rep = query_rep.contiguous().view(-1, query_rep.size(2))
        doc_rep = torch.cat(
            (max_pooled_doc_1, max_pooled_doc_2, max_pooled_doc_3), 1)

        combined_representation = torch.cat((query_rep, doc_rep), 1)
        return F.log_softmax(
            self.ffnn(combined_representation).squeeze().view(
                *batch_docs.size()[0:2]))
Exemple #29
0
class Seq2Seq(nn.Module):
    """Class that classifies question pair as duplicate or not."""
    def __init__(self, dictionary, embedding_index, args):
        """"Constructor of the class."""
        super(Seq2Seq, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1
        self.dictionary = dictionary

        self.embedding = EmbeddingLayer(len(self.dictionary), self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)

        self.encoder = Encoder(self.config.input_size, self.config.nhid_enc,
                               self.config.bidirection, self.config)
        self.decoder = Decoder(self.config.emsize,
                               self.config.nhid_enc * self.num_directions,
                               len(self.dictionary), self.config)

    @staticmethod
    def compute_decoding_loss(logits, target, seq_idx, length):
        losses = -torch.gather(logits, dim=1,
                               index=target.unsqueeze(1)).squeeze()
        mask = helper.mask(length, seq_idx)  # mask: batch x 1
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if not num_non_zero_elem:
            return losses.sum(), 0
        else:
            return losses.sum(), num_non_zero_elem[0]

    def forward(self, videos, video_len, decoder_input, target_length):
        # encode the video features
        encoded_videos = self.encoder(videos, video_len)

        if self.config.pool_type == 'max':
            hidden_states = torch.max(encoded_videos, 1)[0].squeeze()
        elif self.config.pool_type == 'mean':
            hidden_states = torch.sum(encoded_videos,
                                      1).squeeze() / encoded_videos.size(1)
        elif self.config.pool_type == 'last':
            if self.num_directions == 2:
                hidden_states = torch.cat(
                    (encoded_videos[:, -1, :self.config.nhid_enc],
                     encoded_videos[:, -1, self.config.nhid_enc:]), 1)
            else:
                hidden_states = encoded_videos[:, -1, :]

        # Initialize hidden states of decoder with the last hidden states of the encoder
        if self.config.model is 'LSTM':
            cell_states = Variable(torch.zeros(*hidden_states.size()))
            if self.config.cuda:
                cell_states = cell_states.cuda()
            decoder_hidden = (hidden_states.unsqueeze(0).contiguous(),
                              cell_states.unsqueeze(0).contiguous())
        else:
            decoder_hidden = hidden_states.unsqueeze(0).contiguous()

        decoding_loss = 0
        total_local_decoding_loss_element = 0
        for idx in range(decoder_input.size(1) - 1):
            input_variable = decoder_input[:, idx]
            embedded_decoder_input = self.embedding(input_variable).unsqueeze(
                1)
            decoder_output, decoder_hidden = self.decoder(
                embedded_decoder_input, decoder_hidden)
            target_variable = decoder_input[:, idx + 1]

            local_loss, num_local_loss = self.compute_decoding_loss(
                decoder_output, target_variable, idx, target_length)
            decoding_loss += local_loss
            total_local_decoding_loss_element += num_local_loss

        if total_local_decoding_loss_element > 0:
            decoding_loss = decoding_loss / total_local_decoding_loss_element

        return decoding_loss
Exemple #30
0
class BCN(nn.Module):
    """Biattentive classification network architecture for sentence classification."""
    def __init__(self, dictionary, embedding_index, class_distributions, args):
        """"Constructor of the class."""
        super(BCN, self).__init__()
        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1
        self.dictionary = dictionary
        self.class_distributions = class_distributions  #dict of class counts

        #Model definition
        if self.config.pos:
            self.embedding_pos = EmbeddingLayer(len(pos_to_idx),
                                                POS_EMBEDDING_DIM, True,
                                                self.config)
            self.embedding_pos.init_pos_weights(pos_to_idx, POS_EMBEDDING_DIM)

        self.embedding = EmbeddingLayer(len(self.dictionary),
                                        self.config.emsize,
                                        self.config.emtraining, self.config)
        self.embedding.init_embedding_weights(self.dictionary, embedding_index,
                                              self.config.emsize)

        if self.config.pos:
            self.relu_network = nn.Sequential(
                OrderedDict([('dense1',
                              nn.Linear(self.config.emsize + POS_EMBEDDING_DIM,
                                        self.config.nhid)),
                             ('nonlinearity', nn.ReLU())]))
        else:
            self.relu_network = nn.Sequential(
                OrderedDict([('dense1',
                              nn.Linear(self.config.emsize, self.config.nhid)),
                             ('nonlinearity', nn.ReLU())]))

        self.encoder = Encoder(self.config.nhid, self.config.nhid,
                               self.config.bidirection, self.config.nlayers,
                               self.config)
        self.biatt_encoder1 = Encoder(
            self.config.nhid * self.num_directions * 3, self.config.nhid,
            self.config.bidirection, 1, self.config)
        self.biatt_encoder2 = Encoder(
            self.config.nhid * self.num_directions * 3, self.config.nhid,
            self.config.bidirection, 1, self.config)

        self.ffnn = nn.Linear(self.config.nhid * self.num_directions, 1)
        self.maxout_network = MaxoutNetwork(self.config.nhid *
                                            self.num_directions * 4 * 2,
                                            self.config.num_class,
                                            num_units=self.config.num_units)
        print("BCN init num_units: ", self.config.num_class)

    def forward(self,
                sentence1,
                sentence1_len,
                sentence2,
                sentence2_len,
                pos_sent1=None,
                pos_sent2=None):
        """
        Forward computation of the biattentive classification network.
        Returns classification scores for a batch of sentence pairs.
        :param sentence1: 2d tensor [batch_size x max_length]
        :param sentence1_len: 1d numpy array [batch_size]
        :param sentence2: 2d tensor [batch_size x max_length]
        :param sentence2_len: 1d numpy array [batch_size]
        :return: classification scores over batch [batch_size x num_classes]
        """
        # step1: embed the words into vectors [batch_size x max_length x emsize]

        embedded_x = self.embedding(sentence1)
        embedded_y = self.embedding(sentence2)

        if self.config.pos and pos_sent1 is not None and pos_sent2 is not None:
            embedded_pos_x = self.embedding_pos(pos_sent1)
            embedded_pos_y = self.embedding_pos(pos_sent2)

            embedded_x = torch.cat((embedded_x, embedded_pos_x), 2)
            embedded_y = torch.cat((embedded_y, embedded_pos_y), 2)

        # step2: pass the embedded words through the ReLU network [batch_size x max_length x hidden_size]
        embedded_x = self.relu_network(embedded_x)
        embedded_y = self.relu_network(embedded_y)

        # step3: pass the word vectors through the encoder [batch_size x max_length x hidden_size * num_directions]
        encoded_x = self.encoder(embedded_x, sentence1_len)
        # For the second sentences in batch
        encoded_y = self.encoder(embedded_y, sentence2_len)

        # step4: compute affinity matrix [batch_size x sent1_max_length x sent2_max_length]
        affinity_mat = torch.bmm(encoded_x, encoded_y.transpose(1, 2))

        # step5: compute conditioned representations [batch_size x max_length x hidden_size * num_directions]

        if PC:
            conditioned_x = torch.bmm(
                f.softmax(affinity_mat).transpose(1, 2), encoded_x)
            conditioned_y = torch.bmm(
                f.softmax(affinity_mat.transpose(1, 2)).transpose(1, 2),
                encoded_y)
        else:
            conditioned_x = torch.bmm(
                f.softmax(affinity_mat, 2).transpose(1, 2), encoded_x)
            conditioned_y = torch.bmm(
                f.softmax(affinity_mat.transpose(1, 2), 2).transpose(1, 2),
                encoded_y)

        # step6: generate input of the biattentive encoders [batch_size x max_length x hidden_size * num_directions * 3]
        biatt_input_x = torch.cat(
            (encoded_x, torch.abs(encoded_x - conditioned_y),
             torch.mul(encoded_x, conditioned_y)), 2)
        biatt_input_y = torch.cat(
            (encoded_y, torch.abs(encoded_y - conditioned_x),
             torch.mul(encoded_y, conditioned_x)), 2)

        # step7: pass the conditioned information through the biattentive encoders
        # [batch_size x max_length x hidden_size * num_directions]
        biatt_x = self.biatt_encoder1(biatt_input_x, sentence1_len)
        biatt_y = self.biatt_encoder2(biatt_input_y, sentence2_len)

        # step8: compute self-attentive pooling features
        att_weights_x = self.ffnn(biatt_x.view(-1, biatt_x.size(2))).squeeze(1)
        if PC:
            att_weights_x = f.softmax(att_weights_x.view(*biatt_x.size()[:-1]))
        else:
            att_weights_x = f.softmax(att_weights_x.view(*biatt_x.size()[:-1]),
                                      1)

        att_weights_y = self.ffnn(biatt_y.view(-1, biatt_y.size(2))).squeeze(1)

        if PC:
            att_weights_y = f.softmax(att_weights_y.view(*biatt_y.size()[:-1]))
        else:
            att_weights_y = f.softmax(att_weights_y.view(*biatt_y.size()[:-1]),
                                      1)

        self_att_x = torch.bmm(biatt_x.transpose(1, 2),
                               att_weights_x.unsqueeze(2)).squeeze(2)
        self_att_y = torch.bmm(biatt_y.transpose(1, 2),
                               att_weights_y.unsqueeze(2)).squeeze(2)

        # step9: compute the joint representations [batch_size x hidden_size * num_directions * 4]
        # print (' self_att_x size: ', self_att_x.size())
        pooled_x = torch.cat((biatt_x.max(1)[0], biatt_x.mean(1),
                              biatt_x.min(1)[0], self_att_x), 1)
        pooled_y = torch.cat((biatt_y.max(1)[0], biatt_y.mean(1),
                              biatt_y.min(1)[0], self_att_y), 1)

        # step10: pass the pooled representations through the maxout network
        score = self.maxout_network(torch.cat((pooled_x, pooled_y), 1))
        return score