class Discriminator(nn.Module): def __init__(self, bsize, embed_dim, encod_dim, embed_dim_policy, encod_dim_policy, numlabel, rec_num, numclass=2, feature_vec=None, init_embed=False, model='LSTM'): super(Discriminator, self).__init__() # classifier self.batch_size = bsize self.nonlinear_fc = False self.n_classes = numclass self.enc_lstm_dim = encod_dim self.encoder_type = 'Encoder' self.model = model self.embedding = EmbeddingLayer(numlabel, embed_dim) if init_embed: self.embedding.init_embedding_weights(feature_vec, embed_dim) self.encoder = eval(self.encoder_type)(self.batch_size, embed_dim + rec_num + 1, self.enc_lstm_dim, self.model, 1) self.enc2out = nn.Linear(self.enc_lstm_dim, self.n_classes) self.rec2enc = nn.Linear(embed_dim * rec_num, rec_num) def forward(self, seq, reward, rec): # seq : (seq, seq_len) seq_em, seq_len = seq seq_em = self.embedding(seq_em) # rescale the recommendation list rec = rec.permute(0, 2, 1) rec_em = rec.contiguous().view(-1, rec.size(2)) rec_em = self.embedding(rec_em) rec_em = rec_em.view(rec.size(0), rec.size(1), -1) rec_em = self.rec2enc(rec_em) #Concatenate with the reward seq_em = torch.cat((seq_em, rec_em, reward.unsqueeze(2)), 2) if self.model == 'LSTM': enc_out, (h, c) = self.encoder((seq_em, seq_len)) else: enc_out, h = self.encoder((seq_em, seq_len)) # Mean pooling seq_len = torch.FloatTensor(seq_len.copy()).unsqueeze(1).cuda() enc_out = torch.sum(enc_out, 1).squeeze(1) enc_out = enc_out / seq_len.expand_as(enc_out) # Extract the last hidden layer output = self.enc2out(enc_out) #output = self.enc2out(h.squeeze(0))#batch*hidden output = F.log_softmax(output, dim=1) #batch*n_classes return output
class Selector(nn.Module): """Biattentive classification network architecture for sentence classification.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(Selector, self).__init__() self.config = args self.dictionary = dictionary self.embedding = EmbeddingLayer(len(self.dictionary), self.config.emsize, self.config.emtraining, self.config) self.embedding.init_embedding_weights(self.dictionary, embedding_index, self.config.emsize) self.emsize = args.emsize self.num_labels = args.num_class self.linear = nn.Linear(self.emsize, self.num_labels) def forward(self, sentence1, threshold=0.5, is_train=0): embedded_x1 = self.embedding(sentence1) score = self.linear(embedded_x1) # print("linear size: ", score.size()) score = score.squeeze(1) # print('output size: ', score.size()) return score
class CNN_ARC_II(nn.Module): """Implementation of the convolutional matching model (ARC-II).""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(CNN_ARC_II, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.conv1 = nn.Conv2d(self.config.emsize * 2, self.config.nfilters, (3, 3)) self.pool1 = nn.MaxPool2d((2, 2)) self.conv2 = nn.Conv2d(self.config.nfilters, self.config.nfilters, (2, 2)) self.ffnn = nn.Sequential( nn.Linear(self.config.nfilters * 4, self.config.nfilters * 2), nn.Linear(self.config.nfilters * 2, self.config.nfilters), nn.Linear(self.config.nfilters, 1)) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) def forward(self, batch_queries, batch_docs): """ Forward function of the match tensor model. Return average loss for a batch of sessions. :param batch_queries: 2d tensor [batch_size x max_query_length] :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length] :return: score representing click probability [batch_size x num_clicks_per_query] """ embedded_queries = self.embedding(batch_queries) embedded_queries = embedded_queries.unsqueeze(1).expand( *batch_docs.size()[:2], *embedded_queries.size()[1:]) embedded_queries = embedded_queries.contiguous().view( -1, *embedded_queries.size()[2:]) embedded_docs = self.embedding(batch_docs.view(-1, batch_docs.size(-1))) embedded_queries = embedded_queries.unsqueeze(1).expand( embedded_queries.size(0), batch_docs.size(2), *embedded_queries.size()[1:]) embedded_docs = embedded_docs.unsqueeze(2).expand( *embedded_docs.size()[:2], batch_queries.size(1), embedded_docs.size(2)) combined_rep = torch.cat((embedded_queries, embedded_docs), 3) combined_rep = combined_rep.transpose(2, 3).transpose(1, 2) conv1_out = self.pool1(F.relu(self.conv1(combined_rep))) conv2_out = self.pool1(F.relu(self.conv2(conv1_out))).squeeze().view( -1, self.config.nfilters * 4) return F.log_softmax( self.ffnn(conv2_out).squeeze().view(*batch_docs.size()[0:2]), 1)
class CNN_ARC_I(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(CNN_ARC_I, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.convolution1 = nn.Conv1d(self.config.emsize, self.config.nfilters, 1) self.convolution2 = nn.Conv1d(self.config.emsize, self.config.nfilters, 2) self.convolution3 = nn.Conv1d(self.config.emsize, self.config.nfilters * 2, 3) self.ffnn = nn.Sequential( nn.Linear(self.config.nfilters * 8, self.config.nfilters * 4), nn.Linear(self.config.nfilters * 4, self.config.nfilters * 2), nn.Linear(self.config.nfilters * 2, 1)) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) def forward(self, batch_queries, batch_docs): """ Forward function of the match tensor model. Return average loss for a batch of sessions. :param batch_queries: 2d tensor [batch_size x max_query_length] :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length] :return: average loss over batch [autograd Variable] """ embedded_queries = self.embedding(batch_queries) embedded_docs = self.embedding(batch_docs.view(-1, batch_docs.size(-1))) convolved_query_1 = self.convolution1(embedded_queries.transpose( 1, 2)).transpose(1, 2) max_pooled_query_1 = torch.max(convolved_query_1, 1)[0].squeeze() convolved_doc_1 = self.convolution1(embedded_docs.transpose( 1, 2)).transpose(1, 2) max_pooled_doc_1 = torch.max(convolved_doc_1, 1)[0].squeeze() convolved_query_2 = self.convolution2(embedded_queries.transpose( 1, 2)).transpose(1, 2) max_pooled_query_2 = torch.max(convolved_query_2, 1)[0].squeeze() convolved_doc_2 = self.convolution2(embedded_docs.transpose( 1, 2)).transpose(1, 2) max_pooled_doc_2 = torch.max(convolved_doc_2, 1)[0].squeeze() convolved_query_3 = self.convolution3(embedded_queries.transpose( 1, 2)).transpose(1, 2) max_pooled_query_3 = torch.max(convolved_query_3, 1)[0].squeeze() convolved_doc_3 = self.convolution3(embedded_docs.transpose( 1, 2)).transpose(1, 2) max_pooled_doc_3 = torch.max(convolved_doc_3, 1)[0].squeeze() query_rep = torch.cat( (max_pooled_query_1, max_pooled_query_2, max_pooled_query_3), 1).unsqueeze(1) query_rep = query_rep.expand(*batch_docs.size()[0:2], query_rep.size(2)) query_rep = query_rep.contiguous().view(-1, query_rep.size(2)) doc_rep = torch.cat( (max_pooled_doc_1, max_pooled_doc_2, max_pooled_doc_3), 1) combined_representation = torch.cat((query_rep, doc_rep), 1) return F.log_softmax( self.ffnn(combined_representation).squeeze().view( *batch_docs.size()[0:2]))
class BCN(nn.Module): """Biattentive classification network architecture for sentence classification.""" def __init__(self, dictionary, embedding_index, class_distributions, args): """"Constructor of the class.""" super(BCN, self).__init__() self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.dictionary = dictionary self.class_distributions = class_distributions #dict of class counts #Model definition if self.config.pos: self.embedding_pos = EmbeddingLayer(len(pos_to_idx), POS_EMBEDDING_DIM, True, self.config) self.embedding_pos.init_pos_weights(pos_to_idx, POS_EMBEDDING_DIM) self.embedding = EmbeddingLayer(len(self.dictionary), self.config.emsize, self.config.emtraining, self.config) self.embedding.init_embedding_weights(self.dictionary, embedding_index, self.config.emsize) if self.config.pos: self.relu_network = nn.Sequential( OrderedDict([('dense1', nn.Linear(self.config.emsize + POS_EMBEDDING_DIM, self.config.nhid)), ('nonlinearity', nn.ReLU())])) else: self.relu_network = nn.Sequential( OrderedDict([('dense1', nn.Linear(self.config.emsize, self.config.nhid)), ('nonlinearity', nn.ReLU())])) self.encoder = Encoder(self.config.nhid, self.config.nhid, self.config.bidirection, self.config.nlayers, self.config) self.biatt_encoder1 = Encoder( self.config.nhid * self.num_directions * 3, self.config.nhid, self.config.bidirection, 1, self.config) self.biatt_encoder2 = Encoder( self.config.nhid * self.num_directions * 3, self.config.nhid, self.config.bidirection, 1, self.config) self.ffnn = nn.Linear(self.config.nhid * self.num_directions, 1) self.maxout_network = MaxoutNetwork(self.config.nhid * self.num_directions * 4 * 2, self.config.num_class, num_units=self.config.num_units) print("BCN init num_units: ", self.config.num_class) def forward(self, sentence1, sentence1_len, sentence2, sentence2_len, pos_sent1=None, pos_sent2=None): """ Forward computation of the biattentive classification network. Returns classification scores for a batch of sentence pairs. :param sentence1: 2d tensor [batch_size x max_length] :param sentence1_len: 1d numpy array [batch_size] :param sentence2: 2d tensor [batch_size x max_length] :param sentence2_len: 1d numpy array [batch_size] :return: classification scores over batch [batch_size x num_classes] """ # step1: embed the words into vectors [batch_size x max_length x emsize] embedded_x = self.embedding(sentence1) embedded_y = self.embedding(sentence2) if self.config.pos and pos_sent1 is not None and pos_sent2 is not None: embedded_pos_x = self.embedding_pos(pos_sent1) embedded_pos_y = self.embedding_pos(pos_sent2) embedded_x = torch.cat((embedded_x, embedded_pos_x), 2) embedded_y = torch.cat((embedded_y, embedded_pos_y), 2) # step2: pass the embedded words through the ReLU network [batch_size x max_length x hidden_size] embedded_x = self.relu_network(embedded_x) embedded_y = self.relu_network(embedded_y) # step3: pass the word vectors through the encoder [batch_size x max_length x hidden_size * num_directions] encoded_x = self.encoder(embedded_x, sentence1_len) # For the second sentences in batch encoded_y = self.encoder(embedded_y, sentence2_len) # step4: compute affinity matrix [batch_size x sent1_max_length x sent2_max_length] affinity_mat = torch.bmm(encoded_x, encoded_y.transpose(1, 2)) # step5: compute conditioned representations [batch_size x max_length x hidden_size * num_directions] if PC: conditioned_x = torch.bmm( f.softmax(affinity_mat).transpose(1, 2), encoded_x) conditioned_y = torch.bmm( f.softmax(affinity_mat.transpose(1, 2)).transpose(1, 2), encoded_y) else: conditioned_x = torch.bmm( f.softmax(affinity_mat, 2).transpose(1, 2), encoded_x) conditioned_y = torch.bmm( f.softmax(affinity_mat.transpose(1, 2), 2).transpose(1, 2), encoded_y) # step6: generate input of the biattentive encoders [batch_size x max_length x hidden_size * num_directions * 3] biatt_input_x = torch.cat( (encoded_x, torch.abs(encoded_x - conditioned_y), torch.mul(encoded_x, conditioned_y)), 2) biatt_input_y = torch.cat( (encoded_y, torch.abs(encoded_y - conditioned_x), torch.mul(encoded_y, conditioned_x)), 2) # step7: pass the conditioned information through the biattentive encoders # [batch_size x max_length x hidden_size * num_directions] biatt_x = self.biatt_encoder1(biatt_input_x, sentence1_len) biatt_y = self.biatt_encoder2(biatt_input_y, sentence2_len) # step8: compute self-attentive pooling features att_weights_x = self.ffnn(biatt_x.view(-1, biatt_x.size(2))).squeeze(1) if PC: att_weights_x = f.softmax(att_weights_x.view(*biatt_x.size()[:-1])) else: att_weights_x = f.softmax(att_weights_x.view(*biatt_x.size()[:-1]), 1) att_weights_y = self.ffnn(biatt_y.view(-1, biatt_y.size(2))).squeeze(1) if PC: att_weights_y = f.softmax(att_weights_y.view(*biatt_y.size()[:-1])) else: att_weights_y = f.softmax(att_weights_y.view(*biatt_y.size()[:-1]), 1) self_att_x = torch.bmm(biatt_x.transpose(1, 2), att_weights_x.unsqueeze(2)).squeeze(2) self_att_y = torch.bmm(biatt_y.transpose(1, 2), att_weights_y.unsqueeze(2)).squeeze(2) # step9: compute the joint representations [batch_size x hidden_size * num_directions * 4] # print (' self_att_x size: ', self_att_x.size()) pooled_x = torch.cat((biatt_x.max(1)[0], biatt_x.mean(1), biatt_x.min(1)[0], self_att_x), 1) pooled_y = torch.cat((biatt_y.max(1)[0], biatt_y.mean(1), biatt_y.min(1)[0], self_att_y), 1) # step10: pass the pooled representations through the maxout network score = self.maxout_network(torch.cat((pooled_x, pooled_y), 1)) return score
class HRED_QS(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(HRED_QS, self).__init__() self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.embedding = EmbeddingLayer(len(dictionary), self.config) self.embedding.init_embedding_weights(dictionary, embedding_index, self.config.emsize) self.query_encoder = Encoder(self.config.emsize, self.config.nhid_query, self.config.bidirection, self.config) self.session_encoder = EncoderCell( self.config.nhid_query * self.num_directions, self.config.nhid_session, False, self.config) self.decoder = DecoderCell(self.config.emsize, self.config.nhid_session, len(dictionary), self.config) @staticmethod def compute_loss(logits, target, seq_idx, length, regularize): """ Compute negative log-likelihood loss for a batch of predictions. :param logits: 2d tensor [batch_size x vocab_size] :param target: 1d tensor [batch_size] :param seq_idx: an integer represents the current index of the sequences :param length: 1d tensor [batch_size], represents each sequences' true length :param regularize: boolean, whether use entropy regularization in loss computation :return: total loss over the input mini-batch [autograd Variable] and number of loss elements """ losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)).squeeze() mask = helper.mask(length, seq_idx) # mask: batch x 1 losses = losses * mask.float() num_non_zero_elem = torch.nonzero(mask.data).size() if regularize: regularized_loss = logits.exp().mul(logits).sum( 1).squeeze() * regularize loss = losses.sum() + regularized_loss.sum() if not num_non_zero_elem: return loss, 0 else: return loss, num_non_zero_elem[0] else: if not num_non_zero_elem: return losses.sum(), 0 else: return losses.sum(), num_non_zero_elem[0] def forward(self, session_queries, session_query_length): """ Forward function of the neural click model. Return average loss for a batch of sessions. :param session_queries: 3d tensor [batch_size x session_length x max_query_length] :param session_query_length: 2d tensor [batch_size x session_length] :return: average loss over batch [autograd Variable] """ # query encoding embedded_queries = self.embedding( session_queries.view(-1, session_queries.size(-1))) encoded_queries = self.query_encoder( embedded_queries, session_query_length.view(-1).data.cpu().numpy()) encoded_queries = self.apply_pooling(encoded_queries, self.config.pool_type) # encoded_queries: batch_size x session_length x (nhid_query * self.num_directions) encoded_queries = encoded_queries.contiguous().view( *session_queries.size()[:-1], -1) # session level encoding sess_query_hidden = self.session_encoder.init_weights( encoded_queries.size(0)) hidden_states, cell_states = [], [] # loop over all the queries in a session for idx in range(encoded_queries.size(1)): # update session-level query encoder state using query representations sess_q_out, sess_query_hidden = self.session_encoder( encoded_queries[:, idx, :].unsqueeze(1), sess_query_hidden) # -1: only consider hidden states of the last layer if self.config.model == 'LSTM': hidden_states.append(sess_query_hidden[0][-1]) cell_states.append(sess_query_hidden[1][-1]) else: hidden_states.append(sess_query_hidden[-1]) hidden_states = torch.stack(hidden_states, 1) # remove the last hidden states which stand for the last queries in sessions hidden_states = hidden_states[:, :-1, :].contiguous().view( -1, hidden_states.size(-1)).unsqueeze(0) if self.config.model == 'LSTM': cell_states = torch.stack(cell_states, 1) cell_states = cell_states[:, :-1, :].contiguous().view( -1, cell_states.size(-1)).unsqueeze(0) # Initialize hidden states of decoder with the last hidden states of the session encoder decoder_hidden = (hidden_states, cell_states) else: # Initialize hidden states of decoder with the last hidden states of the session encoder decoder_hidden = hidden_states # train the decoder for all the queries in a session except the first embedded_queries = embedded_queries.view(*session_queries.size(), -1) decoder_input = embedded_queries[:, 1:, :, :].contiguous().view( -1, *embedded_queries.size()[2:]) decoder_target = session_queries[:, 1:, :].contiguous().view( -1, session_queries.size(-1)) target_length = session_query_length[:, 1:].contiguous().view(-1) decoding_loss, total_local_decoding_loss_element = 0, 0 for idx in range(decoder_input.size(1) - 1): input_variable = decoder_input[:, idx, :].unsqueeze(1) decoder_output, decoder_hidden = self.decoder( input_variable, decoder_hidden) local_loss, num_local_loss = self.compute_loss( decoder_output, decoder_target[:, idx + 1], idx, target_length, self.config.regularize) decoding_loss += local_loss total_local_decoding_loss_element += num_local_loss if total_local_decoding_loss_element > 0: decoding_loss = decoding_loss / total_local_decoding_loss_element return decoding_loss @staticmethod def apply_pooling(encodings, pool_type): if pool_type == 'max': pooled_encodings = torch.max(encodings, 1)[0].squeeze() elif pool_type == 'mean': pooled_encodings = torch.sum(encodings, 1).squeeze() / encodings.size(1) elif pool_type == 'last': pooled_encodings = encodings[:, -1, :] return pooled_encodings
class Seq2Seq(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(Seq2Seq, self).__init__() self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.dictionary = dictionary self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.embedding.init_embedding_weights(self.dictionary, embedding_index, self.config.emsize) self.encoder = Encoder(self.config.input_size, self.config.nhid_enc, self.config.bidirection, self.config) self.decoder = Decoder(self.config.emsize, self.config.nhid_enc * self.num_directions, len(self.dictionary), self.config) @staticmethod def compute_decoding_loss(logits, target, seq_idx, length): losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)).squeeze() mask = helper.mask(length, seq_idx) # mask: batch x 1 losses = losses * mask.float() num_non_zero_elem = torch.nonzero(mask.data).size() if not num_non_zero_elem: return losses.sum(), 0 else: return losses.sum(), num_non_zero_elem[0] def forward(self, videos, video_len, decoder_input, target_length): # encode the video features encoded_videos = self.encoder(videos, video_len) if self.config.pool_type == 'max': hidden_states = torch.max(encoded_videos, 1)[0].squeeze() elif self.config.pool_type == 'mean': hidden_states = torch.sum(encoded_videos, 1).squeeze() / encoded_videos.size(1) elif self.config.pool_type == 'last': if self.num_directions == 2: hidden_states = torch.cat( (encoded_videos[:, -1, :self.config.nhid_enc], encoded_videos[:, -1, self.config.nhid_enc:]), 1) else: hidden_states = encoded_videos[:, -1, :] # Initialize hidden states of decoder with the last hidden states of the encoder if self.config.model is 'LSTM': cell_states = Variable(torch.zeros(*hidden_states.size())) if self.config.cuda: cell_states = cell_states.cuda() decoder_hidden = (hidden_states.unsqueeze(0).contiguous(), cell_states.unsqueeze(0).contiguous()) else: decoder_hidden = hidden_states.unsqueeze(0).contiguous() decoding_loss = 0 total_local_decoding_loss_element = 0 for idx in range(decoder_input.size(1) - 1): input_variable = decoder_input[:, idx] embedded_decoder_input = self.embedding(input_variable).unsqueeze( 1) decoder_output, decoder_hidden = self.decoder( embedded_decoder_input, decoder_hidden) target_variable = decoder_input[:, idx + 1] local_loss, num_local_loss = self.compute_decoding_loss( decoder_output, target_variable, idx, target_length) decoding_loss += local_loss total_local_decoding_loss_element += num_local_loss if total_local_decoding_loss_element > 0: decoding_loss = decoding_loss / total_local_decoding_loss_element return decoding_loss
class Agent(nn.Module): def __init__(self, bsize, embed_dim, encod_dim, numlabel, n_layers, recom=100, feature_vec=None, init=False, model='LSTM'): super(Agent, self).__init__() # classifier self.batch_size = bsize self.nonlinear_fc = False self.n_classes = numlabel self.enc_lstm_dim = encod_dim self.encoder_type = 'Encoder' self.model = model self.gamma = 0.9 self.n_layers = n_layers self.recom = recom #Only top 10 items are selected self.embedding = EmbeddingLayer(numlabel, embed_dim) ''' if init: self.embedding.init_embedding_weights(feature_vec) ''' self.encoder = eval(self.encoder_type)(self.batch_size, embed_dim, self.enc_lstm_dim, self.model, self.n_layers) self.enc2out = nn.Linear(self.enc_lstm_dim, self.n_classes) if init: self.init_params() # initialise oracle network with N(0,1) # otherwise variance of initialisation is very small => high NLL for data sampled from the same model def copy_weight(self, feature_vec): self.embedding.init_embedding_weights(feature_vec) def init_params(self): for param in self.parameters(): init.normal_(param, 0, 1) def forward(self, seq, evaluate=False): # seq : (seq, seq_len) seq_em, seq_len = seq seq_em = self.embedding(seq_em) if self.model == 'LSTM': enc_out, (h, c) = self.encoder((seq_em, seq_len)) else: enc_out, h = self.encoder((seq_em, seq_len)) output = self.enc2out(enc_out[:, -1, :]) #batch*hidden output = F.softmax(output, dim=1) #indices is with size of batch_size*self.recom if evaluate: _, indices = torch.topk(output, self.recom, dim=1, sorted=True) else: indices = torch.multinomial(output, self.recom) if self.model == 'LSTM': return output, indices, (h, c) else: return output, indices, h def step(self, click, hidden, evaluate=False): seq_em = self.embedding(click) if self.model == 'LSTM': enc_out, (h, c) = self.encoder.step_cell(seq_em, hidden) else: enc_out, h = self.encoder.step_cell(seq_em, hidden) output = self.enc2out(enc_out[:, -1, :]) #batch*hidden output = F.softmax(output, dim=1) if not evaluate: indices = torch.multinomial(output, self.recom) else: _, indices = torch.topk(output, self.recom, dim=1, sorted=True) # Only select from the top k if self.model == 'LSTM': return output, indices, (h, c) else: return output, indices, h
class Selector(nn.Module): """Biattentive classification network architecture for sentence classification.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(Selector, self).__init__() self.config = args self.dictionary = dictionary self.embedding = EmbeddingLayer(len(self.dictionary), self.config.emsize, self.config.emtraining, self.config) self.embedding.init_embedding_weights(self.dictionary, embedding_index, self.config.emsize) self.we_selector = WE_Selector(self.config.emsize, self.config.dropout) def forward(self, sentence1, sentence1_len_old, sentence2, sentence2_len_old, threshold=0.5, is_train=0): """ Forward computation of the biattentive classification network. Returns classification scores for a batch of sentence pairs. :param sentence1: 2d tensor [batch_size x max_length] :param sentence1_len: 1d numpy array [batch_size] :param sentence2: 2d tensor [batch_size x max_length] :param sentence2_len: 1d numpy array [batch_size] :return: classification scores over batch [batch_size x num_classes] """ # step1: embed the words into vectors [batch_size x max_length x emsize] embedded_x1 = self.embedding(sentence1) embedded_y1 = self.embedding(sentence2) ###################################### selection ###################################### pbx = self.we_selector(embedded_x1) pby = self.we_selector(embedded_y1) assert pbx.size() == sentence1.size() assert pby.size() == sentence2.size() #torch byte tesnor Variable of size (batch x len) selection_x = pbx.bernoulli().long() #(pbx>=threshold).long() selection_y = pby.bernoulli().long() #(pby>=threshold).long() result_x = sentence1.mul( selection_x ) #word ids that are selected; contains zeros where it's not selected (ony selected can be found by selected_x[selected_x!=0]) result_y = sentence2.mul(selection_y) selected_x, sentence1_len = helper.get_selected_tensor( result_x, pbx, sentence1, sentence1_len_old, self.config.cuda) #sentence1_len is a numpy array selected_y, sentence2_len = helper.get_selected_tensor( result_y, pby, sentence2, sentence2_len_old, self.config.cuda) #sentence2_len is a numpy array logpz = zsum = zdiff = -1.0 if is_train == 1: mask1 = (sentence1 != 0).long() mask2 = (sentence2 != 0).long() masked_selection_x = selection_x.mul(mask1) masked_selection_y = selection_y.mul(mask2) #logpz (batch x len) logpx = -helper.binary_cross_entropy( pbx, selection_x.float().detach(), reduce=False ) #as reduce is not available for this version I am doing this code myself: logpy = -helper.binary_cross_entropy( pby, selection_y.float().detach(), reduce=False) assert logpx.size() == sentence1.size() # batch logpx = logpx.mul(mask1.float()).sum(1) logpy = logpy.mul(mask2.float()).sum(1) logpz = (logpx + logpy) # zsum = ##### same as sentence1_len #####T.sum(z, axis=0, dtype=theano.config.floatX) zdiff1 = ( masked_selection_x[:, 1:] - masked_selection_x[:, :-1] ).abs().sum( 1 ) ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX) zdiff2 = ( masked_selection_y[:, 1:] - masked_selection_y[:, :-1] ).abs().sum( 1 ) ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX) assert zdiff1.size()[0] == sentence1.size()[0] assert logpz.size()[0] == sentence1.size()[0] zdiff = zdiff1 + zdiff2 xsum = masked_selection_x.sum(1) ysum = masked_selection_y.sum(1) zsum = xsum + ysum assert zsum.size()[0] == sentence1.size()[0] assert logpz.dim() == zsum.dim() assert logpz.dim() == zdiff.dim() return selected_x, sentence1_len, selected_y, sentence2_len, logpz, zsum.float( ), zdiff.float() # return selected_x (var), sentence1_len (numpy), selected_y (var), sentence2_len (numpy), selector_loss (var of size 1) return selected_x, sentence1_len, selected_y, sentence2_len, logpz, zsum, zdiff
class SentenceClassifier(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embeddings_index, args, select_method='max'): """"Constructor of the class.""" super(SentenceClassifier, self).__init__() self.config = args self.feature_select_method = select_method self.num_directions = 2 if args.bidirection else 1 print ("finish sending in arg") self.embedding = EmbeddingLayer(len(dictionary), self.config) print ("finish embed. layer") self.embedding.init_embedding_weights(dictionary, embeddings_index, self.config.emsize) self.encoder = Encoder(self.config.emsize, self.config.nhid, self.config.bidirection, self.config) print ("finish encoder") if args.nonlinear_fc: self.ffnn = nn.Sequential(OrderedDict([ ('dropout1', nn.Dropout(self.config.dropout_fc)), ('dense1', nn.Linear(self.config.nhid * self.num_directions * 4, self.config.fc_dim)), ('tanh', nn.Tanh()), ('dropout2', nn.Dropout(self.config.dropout_fc)), ('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)), ('tanh', nn.Tanh()), ('dropout3', nn.Dropout(self.config.dropout_fc)), ('dense3', nn.Linear(self.config.fc_dim, 2)) ])) else: self.ffnn = nn.Sequential(OrderedDict([ ('dropout1', nn.Dropout(self.config.dropout_fc)), ('dense1', nn.Linear(self.config.nhid * self.num_directions * 4, self.config.fc_dim)), ('dropout2', nn.Dropout(self.config.dropout_fc)), ('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)), ('dropout3', nn.Dropout(self.config.dropout_fc)), ('dense3', nn.Linear(self.config.fc_dim, 2)) ])) def forward(self, batch_sentence1, sent_len1, batch_sentence2, sent_len2): """"Defines the forward computation of the question classifier.""" # print ('embedding sent num 1') # print (batch_sentence1) embedded1 = self.embedding(batch_sentence1) # print (embedded1) # print ('embedding sent num 2') # print (batch_sentence2) embedded2 = self.embedding(batch_sentence2) # print (embedded2) # For the first sentences in batch output1 = self.encoder(embedded1, sent_len1) # For the second sentences in batch output2 = self.encoder(embedded2, sent_len2) if self.feature_select_method == 'max': encoded_questions1 = torch.max(output1, 1)[0].squeeze() encoded_questions2 = torch.max(output2, 1)[0].squeeze() elif self.feature_select_method == 'average': encoded_questions1 = torch.sum(output1, 1).squeeze() / batch_sentence1.size(1) encoded_questions2 = torch.sum(output2, 1).squeeze() / batch_sentence2.size(1) elif self.feature_select_method == 'last': encoded_questions1 = output1[:, -1, :] encoded_questions2 = output2[:, -1, :] assert encoded_questions1.size(0) == encoded_questions2.size(0) if encoded_questions1.data.dim() == 1: encoded_questions1 = encoded_questions1.unsqueeze(0) if encoded_questions2.data.dim() == 1: encoded_questions2 = encoded_questions2.unsqueeze(0) # compute angle between question representation angle = torch.mul(encoded_questions1, encoded_questions2) # compute distance between question representation distance = torch.abs(encoded_questions1 - encoded_questions2) # combined_representation = batch_size x (hidden_size * num_directions * 4) combined_representation = torch.cat((encoded_questions1, encoded_questions2, angle, distance), 1) return self.ffnn(combined_representation)
class NSRM(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(NSRM, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.query_encoder = Encoder(self.config.emsize, self.config.nhid_query, True, self.config) self.document_encoder = Encoder(self.config.emsize, self.config.nhid_doc, True, self.config) self.session_encoder = Encoder( self.config.nhid_query * self.num_directions, self.config.nhid_session, False, self.config) self.projection = nn.Linear( (self.config.nhid_query * self.num_directions) + self.config.nhid_session, self.config.nhid_doc * self.num_directions) self.decoder = Decoder(self.config.emsize, self.config.nhid_session, len(self.dictionary), self.config) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) @staticmethod def compute_decoding_loss(logits, target, seq_idx, length): """ Compute negative log-likelihood loss for a batch of predictions. :param logits: 2d tensor [batch_size x vocab_size] :param target: 2d tensor [batch_size x 1] :param seq_idx: an integer represents the current index of the sequences :param length: 1d tensor [batch_size], represents each sequences' true length :return: total loss over the input mini-batch [autograd Variable] and number of loss elements """ losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)) mask = helper.mask(length, seq_idx) # mask: batch x 1 losses = losses * mask.float() num_non_zero_elem = torch.nonzero(mask.data).size() if not num_non_zero_elem: return losses.sum(), 0 else: return losses.sum(), num_non_zero_elem[0] @staticmethod def compute_click_loss(logits, target): """ Compute logistic loss for a batch of clicks. Return average loss for the input mini-batch. :param logits: 2d tensor [batch_size x num_clicks_per_query] :param target: 2d tensor [batch_size x num_clicks_per_query] :return: average loss over batch [autograd Variable] """ # taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L695 neg_abs = -logits.abs() loss = logits.clamp(min=0) - logits * target + (1 + neg_abs.exp()).log() return loss.mean() def forward(self, batch_session, length, batch_clicks, click_labels): """ Forward function of the neural click model. Return average loss for a batch of sessions. :param batch_session: 3d tensor [batch_size x session_length x max_query_length] :param length: 2d tensor [batch_size x session_length] :param batch_clicks: 4d tensor [batch_size x session_length x num_rel_docs_per_query x max_document_length] :param click_labels: 3d tensor [batch_size x session_length x num_rel_docs_per_query] :return: average loss over batch [autograd Variable] """ # query level encoding embedded_queries = self.embedding( batch_session.view(-1, batch_session.size(-1))) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.query_encoder.init_weights( embedded_queries.size(0)) output, hidden = self.query_encoder(embedded_queries, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.query_encoder.init_weights( embedded_queries.size(0)) output, hidden = self.query_encoder(embedded_queries, encoder_hidden) encoded_queries = torch.max(output, 1)[0].squeeze(1) # encoded_queries = batch_size x num_queries_in_a_session x hidden_size encoded_queries = encoded_queries.view(*batch_session.size()[:-1], -1) # document level encoding embedded_clicks = self.embedding( batch_clicks.view(-1, batch_clicks.size(-1))) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.document_encoder.init_weights( embedded_clicks.size(0)) output, hidden = self.document_encoder( embedded_clicks, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.document_encoder.init_weights( embedded_clicks.size(0)) output, hidden = self.document_encoder(embedded_clicks, encoder_hidden) encoded_clicks = torch.max(output, 1)[0].squeeze(1) # encoded_clicks = batch_size x num_queries_in_a_session x num_rel_docs_per_query x hidden_size encoded_clicks = encoded_clicks.view(*batch_clicks.size()[:-1], -1) # session level encoding sess_hidden = self.session_encoder.init_weights( encoded_queries.size(0)) sess_output = Variable( torch.zeros(self.config.batch_size, 1, self.config.nhid_session)) if self.config.cuda: sess_output = sess_output.cuda() hidden_states, cell_states = [], [] click_loss = 0 for idx in range(encoded_queries.size(1)): combined_rep = torch.cat( (sess_output.squeeze(), encoded_queries[:, idx, :]), 1) combined_rep = self.projection(combined_rep) combined_rep = combined_rep.unsqueeze(1).expand( *encoded_clicks[:, idx, :, :].size()) click_score = torch.sum( torch.mul(combined_rep, encoded_clicks[:, idx, :, :]), 2).squeeze(2) click_loss += self.compute_click_loss(click_score, click_labels[:, idx, :]) # update session state using query representations sess_output, sess_hidden = self.session_encoder( encoded_queries[:, idx, :].unsqueeze(1), sess_hidden) hidden_states.append(sess_hidden[0]) cell_states.append(sess_hidden[1]) click_loss = click_loss / encoded_queries.size(1) hidden_states = torch.stack(hidden_states, 2).squeeze(0) cell_states = torch.stack(cell_states, 2).squeeze(0) # decoding in sequence-to-sequence learning hidden_states = hidden_states[:, :-1, :].contiguous().view( -1, hidden_states.size(-1)).unsqueeze(0) cell_states = cell_states[:, :-1, :].contiguous().view( -1, cell_states.size(-1)).unsqueeze(0) decoder_input = batch_session[:, 1:, :].contiguous().view( -1, batch_session.size(-1)) target_length = length[:, 1:].contiguous().view(-1) input_variable = Variable( torch.LongTensor(decoder_input.size(0)).fill_( self.dictionary.word2idx[self.dictionary.start_token])) if self.config.cuda: input_variable = input_variable.cuda() hidden_states = hidden_states.cuda() cell_states = cell_states.cuda() # Initialize hidden states of decoder with the last hidden states of the session encoder decoder_hidden = (hidden_states, cell_states) decoding_loss = 0 total_local_decoding_loss_element = 0 for idx in range(decoder_input.size(1)): if idx != 0: input_variable = decoder_input[:, idx - 1] embedded_decoder_input = self.embedding(input_variable).unsqueeze( 1) decoder_output, decoder_hidden = self.decoder( embedded_decoder_input, decoder_hidden) target_variable = decoder_input[:, idx] local_loss, num_local_loss = self.compute_decoding_loss( decoder_output, target_variable, idx, target_length) decoding_loss += local_loss total_local_decoding_loss_element += num_local_loss if total_local_decoding_loss_element > 0: decoding_loss = decoding_loss / total_local_decoding_loss_element return click_loss + decoding_loss
class DRMM(nn.Module): """Implementation of the deep relevance matching model.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(DRMM, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.bins = [-1.0, -0.5, 0, 0.5, 1.0, 1.0] self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.gating_network = GatingNetwork(self.config.emsize) self.ffnn = nn.Sequential(nn.Linear(self.config.nbins, 1), nn.Linear(1, 1)) self.output = nn.Linear(1, 1) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) @staticmethod def cosine_similarity(x1, x2, dim=1, eps=1e-8): """ Returns cosine similarity between x1 and x2, computed along dim. # taken from http://pytorch.org/docs/master/_modules/torch/nn/functional.html#cosine_similarity :param x1: (Variable): First input. :param x2: (Variable): Second input (of size matching x1). :param dim: (int, optional): Dimension of vectors. Default: 1 :param eps: Small value to avoid division by zero. Default: 1e-8 :return: """ w12 = torch.sum(x1 * x2, dim) w1 = torch.norm(x1, 2, dim) w2 = torch.norm(x2, 2, dim) return (w12 / (w1 * w2).clamp(min=eps)).squeeze() def forward(self, batch_queries, batch_docs): """ Forward function of the match tensor model. Return average loss for a batch of sessions. :param batch_queries: 2d tensor [batch_size x max_query_length] :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length] :return: score representing click probability [batch_size x num_clicks_per_query] """ embedded_queries = self.embedding(batch_queries) term_weights = self.gating_network(embedded_queries).unsqueeze( 1).expand(*batch_docs.size()[:2], batch_queries.size(1)) embedded_docs = self.embedding(batch_docs.view(-1, batch_docs.size(-1))) embedded_queries = embedded_queries.unsqueeze(1).expand( *batch_docs.size()[:2], *embedded_queries.size()[1:]) embedded_queries = embedded_queries.contiguous().view( -1, *embedded_queries.size()[2:]) embedded_queries = embedded_queries.unsqueeze(2).expand( *embedded_queries.size()[:2], batch_docs.size(2), embedded_queries.size(2)) embedded_docs = embedded_docs.unsqueeze(1).expand( embedded_docs.size(0), batch_queries.size(1), *embedded_docs.size()[1:]) cos_sim = self.cosine_similarity(embedded_queries, embedded_docs, 3) hist = numpy.apply_along_axis( lambda x: numpy.histogram(x, bins=self.bins), 2, cos_sim.data.cpu().numpy()) histogram_feats = torch.from_numpy( numpy.array([[axis2 for axis2 in axis1] for axis1 in hist[:, :, 0]])).float() if self.config.cuda: histogram_feats = Variable(histogram_feats).cuda() else: histogram_feats = Variable(histogram_feats) ffnn_out = self.ffnn(histogram_feats.view( -1, self.config.nbins)).squeeze().view(-1, batch_queries.size(1)) weighted_ffnn_out = ffnn_out * term_weights.contiguous().view( -1, term_weights.size(2)) score = self.output(torch.sum(weighted_ffnn_out, 1, keepdim=True)).squeeze() return score.view(*batch_docs.size()[:2])
class MatchTensor(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(MatchTensor, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.linear_projection = nn.Linear(self.config.emsize, self.config.featsize) self.query_encoder = Encoder(self.config.featsize, self.config.nhid_query, True, self.config) self.document_encoder = Encoder(self.config.featsize, self.config.nhid_doc, True, self.config) self.query_projection = nn.Linear(self.config.nhid_query * self.num_directions, self.config.nchannels) self.document_projection = nn.Linear(self.config.nhid_doc * self.num_directions, self.config.nchannels) self.exact_match_channel = ExactMatchChannel() self.conv1 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 3), padding=1) self.conv2 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 5), padding=(1, 2)) self.conv3 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 7), padding=(1, 3)) self.relu = nn.ReLU() self.conv = nn.Conv2d(self.config.nfilters * 3, self.config.match_filter_size, (1, 1)) self.output = nn.Linear(self.config.match_filter_size, 1) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) def forward(self, batch_queries, query_len, batch_docs, doc_len): """ Forward function of the match tensor model. Return average loss for a batch of sessions. :param batch_queries: 2d tensor [batch_size x max_query_length] :param query_len: 1d numpy array [batch_size] :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length] :param doc_len: 2d numpy array [batch_size x num_clicks_per_query] :return: score representing click probability [batch_size x num_clicks_per_query] """ # step1: apply embedding lookup embedded_queries = self.embedding(batch_queries) embedded_docs = self.embedding(batch_docs.view(-1, batch_docs.size(-1))) # step2: apply linear projection on embedded queries and documents embedded_queries = self.linear_projection(embedded_queries.view(-1, embedded_queries.size(-1))) embedded_docs = self.linear_projection(embedded_docs.view(-1, embedded_docs.size(-1))) # step3: transform the tensors so that they can be given as input to RNN embedded_queries = embedded_queries.view(*batch_queries.size(), self.config.featsize) embedded_docs = embedded_docs.view(-1, batch_docs.size()[-1], self.config.featsize) # step4: pass the encoded query and doc through a bi-LSTM encoded_queries = self.query_encoder(embedded_queries, query_len) encoded_docs = self.document_encoder(embedded_docs, doc_len.reshape(-1)) # step5: apply linear projection on query hidden states projected_queries = self.query_projection(encoded_queries.view(-1, encoded_queries.size()[-1])).view( *batch_queries.size(), -1) projected_queries = projected_queries.unsqueeze(1).expand(projected_queries.size(0), batch_docs.size(1), *projected_queries.size()[1:]) projected_queries = projected_queries.contiguous().view(-1, *projected_queries.size()[2:]) projected_docs = self.document_projection(encoded_docs.view(-1, encoded_docs.size()[-1])) projected_docs = projected_docs.view(-1, batch_docs.size(2), projected_docs.size()[-1]) projected_queries = projected_queries.unsqueeze(2).expand(*projected_queries.size()[:2], batch_docs.size()[-1], projected_queries.size(2)) projected_docs = projected_docs.unsqueeze(1).expand(projected_docs.size(0), batch_queries.size()[-1], *projected_docs.size()[1:]) # step6: 2d product between projected query and doc vectors query_document_product = projected_queries * projected_docs # step7: append exact match channel exact_match = self.exact_match_channel(batch_queries, batch_docs).unsqueeze(3) query_document_product = torch.cat((query_document_product, exact_match), 3) query_document_product = query_document_product.transpose(2, 3).transpose(1, 2) # step8: run the convolutional operation, max-pooling and linear projection convoluted_feat1 = self.conv1(query_document_product) convoluted_feat2 = self.conv2(query_document_product) convoluted_feat3 = self.conv3(query_document_product) convoluted_feat = self.relu(torch.cat((convoluted_feat1, convoluted_feat2, convoluted_feat3), 1)) convoluted_feat = self.conv(convoluted_feat).transpose(1, 2).transpose(2, 3) max_pooled_feat = torch.max(convoluted_feat, 2)[0].squeeze() max_pooled_feat = torch.max(max_pooled_feat, 1)[0].squeeze() return self.output(max_pooled_feat).squeeze().view(*batch_docs.size()[:2])
class Seq2Seq(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(Seq2Seq, self).__init__() self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.embedding = EmbeddingLayer(len(dictionary), self.config) self.embedding.init_embedding_weights(dictionary, embedding_index, self.config.emsize) self.encoder = Encoder(self.config.emsize, self.config.nhid_enc, self.config.bidirection, self.config) if self.config.attn_type: self.decoder = AttentionDecoder( self.config.emsize, self.config.nhid_enc * self.num_directions, len(dictionary), self.config) else: self.decoder = Decoder(self.config.emsize, self.config.nhid_enc * self.num_directions, len(dictionary), self.config) @staticmethod def compute_decoding_loss(logits, target, seq_idx, length, regularize): losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)).squeeze() mask = helper.mask(length, seq_idx) # mask: batch x 1 losses = losses * mask.float() num_non_zero_elem = torch.nonzero(mask.data).size() if regularize: regularized_loss = logits.exp().mul(logits).sum( 1).squeeze() * regularize loss = losses.sum() + regularized_loss.sum() if not num_non_zero_elem: return loss, 0 else: return loss, num_non_zero_elem[0] else: if not num_non_zero_elem: return losses.sum(), 0 else: return losses.sum(), num_non_zero_elem[0] def forward(self, q1_var, q1_len, q2_var, q2_len): # encode the query embedded_q1 = self.embedding(q1_var) encoded_q1, hidden = self.encoder(embedded_q1, q1_len) if self.config.bidirection: if self.config.model == 'LSTM': h_t, c_t = hidden[0][-2:], hidden[1][-2:] decoder_hidden = torch.cat( (h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2), torch.cat( (c_t[0].unsqueeze(0), c_t[1].unsqueeze(0)), 2) else: h_t = hidden[0][-2:] decoder_hidden = torch.cat( (h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2) else: if self.config.model == 'LSTM': decoder_hidden = hidden[0][-1], hidden[1][-1] else: decoder_hidden = hidden[-1] if self.config.attn_type: decoder_context = Variable( torch.zeros(encoded_q1.size(0), encoded_q1.size(2))).unsqueeze(1) if self.config.cuda: decoder_context = decoder_context.cuda() decoding_loss, total_local_decoding_loss_element = 0, 0 for idx in range(q2_var.size(1) - 1): input_variable = q2_var[:, idx] embedded_decoder_input = self.embedding(input_variable).unsqueeze( 1) if self.config.attn_type: decoder_output, decoder_hidden, decoder_context, attn_weights = self.decoder( embedded_decoder_input, decoder_hidden, decoder_context, encoded_q1) else: decoder_output, decoder_hidden = self.decoder( embedded_decoder_input, decoder_hidden) local_loss, num_local_loss = self.compute_decoding_loss( decoder_output, q2_var[:, idx + 1], idx, q2_len, self.config.regularize) decoding_loss += local_loss total_local_decoding_loss_element += num_local_loss if total_local_decoding_loss_element > 0: decoding_loss = decoding_loss / total_local_decoding_loss_element return decoding_loss
class Sequence2Sequence(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(Sequence2Sequence, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.query_encoder = Encoder(self.config.emsize, self.config.nhid_query, self.config) self.session_encoder = Encoder(self.config.nhid_query, self.config.nhid_session, self.config) self.decoder = Decoder(self.config.emsize, self.config.nhid_session, len(self.dictionary), self.config) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) @staticmethod def compute_loss(logits, target, seq_idx, length): # logits: batch x vocab_size, target: batch x 1 losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)) # mask: batch x 1 mask = helper.mask(length, seq_idx) losses = losses * mask.float() num_non_zero_elem = torch.nonzero(mask.data).size() if not num_non_zero_elem: loss = losses.sum() else: loss = losses.sum() / num_non_zero_elem[0] return loss def forward(self, batch_session, length): """"Defines the forward computation of the question classifier.""" embedded_input = self.embedding( batch_session.view(-1, batch_session.size(-1))) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.query_encoder.init_weights( embedded_input.size(0)) output, hidden = self.query_encoder(embedded_input, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.query_encoder.init_weights( embedded_input.size(0)) output, hidden = self.query_encoder(embedded_input, encoder_hidden) if self.config.bidirection: output = torch.div( torch.add( output[:, :, 0:self.config.nhid_query], output[:, :, self.config.nhid_query:2 * self.config.nhid_query]), 2) session_input = output[:, -1, :].contiguous().view( batch_session.size(0), batch_session.size(1), -1) # session level encoding sess_hidden = self.session_encoder.init_weights(session_input.size(0)) hidden_states, cell_states = [], [] for idx in range(session_input.size(1)): sess_output, sess_hidden = self.session_encoder( session_input[:, idx, :].unsqueeze(1), sess_hidden) if self.config.bidirection: hidden_states.append(torch.mean(sess_hidden[0], 0)) cell_states.append(torch.mean(sess_hidden[1], 0)) else: hidden_states.append(sess_hidden[0]) cell_states.append(sess_hidden[1]) hidden_states = torch.stack(hidden_states, 2).squeeze(0) cell_states = torch.stack(cell_states, 2).squeeze(0) hidden_states = hidden_states[:, :-1, :].contiguous().view( -1, hidden_states.size(-1)).unsqueeze(0) cell_states = cell_states[:, :-1, :].contiguous().view( -1, cell_states.size(-1)).unsqueeze(0) decoder_input = batch_session[:, 1:, :].contiguous().view( -1, batch_session.size(-1)) target_length = length[:, 1:].contiguous().view(-1) input_variable = Variable( torch.LongTensor(decoder_input.size(0)).fill_( self.dictionary.word2idx[self.dictionary.start_token])) if self.config.cuda: input_variable = input_variable.cuda() hidden_states = hidden_states.cuda() cell_states = cell_states.cuda() # Initialize hidden states of decoder with the last hidden states of the session encoder decoder_hidden = (hidden_states, cell_states) loss = 0 for idx in range(decoder_input.size(1)): if idx != 0: input_variable = decoder_input[:, idx - 1] embedded_decoder_input = self.embedding(input_variable).unsqueeze( 1) decoder_output, decoder_hidden = self.decoder( embedded_decoder_input, decoder_hidden) target_variable = decoder_input[:, idx] loss += self.compute_loss(decoder_output, target_variable, idx, target_length) return loss
class BCN(nn.Module): """Biattentive classification network architecture for sentence classification.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(BCN, self).__init__() self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.dictionary = dictionary self.embedding = EmbeddingLayer(len(self.dictionary), self.config.emsize, self.config.emtraining, self.config) self.embedding.init_embedding_weights(self.dictionary, embedding_index, self.config.emsize) self.selector = Selector(self.config.emsize, self.config.dropout) self.relu_network = nn.Sequential( OrderedDict([('dense1', nn.Linear(self.config.emsize, self.config.nhid)), ('nonlinearity', nn.ReLU())])) self.encoder = Encoder(self.config.nhid, self.config.nhid, self.config.bidirection, self.config.nlayers, self.config) self.biatt_encoder1 = Encoder( self.config.nhid * self.num_directions * 3, self.config.nhid, self.config.bidirection, 1, self.config) self.biatt_encoder2 = Encoder( self.config.nhid * self.num_directions * 3, self.config.nhid, self.config.bidirection, 1, self.config) self.ffnn = nn.Linear(self.config.nhid * self.num_directions, 1) self.maxout_network = MaxoutNetwork(self.config.nhid * self.num_directions * 4 * 2, self.config.num_class, num_units=self.config.num_units) def forward(self, sentence1, sentence1_len_old, sentence2, sentence2_len_old): """ Forward computation of the biattentive classification network. Returns classification scores for a batch of sentence pairs. :param sentence1: 2d tensor [batch_size x max_length] :param sentence1_len: 1d numpy array [batch_size] :param sentence2: 2d tensor [batch_size x max_length] :param sentence2_len: 1d numpy array [batch_size] :return: classification scores over batch [batch_size x num_classes] """ # step1: embed the words into vectors [batch_size x max_length x emsize] embedded_x1 = self.embedding(sentence1) embedded_y1 = self.embedding(sentence2) ###################################### selection ###################################### selection_x = self.selector(embedded_x1) selection_y = self.selector(embedded_y1) assert selection_x.size() == sentence1.size() assert selection_y.size() == sentence2.size() result_x = sentence1.mul( selection_x ) #word ids that are selected contains zeros where it's not selected (ony selected can be found by selected_x[selected_x!=0]) result_y = sentence2.mul(selection_y) selected_x, sentence1_len = helper.get_selected_tensor( result_x, self.config.cuda) #sentence1_len is a numpy array selected_y, sentence2_len = helper.get_selected_tensor( result_y, self.config.cuda) #sentence2_len is a numpy array embedded_x = self.embedding(selected_x) embedded_y = self.embedding(selected_y) # batch # zsum = ##### same as sentence1_len #####T.sum(z, axis=0, dtype=theano.config.floatX) zdiff1 = (selection_x[:, 1:] - selection_x[:, :-1]).abs().sum( 1 ) ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX) zdiff2 = (selection_y[:, 1:] - selection_y[:, :-1]).abs().sum( 1 ) ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX) assert zdiff1.size()[0] == len(sentence1_len) ###################################### selection ###################################### # step2: pass the embedded words through the ReLU network [batch_size x max_length x hidden_size] embedded_x = self.relu_network(embedded_x) embedded_y = self.relu_network(embedded_y) # step3: pass the word vectors through the encoder [batch_size x max_length x hidden_size * num_directions] encoded_x = self.encoder(embedded_x, sentence1_len) # For the second sentences in batch encoded_y = self.encoder(embedded_y, sentence2_len) # step4: compute affinity matrix [batch_size x sent1_max_length x sent2_max_length] affinity_mat = torch.bmm(encoded_x, encoded_y.transpose(1, 2)) # step5: compute conditioned representations [batch_size x max_length x hidden_size * num_directions] conditioned_x = torch.bmm( f.softmax(affinity_mat, 2).transpose(1, 2), encoded_x) conditioned_y = torch.bmm( f.softmax(affinity_mat.transpose(1, 2), 2).transpose(1, 2), encoded_y) # step6: generate input of the biattentive encoders [batch_size x max_length x hidden_size * num_directions * 3] biatt_input_x = torch.cat( (encoded_x, torch.abs(encoded_x - conditioned_y), torch.mul(encoded_x, conditioned_y)), 2) biatt_input_y = torch.cat( (encoded_y, torch.abs(encoded_y - conditioned_x), torch.mul(encoded_y, conditioned_x)), 2) # step7: pass the conditioned information through the biattentive encoders # [batch_size x max_length x hidden_size * num_directions] biatt_x = self.biatt_encoder1(biatt_input_x, sentence1_len) biatt_y = self.biatt_encoder2(biatt_input_y, sentence2_len) # step8: compute self-attentive pooling features att_weights_x = self.ffnn(biatt_x.view(-1, biatt_x.size(2))).squeeze(1) att_weights_x = f.softmax(att_weights_x.view(*biatt_x.size()[:-1]), 1) att_weights_y = self.ffnn(biatt_y.view(-1, biatt_y.size(2))).squeeze(1) att_weights_y = f.softmax(att_weights_y.view(*biatt_y.size()[:-1]), 1) self_att_x = torch.bmm(biatt_x.transpose(1, 2), att_weights_x.unsqueeze(2)).squeeze(2) self_att_y = torch.bmm(biatt_y.transpose(1, 2), att_weights_y.unsqueeze(2)).squeeze(2) # step9: compute the joint representations [batch_size x hidden_size * num_directions * 4] # print (' self_att_x size: ', self_att_x.size()) pooled_x = torch.cat((biatt_x.max(1)[0], biatt_x.mean(1), biatt_x.min(1)[0], self_att_x), 1) pooled_y = torch.cat((biatt_y.max(1)[0], biatt_y.mean(1), biatt_y.min(1)[0], self_att_y), 1) # step10: pass the pooled representations through the maxout network score = self.maxout_network(torch.cat((pooled_x, pooled_y), 1)) return score, sentence1_len, sentence2_len, zdiff1, zdiff2
class MatchTensor(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(MatchTensor, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) self.linear_projection = nn.Linear(self.config.emsize, self.config.featsize) self.query_encoder = Encoder(self.config.featsize, self.config.nhid_query, True, self.config) self.document_encoder = Encoder(self.config.featsize, self.config.nhid_doc, True, self.config) self.query_projection = nn.Linear( self.config.nhid_query * self.num_directions, self.config.nchannels) self.document_projection = nn.Linear( self.config.nhid_doc * self.num_directions, self.config.nchannels) self.exact_match_channel = ExactMatchChannel() self.conv1 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 3), padding=1) self.conv2 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 5), padding=(1, 2)) self.conv3 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 7), padding=(1, 3)) self.relu = nn.ReLU() self.conv = nn.Conv2d(self.config.nfilters * 3, self.config.match_filter_size, (1, 1)) self.output = nn.Linear(self.config.match_filter_size, 1) self.session_encoder = EncoderCell(self.config.nchannels, self.config.nhid_session, False, self.config) self.decoder = DecoderCell(self.config.emsize, self.config.nhid_session, len(dictionary), self.config) @staticmethod def compute_decoding_loss(logits, target, seq_idx, length, regularize): """ Compute negative log-likelihood loss for a batch of predictions. :param logits: 2d tensor [batch_size x vocab_size] :param target: 1d tensor [batch_size] :param seq_idx: an integer represents the current index of the sequences :param length: 1d tensor [batch_size], represents each sequences' true length :return: total loss over the input mini-batch [autograd Variable] and number of loss elements """ losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)).squeeze() mask = helper.mask(length, seq_idx) # mask: batch x 1 losses = losses * mask.float() num_non_zero_elem = torch.nonzero(mask.data).size() if regularize: regularized_loss = logits.exp().mul(logits).sum( 1).squeeze() * regularize loss = losses.sum() + regularized_loss.sum() if not num_non_zero_elem: return loss, 0 else: return loss, num_non_zero_elem[0] else: if not num_non_zero_elem: return losses.sum(), 0 else: return losses.sum(), num_non_zero_elem[0] def forward(self, session_queries, session_query_length, rel_docs, rel_docs_length, doc_labels): """ Forward function of the neural click model. Return average loss for a batch of sessions. :param session_queries: 3d tensor [batch_size x session_length x max_query_length] :param session_query_length: 2d tensor [batch_size x session_length] :param rel_docs: 4d tensor [batch_size x session_length x num_rel_docs_per_query x max_doc_length] :param rel_docs_length: 3d tensor [batch_size x session_length x num_rel_docs_per_query] :param doc_labels: 3d tensor [batch_size x session_length x num_rel_docs_per_query] :return: average loss over batch [autograd Variable] """ batch_queries = session_queries.view(-1, session_queries.size(-1)) batch_docs = rel_docs.view(-1, *rel_docs.size()[2:]) projected_queries = self.encode_query( batch_queries, session_query_length) # (B*S) x L x H projected_docs = self.encode_document(batch_docs, rel_docs_length) score = self.document_ranker(projected_queries, projected_docs, batch_queries, batch_docs) click_loss = f.binary_cross_entropy_with_logits( score, doc_labels.view(-1, doc_labels.size(2))) # encoded_queries: batch_size x session_length x nhid_query encoded_queries = projected_queries.max(1)[0].view( *session_queries.size()[:2], -1) decoding_loss = self.query_recommender(session_queries, session_query_length, encoded_queries) return click_loss, decoding_loss def query_recommender(self, session_queries, session_query_length, encoded_queries): # session level encoding sess_q_hidden = self.session_encoder.init_weights( encoded_queries.size(0)) hidden_states, cell_states = [], [] # loop over all the queries in a session for idx in range(encoded_queries.size(1)): # update session-level query encoder state using query representations sess_q_out, sess_q_hidden = self.session_encoder( encoded_queries[:, idx, :].unsqueeze(1), sess_q_hidden) # -1 stands for: only consider hidden states from the last layer if self.config.model == 'LSTM': hidden_states.append(sess_q_hidden[0][-1]) cell_states.append(sess_q_hidden[1][-1]) else: hidden_states.append(sess_q_hidden[-1]) hidden_states = torch.stack(hidden_states, 1) # remove the last hidden states which stand for the last queries in sessions hidden_states = hidden_states[:, :-1, :].contiguous().view( -1, hidden_states.size(-1)).unsqueeze(0) if self.config.model == 'LSTM': cell_states = torch.stack(cell_states, 1) cell_states = cell_states[:, :-1, :].contiguous().view( -1, cell_states.size(-1)).unsqueeze(0) # Initialize hidden states of decoder with the last hidden states of the session encoder decoder_hidden = (hidden_states, cell_states) else: # Initialize hidden states of decoder with the last hidden states of the session encoder decoder_hidden = hidden_states embedded_queries = self.embedding( session_queries.view(-1, session_queries.size(-1))) # train the decoder for all the queries in a session except the last embedded_queries = embedded_queries.view(*session_queries.size(), -1) decoder_input = embedded_queries[:, 1:, :, :].contiguous().view( -1, *embedded_queries.size()[2:]) decoder_target = session_queries[:, 1:, :].contiguous().view( -1, session_queries.size(-1)) target_length = session_query_length[:, 1:].contiguous().view(-1) decoding_loss, total_local_decoding_loss_element = 0, 0 for idx in range(decoder_input.size(1) - 1): input_variable = decoder_input[:, idx, :].unsqueeze(1) decoder_output, decoder_hidden = self.decoder( input_variable, decoder_hidden) local_loss, num_local_loss = self.compute_decoding_loss( decoder_output, decoder_target[:, idx + 1], idx, target_length, self.config.regularize) decoding_loss += local_loss total_local_decoding_loss_element += num_local_loss if total_local_decoding_loss_element > 0: decoding_loss = decoding_loss / total_local_decoding_loss_element return decoding_loss def document_ranker(self, projected_queries, projected_docs, batch_queries, batch_docs): # step6: 2d product between projected query and doc vectors projected_queries = projected_queries.unsqueeze(1).expand( projected_queries.size(0), batch_docs.size(1), *projected_queries.size()[1:]) projected_queries = projected_queries.contiguous().view( -1, *projected_queries.size()[2:]) projected_docs = projected_docs.view(-1, batch_docs.size(2), projected_docs.size()[-1]) projected_queries = projected_queries.unsqueeze(2).expand( *projected_queries.size()[:2], batch_docs.size()[-1], projected_queries.size(2)) projected_docs = projected_docs.unsqueeze(1).expand( projected_docs.size(0), batch_queries.size()[-1], *projected_docs.size()[1:]) query_document_product = projected_queries * projected_docs # step7: append exact match channel exact_match = self.exact_match_channel(batch_queries, batch_docs).unsqueeze(3) query_document_product = torch.cat( (query_document_product, exact_match), 3) query_document_product = query_document_product.transpose(2, 3).transpose( 1, 2) # step8: run the convolutional operation, max-pooling and linear projection convoluted_feat1 = self.conv1(query_document_product) convoluted_feat2 = self.conv2(query_document_product) convoluted_feat3 = self.conv3(query_document_product) convoluted_feat = self.relu( torch.cat((convoluted_feat1, convoluted_feat2, convoluted_feat3), 1)) convoluted_feat = self.conv(convoluted_feat).transpose(1, 2).transpose( 2, 3) max_pooled_feat = torch.max(convoluted_feat, 2)[0].squeeze() max_pooled_feat = torch.max(max_pooled_feat, 1)[0].squeeze() return self.output(max_pooled_feat).squeeze().view( *batch_docs.size()[:2]) def encode_query(self, batch_queries, session_query_length): # step1: apply embedding lookup embedded_queries = self.embedding(batch_queries) # step2: apply linear projection on embedded queries and documents embedded_queries = self.linear_projection( embedded_queries.view(-1, embedded_queries.size(-1))) # step3: transform the tensors so that they can be given as input to RNN embedded_queries = embedded_queries.view(*batch_queries.size(), self.config.featsize) # step4: pass the encoded query and doc through a bi-LSTM encoded_queries = self.query_encoder( embedded_queries, session_query_length.view(-1).data.cpu().numpy()) # step5: apply linear projection on query hidden states projected_queries = self.query_projection( encoded_queries.view(-1, encoded_queries.size()[-1])).view( *batch_queries.size(), -1) return projected_queries def encode_document(self, batch_docs, rel_docs_length): # step1: apply embedding lookup embedded_docs = self.embedding(batch_docs.view(-1, batch_docs.size(-1))) # step2: apply linear projection on embedded queries and documents embedded_docs = self.linear_projection( embedded_docs.view(-1, embedded_docs.size(-1))) # step3: transform the tensors so that they can be given as input to RNN embedded_docs = embedded_docs.view(-1, batch_docs.size()[-1], self.config.featsize) # step4: pass the encoded query and doc through a bi-LSTM encoded_docs = self.document_encoder( embedded_docs, rel_docs_length.view(-1).data.cpu().numpy()) # step5: apply linear projection on query hidden states projected_docs = self.document_projection( encoded_docs.view(-1, encoded_docs.size()[-1])) return projected_docs
class SentenceClassifier(nn.Module): """Predicts the label given a pair of sentences.""" def __init__(self, dictionary, embeddings_index, args): """"Constructor of the class.""" super(SentenceClassifier, self).__init__() self.config = args self.num_directions = 2 if args.bidirection else 1 self.embedding = EmbeddingLayer(len(dictionary), self.config) self.embedding.init_embedding_weights(dictionary, embeddings_index, self.config.emsize) self.encoder = Encoder(self.config.emsize, self.config.nhid, self.config.bidirection, self.config) if args.nonlinear_fc: self.ffnn = nn.Sequential(OrderedDict([ ('dropout1', nn.Dropout(self.config.dropout_fc)), ('dense1', nn.Linear(self.config.nhid * self.num_directions * 4, self.config.fc_dim)), ('tanh', nn.Tanh()), ('dropout2', nn.Dropout(self.config.dropout_fc)), ('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)), ('tanh', nn.Tanh()), ('dropout3', nn.Dropout(self.config.dropout_fc)), ('dense3', nn.Linear(self.config.fc_dim, self.config.num_classes)) ])) else: self.ffnn = nn.Sequential(OrderedDict([ ('dropout1', nn.Dropout(self.config.dropout_fc)), ('dense1', nn.Linear(self.config.nhid * self.num_directions * 4, self.config.fc_dim)), ('dropout2', nn.Dropout(self.config.dropout_fc)), ('dense2', nn.Linear(self.config.fc_dim, self.config.fc_dim)), ('dropout3', nn.Dropout(self.config.dropout_fc)), ('dense3', nn.Linear(self.config.fc_dim, self.config.num_classes)) ])) def forward(self, batch_sentence1, sent_len1, batch_sentence2, sent_len2): """"Defines the forward computation of the sentence pair classifier.""" embedded1 = self.embedding(batch_sentence1) embedded2 = self.embedding(batch_sentence2) # For the first sentences in batch output1 = self.encoder(embedded1, sent_len1) # For the second sentences in batch output2 = self.encoder(embedded2, sent_len2) if self.config.pool_type == 'max': encoded_questions1 = torch.max(output1, 1)[0] encoded_questions2 = torch.max(output2, 1)[0] elif self.config.pool_type == 'mean': encoded_questions1 = torch.mean(output1, 1) encoded_questions2 = torch.mean(output2, 1) elif self.config.pool_type == 'last': if self.num_directions == 2: encoded_questions1 = torch.cat((output1[:, -1, :self.config.nhid], output1[:, 0, self.config.nhid:]), 1) encoded_questions2 = torch.cat((output2[:, -1, :self.config.nhid], output2[:, 0, self.config.nhid:]), 1) else: encoded_questions1 = output1[:, -1, :] encoded_questions2 = output2[:, -1, :] assert encoded_questions1.size(0) == encoded_questions2.size(0) # compute angle between sentence representation angle = torch.mul(encoded_questions1, encoded_questions2) # compute distance between sentence representation distance = torch.abs(encoded_questions1 - encoded_questions2) # combined_representation = batch_size x (hidden_size * num_directions * 4) combined_representation = torch.cat((encoded_questions1, encoded_questions2, angle, distance), 1) return self.ffnn(combined_representation)
class MatchTensor(nn.Module): """Class that classifies question pair as duplicate or not.""" def __init__(self, dictionary, embedding_index, args): """"Constructor of the class.""" super(MatchTensor, self).__init__() self.dictionary = dictionary self.embedding_index = embedding_index self.config = args self.num_directions = 2 if self.config.bidirection else 1 self.embedding = EmbeddingLayer(len(self.dictionary), self.config) self.linear_projection = nn.Linear(self.config.emsize, self.config.featsize) self.query_encoder = Encoder(self.config.featsize, self.config.nhid_query, True, self.config) self.document_encoder = Encoder(self.config.featsize, self.config.nhid_doc, True, self.config) self.query_projection = nn.Linear( self.config.nhid_query * self.num_directions, self.config.nchannels) self.document_projection = nn.Linear( self.config.nhid_doc * self.num_directions, self.config.nchannels) self.exact_match_channel = ExactMatchChannel() self.conv1 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 3), padding=1) self.conv2 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 5), padding=(1, 2)) self.conv3 = nn.Conv2d(self.config.nchannels + 1, self.config.nfilters, (3, 7), padding=(1, 3)) self.relu = nn.ReLU() self.conv = nn.Conv2d(self.config.nfilters * 3, self.config.match_filter_size, (1, 1)) self.output = nn.Linear(self.config.match_filter_size, 1) # Initializing the weight parameters for the embedding layer. self.embedding.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) def forward(self, batch_queries, batch_docs): """ Forward function of the match tensor model. Return average loss for a batch of sessions. :param batch_queries: 2d tensor [batch_size x max_query_length] :param batch_docs: 3d tensor [batch_size x num_rel_docs_per_query x max_document_length] :return: average loss over batch [autograd Variable] """ embedded_queries = self.embedding(batch_queries) embedded_docs = self.embedding(batch_docs.view(-1, batch_docs.size(-1))) embedded_queries = self.linear_projection( embedded_queries.view(-1, embedded_queries.size(-1))) embedded_docs = self.linear_projection( embedded_docs.view(-1, embedded_docs.size(-1))) embedded_queries = embedded_queries.view(*batch_queries.size(), self.config.featsize) embedded_docs = embedded_docs.view(-1, batch_docs.size()[-1], self.config.featsize) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.query_encoder.init_weights( embedded_queries.size(0)) output, hidden = self.query_encoder(embedded_queries, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.query_encoder.init_weights( embedded_queries.size(0)) output, hidden = self.query_encoder(embedded_queries, encoder_hidden) embedded_queries = self.query_projection( output.view(-1, output.size()[-1])).view(*batch_queries.size(), -1) embedded_queries = embedded_queries.unsqueeze(1).expand( embedded_queries.size(0), batch_docs.size(1), *embedded_queries.size()[1:]) embedded_queries = embedded_queries.contiguous().view( -1, *embedded_queries.size()[2:]) if self.config.model == 'LSTM': encoder_hidden, encoder_cell = self.document_encoder.init_weights( embedded_docs.size(0)) output, hidden = self.document_encoder( embedded_docs, (encoder_hidden, encoder_cell)) else: encoder_hidden = self.document_encoder.init_weights( embedded_docs.size(0)) output, hidden = self.document_encoder(embedded_docs, encoder_hidden) embedded_docs = self.document_projection( output.view(-1, output.size()[-1])) embedded_docs = embedded_docs.view(-1, batch_docs.size(2), embedded_docs.size()[-1]) embedded_queries = embedded_queries.unsqueeze(2).expand( *embedded_queries.size()[:2], batch_docs.size()[-1], embedded_queries.size(2)) embedded_docs = embedded_docs.unsqueeze(1).expand( embedded_docs.size(0), batch_queries.size()[-1], *embedded_docs.size()[1:]) query_document_product = embedded_queries * embedded_docs exact_match = self.exact_match_channel(batch_queries, batch_docs).unsqueeze(3) query_document_product = torch.cat( (query_document_product, exact_match), 3) query_document_product = query_document_product.transpose(2, 3).transpose( 1, 2) convoluted_feat1 = self.conv1(query_document_product) convoluted_feat2 = self.conv2(query_document_product) convoluted_feat3 = self.conv3(query_document_product) convoluted_feat = self.relu( torch.cat((convoluted_feat1, convoluted_feat2, convoluted_feat3), 1)) convoluted_feat = self.conv(convoluted_feat).transpose(1, 2).transpose( 2, 3) max_pooled_feat = torch.max(convoluted_feat, 2)[0].squeeze() max_pooled_feat = torch.max(max_pooled_feat, 1)[0].squeeze() return self.output(max_pooled_feat).squeeze().view( *batch_docs.size()[:2])