class NSRM(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(NSRM, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.query_encoder = Encoder(self.config.emsize, self.config.nhid_query, True, self.config) self.document_encoder = Encoder(self.config.emsize, self.config.nhid_doc, True, self.config) self.session_encoder = Encoder( self.config.nhid_query * self.num_directions, self.config.nhid_session, False, self.config) self.projection = nn.Linear( (self.config.nhid_query * self.num_directions) + self.config.nhid_session, self.config.nhid_doc * self.num_directions) self.decoder = Decoder(self.config.emsize, self.config.nhid_session, len(self.dictionary), self.config) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) @staticmethod def compute_decoding_loss(logits, target, seq_idx, length): """ Compute negative log-likelihood loss for a batch of predictions. :param logits: 2d tensor [batch_size x vocab_size] :param target: 2d tensor [batch_size x 1] :param seq_idx: an integer represents the current index of the sequences :param length: 1d tensor [batch_size], represents each sequences' true length :return: total loss over the input mini-batch [autograd Variable] and number of loss elements """ losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)) mask = helper.mask(length, seq_idx) # mask: batch x 1 losses = losses * mask.float() num_non_zero_elem = torch.nonzero(mask.data).size() if not num_non_zero_elem: return losses.sum(), 0 else: return losses.sum(), num_non_zero_elem[0] @staticmethod def compute_click_loss(logits, target): """ Compute logistic loss for a batch of clicks. Return average loss for the input mini-batch. :param logits: 2d tensor [batch_size x num_clicks_per_query] :param target: 2d tensor [batch_size x num_clicks_per_query] :return: average loss over batch [autograd Variable] """ # taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L695 neg_abs = -logits.abs() loss = logits.clamp(min=0) - logits * target + (1 + neg_abs.exp()).log() return loss.mean() def forward(self, batch_session, length, batch_clicks, click_labels): """ Forward function of the neural click model. Return average loss for a batch of sessions. :param batch_session: 3d tensor [batch_size x session_length x max_query_length] :param length: 2d tensor [batch_size x session_length] :param batch_clicks: 4d tensor [batch_size x session_length x num_rel_docs_per_query x max_document_length] :param click_labels: 3d tensor [batch_size x session_length x num_rel_docs_per_query] :return: average loss over batch [autograd Variable] """ # query level encoding embedded_queries = self.embedding( batch_session.view(-1, batch_session.size(-1))) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.query_encoder.init_weights( embedded_queries.size(0)) output, hidden = self.query_encoder(embedded_queries, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.query_encoder.init_weights( embedded_queries.size(0)) output, hidden = self.query_encoder(embedded_queries, encoder_hidden) encoded_queries = torch.max(output, 1)[0].squeeze(1) # encoded_queries = batch_size x num_queries_in_a_session x hidden_size encoded_queries = encoded_queries.view(*batch_session.size()[:-1], -1) # document level encoding embedded_clicks = self.embedding( batch_clicks.view(-1, batch_clicks.size(-1))) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.document_encoder.init_weights( embedded_clicks.size(0)) output, hidden = self.document_encoder( embedded_clicks, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.document_encoder.init_weights( embedded_clicks.size(0)) output, hidden = self.document_encoder(embedded_clicks, encoder_hidden) encoded_clicks = torch.max(output, 1)[0].squeeze(1) # encoded_clicks = batch_size x num_queries_in_a_session x num_rel_docs_per_query x hidden_size encoded_clicks = encoded_clicks.view(*batch_clicks.size()[:-1], -1) # session level encoding sess_hidden = self.session_encoder.init_weights( encoded_queries.size(0)) sess_output = Variable( torch.zeros(self.config.batch_size, 1, self.config.nhid_session)) if self.config.cuda: sess_output = sess_output.cuda() hidden_states, cell_states = [], [] click_loss = 0 for idx in range(encoded_queries.size(1)): combined_rep = torch.cat( (sess_output.squeeze(), encoded_queries[:, idx, :]), 1) combined_rep = self.projection(combined_rep) combined_rep = combined_rep.unsqueeze(1).expand( *encoded_clicks[:, idx, :, :].size()) click_score = torch.sum( torch.mul(combined_rep, encoded_clicks[:, idx, :, :]), 2).squeeze(2) click_loss += self.compute_click_loss(click_score, click_labels[:, idx, :]) # update session state using query representations sess_output, sess_hidden = self.session_encoder( encoded_queries[:, idx, :].unsqueeze(1), sess_hidden) hidden_states.append(sess_hidden[0]) cell_states.append(sess_hidden[1]) click_loss = click_loss / encoded_queries.size(1) hidden_states = torch.stack(hidden_states, 2).squeeze(0) cell_states = torch.stack(cell_states, 2).squeeze(0) # decoding in sequence-to-sequence learning hidden_states = hidden_states[:, :-1, :].contiguous().view( -1, hidden_states.size(-1)).unsqueeze(0) cell_states = cell_states[:, :-1, :].contiguous().view( -1, cell_states.size(-1)).unsqueeze(0) decoder_input = batch_session[:, 1:, :].contiguous().view( -1, batch_session.size(-1)) target_length = length[:, 1:].contiguous().view(-1) input_variable = Variable( torch.LongTensor(decoder_input.size(0)).fill_( self.dictionary.word2idx[self.dictionary.start_token])) if self.config.cuda: input_variable = input_variable.cuda() hidden_states = hidden_states.cuda() cell_states = cell_states.cuda() # Initialize hidden states of decoder with the last hidden states of the session encoder decoder_hidden = (hidden_states, cell_states) decoding_loss = 0 total_local_decoding_loss_element = 0 for idx in range(decoder_input.size(1)): if idx != 0: input_variable = decoder_input[:, idx - 1] embedded_decoder_input = self.embedding(input_variable).unsqueeze( 1) decoder_output, decoder_hidden = self.decoder( embedded_decoder_input, decoder_hidden) target_variable = decoder_input[:, idx] local_loss, num_local_loss = self.compute_decoding_loss( decoder_output, target_variable, idx, target_length) decoding_loss += local_loss total_local_decoding_loss_element += num_local_loss if total_local_decoding_loss_element > 0: decoding_loss = decoding_loss / total_local_decoding_loss_element return click_loss + decoding_loss
class Sequence2Sequence(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(Sequence2Sequence, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.query_encoder = Encoder(self.config.emsize, self.config.nhid_query, self.config) self.session_encoder = Encoder(self.config.nhid_query, self.config.nhid_session, self.config) self.decoder = Decoder(self.config.emsize, self.config.nhid_session, len(self.dictionary), self.config) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) @staticmethod def compute_loss(logits, target, seq_idx, length): # logits: batch x vocab_size, target: batch x 1 losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)) # mask: batch x 1 mask = helper.mask(length, seq_idx) losses = losses * mask.float() num_non_zero_elem = torch.nonzero(mask.data).size() if not num_non_zero_elem: loss = losses.sum() else: loss = losses.sum() / num_non_zero_elem[0] return loss def forward(self, batch_session, length): """"Defines the forward computation of the question classifier.""" embedded_input = self.embedding( batch_session.view(-1, batch_session.size(-1))) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.query_encoder.init_weights( embedded_input.size(0)) output, hidden = self.query_encoder(embedded_input, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.query_encoder.init_weights( embedded_input.size(0)) output, hidden = self.query_encoder(embedded_input, encoder_hidden) if self.config.bidirection: output = torch.div( torch.add( output[:, :, 0:self.config.nhid_query], output[:, :, self.config.nhid_query:2 * self.config.nhid_query]), 2) session_input = output[:, -1, :].contiguous().view( batch_session.size(0), batch_session.size(1), -1) # session level encoding sess_hidden = self.session_encoder.init_weights(session_input.size(0)) hidden_states, cell_states = [], [] for idx in range(session_input.size(1)): sess_output, sess_hidden = self.session_encoder( session_input[:, idx, :].unsqueeze(1), sess_hidden) if self.config.bidirection: hidden_states.append(torch.mean(sess_hidden[0], 0)) cell_states.append(torch.mean(sess_hidden[1], 0)) else: hidden_states.append(sess_hidden[0]) cell_states.append(sess_hidden[1]) hidden_states = torch.stack(hidden_states, 2).squeeze(0) cell_states = torch.stack(cell_states, 2).squeeze(0) hidden_states = hidden_states[:, :-1, :].contiguous().view( -1, hidden_states.size(-1)).unsqueeze(0) cell_states = cell_states[:, :-1, :].contiguous().view( -1, cell_states.size(-1)).unsqueeze(0) decoder_input = batch_session[:, 1:, :].contiguous().view( -1, batch_session.size(-1)) target_length = length[:, 1:].contiguous().view(-1) input_variable = Variable( torch.LongTensor(decoder_input.size(0)).fill_( self.dictionary.word2idx[self.dictionary.start_token])) if self.config.cuda: input_variable = input_variable.cuda() hidden_states = hidden_states.cuda() cell_states = cell_states.cuda() # Initialize hidden states of decoder with the last hidden states of the session encoder decoder_hidden = (hidden_states, cell_states) loss = 0 for idx in range(decoder_input.size(1)): if idx != 0: input_variable = decoder_input[:, idx - 1] embedded_decoder_input = self.embedding(input_variable).unsqueeze( 1) decoder_output, decoder_hidden = self.decoder( embedded_decoder_input, decoder_hidden) target_variable = decoder_input[:, idx] loss += self.compute_loss(decoder_output, target_variable, idx, target_length) return loss
class MatchTensor(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(MatchTensor, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.linear_projection = nn.Linear(self.config.emsize, self.config.featsize) self.query_encoder = Encoder(self.config.featsize, self.config.nhid_query, True, self.config) self.document_encoder = Encoder(self.config.featsize, self.config.nhid_doc, True, self.config) self.query_projection = nn.Linear( self.config.nhid_query * self.num_directions, self.config.nchannels) self.document_projection = nn.Linear( self.config.nhid_doc * self.num_directions, self.config.nchannels) self.exact_match_channel = ExactMatchChannel() self.conv1 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 3), padding=1) self.conv2 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 5), padding=(1, 2)) self.conv3 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 7), padding=(1, 3)) self.relu = nn.ReLU() self.conv = nn.Conv2d(self.config.nfilters * 3, self.config.match_filter_size, (1, 1)) self.output = nn.Linear(self.config.match_filter_size, 1) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) def forward(self, batch_queries, batch_docs): """ Forward function of the match tensor model. Return average loss for a batch of sessions. :param batch_queries: 2d tensor [batch_size x max_query_length] :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length] :return: average loss over batch [autograd Variable] """ embedded_queries = self.embedding(batch_queries) embedded_docs = self.embedding(batch_docs.view(-1, batch_docs.size(-1))) embedded_queries = self.linear_projection( embedded_queries.view(-1, embedded_queries.size(-1))) embedded_docs = self.linear_projection( embedded_docs.view(-1, embedded_docs.size(-1))) embedded_queries = embedded_queries.view(*batch_queries.size(), self.config.featsize) embedded_docs = embedded_docs.view(-1, batch_docs.size()[-1], self.config.featsize) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.query_encoder.init_weights( embedded_queries.size(0)) output, hidden = self.query_encoder(embedded_queries, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.query_encoder.init_weights( embedded_queries.size(0)) output, hidden = self.query_encoder(embedded_queries, encoder_hidden) embedded_queries = self.query_projection( output.view(-1, output.size()[-1])).view(*batch_queries.size(), -1) embedded_queries = embedded_queries.unsqueeze(1).expand( embedded_queries.size(0), batch_docs.size(1), *embedded_queries.size()[1:]) embedded_queries = embedded_queries.contiguous().view( -1, *embedded_queries.size()[2:]) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.document_encoder.init_weights( embedded_docs.size(0)) output, hidden = self.document_encoder( embedded_docs, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.document_encoder.init_weights( embedded_docs.size(0)) output, hidden = self.document_encoder(embedded_docs, encoder_hidden) embedded_docs = self.document_projection( output.view(-1, output.size()[-1])) embedded_docs = embedded_docs.view(-1, batch_docs.size(2), embedded_docs.size()[-1]) embedded_queries = embedded_queries.unsqueeze(2).expand( *embedded_queries.size()[:2], batch_docs.size()[-1], embedded_queries.size(2)) embedded_docs = embedded_docs.unsqueeze(1).expand( embedded_docs.size(0), batch_queries.size()[-1], *embedded_docs.size()[1:]) query_document_product = embedded_queries * embedded_docs exact_match = self.exact_match_channel(batch_queries, batch_docs).unsqueeze(3) query_document_product = torch.cat( (query_document_product, exact_match), 3) query_document_product = query_document_product.transpose(2, 3).transpose( 1, 2) convoluted_feat1 = self.conv1(query_document_product) convoluted_feat2 = self.conv2(query_document_product) convoluted_feat3 = self.conv3(query_document_product) convoluted_feat = self.relu( torch.cat((convoluted_feat1, convoluted_feat2, convoluted_feat3), 1)) convoluted_feat = self.conv(convoluted_feat).transpose(1, 2).transpose( 2, 3) max_pooled_feat = torch.max(convoluted_feat, 2)[0].squeeze() max_pooled_feat = torch.max(max_pooled_feat, 1)[0].squeeze() return self.output(max_pooled_feat).squeeze().view( *batch_docs.size()[:2])