def __init__(self, args, query_size, url_size, vtype_size, n_layers=1): super(CACMN, self).__init__() self.n_layers = n_layers self.args = args self.knowledge_hidden_size = args.hidden_size self.state_hidden_size = args.hidden_size self.document_hidden_size = args.hidden_size self.hidden_size = args.hidden_size self.batch_size = args.batch_size self.embed_size = args.embed_size self.softmax1 = torch.nn.Softmax(dim=0) self.softmax2 = torch.nn.Softmax(dim=1) self.logger = logging.getLogger("CACM") self.query_size = query_size self.url_size = url_size self.vtype_size = vtype_size self.dropout_rate = args.dropout_rate self.encode_gru_num_layer = 1 self.use_knowledge = args.use_knowledge self.use_knowledge_attention = args.use_knowledge_attention self.use_state_attention = args.use_state_attention # whether use pre-trained embeddings if args.use_knowledge: self.knowledge_embedding_size = args.embed_size else: self.knowledge_embedding_size = query_size # context-aware relevance estimator self.knowledge_encoder = KnowledgeEncoder(self.args, self.query_size) self.state_encoder = StateEncoder(self.args, self.url_size, self.vtype_size) self.document_encoder = DocumentEncoder(self.args, self.url_size, self.vtype_size) self.relevance_estimator = RelevanceEstimator( self.args.hidden_size * 3, args.hidden_size) # examination predictor self.examination_predictor = ExamPredictor(self.args, self.vtype_size) # set the combination function of relevance and examination if self.args.combine == 'exp_mul' or self.args.combine == 'exp_sigmoid_log': self.lamda = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.mu = nn.Parameter(torch.FloatTensor(1), requires_grad=True) # initialization self.lamda.data.fill_(1.0) self.mu.data.fill_(1.0) elif self.args.combine == 'linear': self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.beta = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.alpha.data.fill_(0.5) self.beta.data.fill_(0.5) elif self.args.combine == 'nonlinear': self.w11 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w12 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w21 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w22 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w31 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w32 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.sigmoid = nn.Sigmoid() self.w11.data.fill_(0.5) self.w12.data.fill_(0.5) self.w21.data.fill_(0.5) self.w22.data.fill_(0.5) self.w31.data.fill_(0.5) self.w32.data.fill_(0.5)
class CACMN(nn.Module): def __init__(self, args, query_size, url_size, vtype_size, n_layers=1): super(CACMN, self).__init__() self.n_layers = n_layers self.args = args self.knowledge_hidden_size = args.hidden_size self.state_hidden_size = args.hidden_size self.document_hidden_size = args.hidden_size self.hidden_size = args.hidden_size self.batch_size = args.batch_size self.embed_size = args.embed_size self.softmax1 = torch.nn.Softmax(dim=0) self.softmax2 = torch.nn.Softmax(dim=1) self.logger = logging.getLogger("CACM") self.query_size = query_size self.url_size = url_size self.vtype_size = vtype_size self.dropout_rate = args.dropout_rate self.encode_gru_num_layer = 1 self.use_knowledge = args.use_knowledge self.use_knowledge_attention = args.use_knowledge_attention self.use_state_attention = args.use_state_attention # whether use pre-trained embeddings if args.use_knowledge: self.knowledge_embedding_size = args.embed_size else: self.knowledge_embedding_size = query_size # context-aware relevance estimator self.knowledge_encoder = KnowledgeEncoder(self.args, self.query_size) self.state_encoder = StateEncoder(self.args, self.url_size, self.vtype_size) self.document_encoder = DocumentEncoder(self.args, self.url_size, self.vtype_size) self.relevance_estimator = RelevanceEstimator( self.args.hidden_size * 3, args.hidden_size) # examination predictor self.examination_predictor = ExamPredictor(self.args, self.vtype_size) # set the combination function of relevance and examination if self.args.combine == 'exp_mul' or self.args.combine == 'exp_sigmoid_log': self.lamda = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.mu = nn.Parameter(torch.FloatTensor(1), requires_grad=True) # initialization self.lamda.data.fill_(1.0) self.mu.data.fill_(1.0) elif self.args.combine == 'linear': self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.beta = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.alpha.data.fill_(0.5) self.beta.data.fill_(0.5) elif self.args.combine == 'nonlinear': self.w11 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w12 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w21 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w22 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w31 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.w32 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.sigmoid = nn.Sigmoid() self.w11.data.fill_(0.5) self.w12.data.fill_(0.5) self.w21.data.fill_(0.5) self.w22.data.fill_(0.5) self.w31.data.fill_(0.5) self.w32.data.fill_(0.5) def get_clicks(self, relevances, exams): clicks = [] combine = self.args.combine if combine == 'mul': clicks = torch.mul(relevances, exams) elif combine == 'exp_mul': clicks = torch.mul(torch.pow(relevances, self.lamda), torch.pow(exams, self.mu)) elif combine == 'linear': clicks = torch.add(torch.mul(relevances, self.alpha), torch.mul(exams, self.beta)) elif combine == 'nonlinear': # 2-layer out1 = self.sigmoid( torch.add(torch.mul(relevances, self.w11), torch.mul(exams, self.w12))) out2 = self.sigmoid( torch.add(torch.mul(relevances, self.w21), torch.mul(exams, self.w22))) clicks = self.sigmoid( torch.add(torch.mul(out1, self.w31), torch.mul(out2, self.w32))) elif combine == 'sigmoid_log': clicks = 4 * torch.div( torch.mul(relevances, exams), torch.mul(torch.add(relevances, 1), torch.add(exams, 1))) return clicks # inputs include: knowledge, interaction, document def forward(self, knowledge_variable, interaction_variable, document_variable, examination_context, data): # every variable correspond to a query-doc pair, which is to be predicted # forward one query session at a time # knowledge encoding knowledge_input_variable = knowledge_variable knowledge_input_variable = knowledge_input_variable.cuda( ) if use_cuda else knowledge_input_variable knowledge_output_list = [] for batch_idx, batch_knowledge in enumerate(knowledge_input_variable): batch_knowledge_output = [] for sess_pos_idx, knowledge in enumerate(batch_knowledge): query_idx = sess_pos_idx / 10 + 1 knowledge_hidden = self.knowledge_encoder.initHidden() this_knowledge = knowledge[:query_idx] knowledge_output, knowledge_hidden = self.knowledge_encoder.forward( this_knowledge, knowledge_hidden, data, query_idx) # attention for knowledge if self.use_knowledge_attention: a = torch.mm(knowledge_output, torch.transpose(knowledge_hidden, 0, 1)) a = self.softmax1(a).view(-1, 1) knowledge_memory = torch.mul(knowledge_output, a) knowledge_output = knowledge_memory.sum(dim=0) else: knowledge_output = knowledge_output[-1] batch_knowledge_output.append(knowledge_output) batch_knowledge_output = torch.stack(tuple(batch_knowledge_output), 0) knowledge_output_list.append(batch_knowledge_output) knowledge_output = torch.stack(tuple(knowledge_output_list), 0) # state encoding from interaction # interaction: batch_size * session_doc_num * data interaction_input_variable = interaction_variable interaction_input_variable = interaction_input_variable.cuda( ) if use_cuda else interaction_input_variable interaction_hidden = self.state_encoder.initHidden() # interaction_input_variable[:, :, i] has 4 parts: url, rank, vtype, click, each one is a one-hot vector interaction_output, interaction_hidden = self.state_encoder.forward( interaction_input_variable[:, :, 0], interaction_input_variable[:, :, 1], interaction_input_variable[:, :, 2], interaction_input_variable[:, :, 3], interaction_hidden, data) if self.use_state_attention: interaction_attention_output = [] for batch_idx, batch_interaction in enumerate(interaction_output): batch_interaction_output = [] for sess_pos_idx, interaction in enumerate(batch_interaction): prev_hidden = interaction_output[batch_idx][:sess_pos_idx + 1] interaction = interaction.view(1, -1) a = torch.mm(interaction, torch.transpose(prev_hidden, 0, 1)) a = self.softmax2(a).view(-1, 1) interaction_memory = torch.mul(prev_hidden, a) this_interaction_output = interaction_memory.sum(dim=0) batch_interaction_output.append(this_interaction_output) batch_interaction_output = torch.stack( tuple(batch_interaction_output), 0) interaction_attention_output.append(batch_interaction_output) interaction_output = torch.stack( tuple(interaction_attention_output), 0) # document encoding # document_input_variable has 3 parts: url, rank, vtype, each one is a one-hot vector document_input_variable = document_variable document_input_variable = document_input_variable.cuda( ) if use_cuda else document_input_variable document_output = self.document_encoder.forward( document_input_variable[:, :, 0], document_input_variable[:, :, 1], document_input_variable[:, :, 2], document_input_variable[:, :, 3], data) # concatenation and relevance estimator concat_output = torch.cat( (knowledge_output, interaction_output, document_output), dim=2) relevance = self.relevance_estimator.forward(concat_output, self.batch_size) # examination prediction examination_input_variable = examination_context examination_input_variable = examination_input_variable.cuda( ) if use_cuda else examination_input_variable examination_list_output = [] for batch_idx, batch_examination in enumerate( examination_input_variable): batch_examination_output = [] query_num = batch_examination.size()[0] / 10 for query_idx in range(query_num): this_query_context = batch_examination[query_idx * 10:(query_idx + 1) * 10] this_query_context = this_query_context.view(1, 10, -1) this_hidden = self.examination_predictor.initHidden() this_examination_output = self.examination_predictor.forward( this_query_context[:, :, 2], this_query_context[:, :, 3], this_query_context[:, :, 1], this_hidden) batch_examination_output.append(this_examination_output) batch_examination_output = torch.cat( tuple(batch_examination_output), 1) examination_list_output.append(batch_examination_output) examination_output = torch.cat(tuple(examination_list_output), 0) exam_prob = examination_output # combine the relevance and the examination according to the combination type clicks = self.get_clicks(relevance, exam_prob) return relevance, exam_prob, clicks