def forward(self, unique_word_chars, unique_word_lengths, sequences_as_uniqs=None): long_tensor = torch.cuda.LongTensor if torch.cuda.device_count( ) > 0 else torch.LongTensor embedded_chars = self._embeddings(unique_word_chars.type(long_tensor)) # [N, S, L] conv_out = self._conv(embedded_chars.transpose(1, 2)) # [N, L] conv_mask = misc.mask_for_lengths(unique_word_lengths) conv_out = conv_out + conv_mask.unsqueeze(1) embedded_words = conv_out.max(2)[0] if sequences_as_uniqs is None: return embedded_words else: if not isinstance(sequences_as_uniqs, list): sequences_as_uniqs = [sequences_as_uniqs] all_embedded = [] for word_idxs in sequences_as_uniqs: all_embedded.append( functional.embedding(word_idxs.type(long_tensor), embedded_words)) return all_embedded
def forward(self, emb_question, question_length, emb_support, support_length, unique_word_chars, unique_word_char_length, question_words2unique, support_words2unique, word_in_question, correct_start, answer2support, is_eval): """fast_qa model Args: emb_question: [Q, L_q, N] question_length: [Q] emb_support: [Q, L_s, N] support_length: [Q] unique_word_chars unique_word_char_length question_words2unique support_words2unique word_in_question: [Q, L_s] correct_start: [A], only during training, i.e., is_eval=False answer2question: [A], only during training, i.e., is_eval=False is_eval: [] Returns: start_scores [B, L_s, N], end_scores [B, L_s, N], span_prediction [B, 2] """ # Some helpers float_tensor = torch.cuda.FloatTensor if emb_question.is_cuda else torch.FloatTensor long_tensor = torch.cuda.LongTensor if emb_question.is_cuda else torch.LongTensor batch_size = question_length.data.shape[0] max_question_length = question_length.max().data[0] support_mask = misc.mask_for_lengths(support_length) question_binary_mask = misc.mask_for_lengths(question_length, mask_right=False, value=1.0) if self._with_char_embeddings: # compute combined embeddings [char_emb_question, char_emb_support] = self._conv_char_embedding( unique_word_chars, unique_word_char_length, [question_words2unique, support_words2unique]) emb_question = torch.cat([emb_question, char_emb_question], 2) emb_support = torch.cat([emb_support, char_emb_support], 2) # compute encoder features question_features = torch.autograd.Variable( torch.ones(batch_size, max_question_length, 2, out=float_tensor())) question_features = question_features.type_as(emb_question) v_wiqw = self._v_wiq_w # [B, L_q, L_s] wiq_w = torch.matmul(emb_question * v_wiqw, emb_support.transpose(1, 2)) # [B, L_q, L_s] wiq_w = wiq_w + support_mask.unsqueeze(1) wiq_w = F.softmax(wiq_w.view(batch_size * max_question_length, -1), dim=1).view(batch_size, max_question_length, -1) # [B, L_s] wiq_w = torch.matmul(question_binary_mask.unsqueeze(1), wiq_w).squeeze(1) # [B, L , 2] support_features = torch.stack([word_in_question, wiq_w], dim=2) if self._with_char_embeddings: # highway layer to allow for interaction between concatenated embeddings emb_question = self._embedding_projection(emb_question) emb_support = self._embedding_projection(emb_support) emb_question = self._embedding_highway(emb_question) emb_support = self._embedding_highway(emb_support) # dropout dropout = self._shared_resources.config.get("dropout", 0.0) emb_question = F.dropout(emb_question, dropout, training=not is_eval) emb_support = F.dropout(emb_support, dropout, training=not is_eval) # extend embeddings with features emb_question_ext = torch.cat([emb_question, question_features], 2) emb_support_ext = torch.cat([emb_support, support_features], 2) # encode question and support # [B, L, 2 * size] encoded_question = self._bilstm(emb_question_ext)[0] encoded_support = self._bilstm(emb_support_ext)[0] # [B, L, size] encoded_support = F.tanh( F.linear(encoded_support, self._support_projection)) encoded_question = F.tanh( F.linear(encoded_question, self._question_projection)) start_scores, end_scores, predicted_start_pointer, predicted_end_pointer = \ self._answer_layer(encoded_question, question_length, encoded_support, support_length, correct_start, answer2support, is_eval) # no multi paragraph support yet doc_idx = torch.autograd.Variable( torch.zeros(predicted_start_pointer.data.shape[0], out=long_tensor())) span = torch.stack( [doc_idx, predicted_start_pointer, predicted_end_pointer], 1) return start_scores, end_scores, span
def forward(self, encoded_question, question_length, encoded_support, support_length, correct_start, answer2question, is_eval): # casting long_tensor = torch.cuda.LongTensor if encoded_question.is_cuda else torch.LongTensor answer2question = answer2question.type(long_tensor) # computing single time attention over question attention_scores = self._linear_question_attention(encoded_question) q_mask = misc.mask_for_lengths(question_length) attention_scores = attention_scores.squeeze(2) + q_mask question_attention_weights = F.softmax(attention_scores, dim=1) question_state = torch.matmul(question_attention_weights.unsqueeze(1), encoded_question).squeeze(1) # Prediction # start start_input = torch.cat( [question_state.unsqueeze(1) * encoded_support, encoded_support], 2) q_start_state = self._linear_q_start( start_input) + self._linear_q_start_q(question_state).unsqueeze(1) start_scores = self._linear_start_scores( F.relu(q_start_state)).squeeze(2) support_mask = misc.mask_for_lengths(support_length) start_scores = start_scores + support_mask _, predicted_start_pointer = start_scores.max(1) def align(t): return torch.index_select(t, 0, answer2question) if is_eval: start_pointer = predicted_start_pointer else: # use correct start during training, because p(end|start) should be optimized start_pointer = correct_start.type(long_tensor) predicted_start_pointer = align(predicted_start_pointer) start_scores = align(start_scores) start_input = align(start_input) encoded_support = align(encoded_support) question_state = align(question_state) support_mask = align(support_mask) # end u_s = [] for b, p in enumerate(start_pointer): u_s.append(encoded_support[b, p.data[0]]) u_s = torch.stack(u_s) end_input = torch.cat( [encoded_support * u_s.unsqueeze(1), start_input], 2) q_end_state = self._linear_q_end(end_input) + self._linear_q_end_q( question_state).unsqueeze(1) end_scores = self._linear_end_scores(F.relu(q_end_state)).squeeze(2) end_scores = end_scores + support_mask max_support = support_length.max().data[0] if is_eval: end_scores += misc.mask_for_lengths(start_pointer, max_support, mask_right=False) _, predicted_end_pointer = end_scores.max(1) return start_scores, end_scores, predicted_start_pointer, predicted_end_pointer