def forward(self, question_embed, question_mask, column_embed, column_name_mask, column_mask): B, C_L, N_L, embed_D = list(column_embed.size()) # Column Encoder encoded_column = utils.encode_column(column_embed, column_name_mask, self.column_rnn) encoded_question, _ = self.question_rnn(question_embed) if self.column_attention: attn_matrix = torch.bmm( encoded_column, self.linear_attn(encoded_question).transpose(1, 2)) attn_matrix = f.add_masked_value(attn_matrix, question_mask.unsqueeze(1), value=-1e7) attn_matrix = F.softmax(attn_matrix, dim=-1) attn_question = (encoded_question.unsqueeze(1) * attn_matrix.unsqueeze(3)).sum(2) else: attn_matrix = self.seq_attn(encoded_question, question_mask) attn_question = f.weighted_sum(attn_matrix, encoded_question) attn_question = attn_question.unsqueeze(1) logits = self.mlp( self.linear_question(attn_question) + self.linear_column(encoded_column)).squeeze() logits = f.add_masked_value(logits, column_mask, value=-1e7) return logits
def forward(self, context_embed, question_embed, context_mask=None, question_mask=None): C, Q = context_embed, question_embed B, C_L, Q_L, D = C.size(0), C.size(1), Q.size(1), Q.size(2) similarity_matrix_shape = torch.zeros(B, C_L, Q_L, D) # (B, C_L, Q_L, D) C_ = C.unsqueeze(2).expand_as(similarity_matrix_shape) Q_ = Q.unsqueeze(1).expand_as(similarity_matrix_shape) C_Q = torch.mul(C_, Q_) S = self.W_0(torch.cat([C_, Q_, C_Q], 3)).squeeze(3) # (B, C_L, Q_L) S_question = S if question_mask is not None: S_question = f.add_masked_value(S_question, question_mask.unsqueeze(1), value=-1e7) S_q = F.softmax(S_question, 2) # (B, C_L, Q_L) S_context = S.transpose(1, 2) if context_mask is not None: S_context = f.add_masked_value(S_context, context_mask.unsqueeze(1), value=-1e7) S_c = F.softmax(S_context, 2) # (B, Q_L, C_L) A = torch.bmm(S_q, Q) # context2query (B, C_L, D) B = torch.bmm(S_q, S_c).bmm(C) # query2context (B, Q_L, D) out = torch.cat([C, A, C * A, C * B], dim=-1) return out
def decode_then_output( self, encoded_used_column, encoded_question, question_mask, decoder_input, decoder_hidden=None, ): B = encoded_used_column.size(0) decoder_output, decoder_hidden = self.decoder( decoder_input.view(B * self.column_maxlen, -1, self.token_maxlen), decoder_hidden) decoder_output = decoder_output.contiguous().view( B, self.column_maxlen, -1, self.model_dim) decoder_output = decoder_output.unsqueeze(3) logits = self.mlp( self.linear_column(encoded_used_column) + self.linear_conds(decoder_output) + self.linear_question(encoded_question)).squeeze() logits = f.add_masked_value(logits, question_mask.unsqueeze(1).unsqueeze(1), value=-1e7) return logits, decoder_hidden
def forward(self, question_embed, question_mask, column_embed, column_name_mask, col_idx): B, C_L, N_L, embed_D = list(column_embed.size()) # Column Encoder encoded_column = utils.encode_column(column_embed, column_name_mask, self.column_rnn) encoded_used_column = utils.filter_used_column( encoded_column, col_idx, padding_count=self.column_maxlen) encoded_question, _ = self.question_rnn(question_embed) if self.column_attention: attn_matrix = torch.matmul( self.linear_attn(encoded_question).unsqueeze(1), encoded_used_column.unsqueeze(3)).squeeze() attn_matrix = f.add_masked_value(attn_matrix, question_mask.unsqueeze(1), value=-1e7) attn_matrix = F.softmax(attn_matrix, dim=-1) attn_question = (encoded_question.unsqueeze(1) * attn_matrix.unsqueeze(3)).sum(2) else: attn_matrix = self.seq_attn(encoded_question, question_mask) attn_question = f.weighted_sum(attn_matrix, encoded_question) attn_question = attn_question.unsqueeze(1) return self.mlp( self.linear_question(attn_question) + self.linear_column(encoded_used_column)).squeeze()
def _scaled_dot_product(self, query, key, value, mask=None): K_D = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(K_D) if mask is not None: mask = mask.unsqueeze(1).unsqueeze(1) # [B, #H, C_L, D] scores = f.add_masked_value(scores, mask, value=-1e7) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) return torch.matmul(attn, value)
def forward(self, context, context_mask, query, query_mask): c, c_mask, q, q_mask = context, context_mask, query, query_mask S = self._make_similiarity_matrix(c, q) # (B, C_L, Q_L) masked_S = f.add_masked_value(S, query_mask.unsqueeze(1), value=-1e7) c2q = self._context2query(S, q, q_mask) q2c = self._query2context(masked_S.max(dim=-1)[0], c, c_mask) # [h; u˜; h◦u˜; h◦h˜] ~ (B, C_L, 8d) G = torch.cat((c, c2q, c * c2q, c * q2c), dim=-1) return G
def test_add_masked_value(): a = torch.rand(3, 5) a_mask = torch.FloatTensor([ [1, 1, 1, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 1, 1], ]) tensor = f.add_masked_value(a, a_mask, value=100) assert tensor[0][3] == 100 assert tensor[0][4] == 100 assert tensor[1][2] == 100 assert tensor[1][3] == 100 assert tensor[1][4] == 100
def forward(self, x, x_mask, key, key_mask): S = self._trilinear(x, key) if self.self_attn: seq_length = x.size(1) diag_mask = self.diag_mask.narrow(0, 0, seq_length).narrow( 1, 0, seq_length) joint_mask = 1 - self._compute_attention_mask(x_mask, key_mask) mask = torch.clamp(diag_mask + joint_mask, 0, 1) masked_S = S + mask * (-1e7) x2key = self._x2key(masked_S, key, key_mask) return torch.cat((x, x2key, x * x2key), dim=-1) else: joint_mask = 1 - self._compute_attention_mask(x_mask, key_mask) masked_S = S + joint_mask * (-1e7) x2key = self._x2key(masked_S, key, key_mask) masked_S = f.add_masked_value(S, key_mask.unsqueeze(1), value=-1e7) key2x = self._key2x(masked_S.max(dim=-1)[0], x, x_mask) return torch.cat((x, x2key, x * x2key, x * key2x), dim=-1)
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - answer_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward context_seq_config = f.get_sorted_seq_config(context) query_seq_config = f.get_sorted_seq_config(question) # Embedding Layer (Char + Word -> Contextual) query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() query_mask = f.get_mask_from_tokens(question).float() B, C_L = context_embed.size(0), context_embed.size(1) context_embed = self.context_highway(context_embed) query_embed = self.query_highway(query_embed) context_encoded = f.forward_rnn_with_pack(self.context_contextual_rnn, context_embed, context_seq_config) context_encoded = self.dropout(context_encoded) query_encoded = f.forward_rnn_with_pack(self.query_contextual_rnn, query_embed, query_seq_config) query_encoded = self.dropout(query_encoded) # Attention Flow Layer attention_context_query = self.attention(context_encoded, context_mask, query_encoded, query_mask) # Modeling Layer modeled_context = f.forward_rnn_with_pack(self.modeling_rnn, attention_context_query, context_seq_config) modeled_context = self.dropout(modeled_context) M_D = modeled_context.size(-1) # Output Layer span_start_input = self.dropout( torch.cat([attention_context_query, modeled_context], dim=-1)) # (B, C_L, 10d) span_start_logits = self.span_start_linear(span_start_input).squeeze( -1) # (B, C_L) span_start_probs = f.masked_softmax(span_start_logits, context_mask) span_start_representation = f.weighted_sum(attention=span_start_probs, matrix=modeled_context) tiled_span_start_representation = span_start_representation.unsqueeze( 1).expand(B, C_L, M_D) span_end_representation = torch.cat( [ attention_context_query, modeled_context, tiled_span_start_representation, modeled_context * tiled_span_start_representation, ], dim=-1, ) encoded_span_end = f.forward_rnn_with_pack(self.output_end_rnn, span_end_representation, context_seq_config) encoded_span_end = self.dropout(encoded_span_end) span_end_input = self.dropout( torch.cat([attention_context_query, encoded_span_end], dim=-1)) span_end_logits = self.span_end_linear(span_end_input).squeeze(-1) # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) # No_Answer Bias bias = self.bias.expand(B, 1) span_start_logits = torch.cat([span_start_logits, bias], dim=-1) span_end_logits = torch.cat([span_end_logits, bias], dim=-1) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span( span_start_logits[:, :-1], span_end_logits[:, :-1], answer_maxlen=self.answer_maxlen, # except no_answer bias ), } if labels: answer_idx = labels["answer_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] answerable = labels["answerable"] # No_Asnwer Case C_L = context_mask.size(1) answer_start_idx = answer_start_idx.masked_fill( answerable.eq(0), C_L) answer_end_idx = answer_end_idx.masked_fill(answerable.eq(0), C_L) output_dict["answer_idx"] = answer_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze( 0) # NOTE: DataParallel concat Error return output_dict
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - data_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] # aka paragraph question = features["question"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward context_seq_config = f.get_sorted_seq_config(context) query_seq_config = f.get_sorted_seq_config(question) # Embedding query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() query_mask = f.get_mask_from_tokens(question).float() context_embed = self.dropout(context_embed) query_embed = self.dropout(query_embed) # RNN (LSTM) context_encoded = f.forward_rnn_with_pack(self.paragraph_rnn, context_embed, context_seq_config) context_encoded = self.dropout(context_encoded) query_encoded = f.forward_rnn_with_pack( self.query_rnn, query_embed, query_seq_config) # (B, Q_L, H*2) query_encoded = self.dropout(query_encoded) query_attention = self.query_att(query_encoded, query_mask) # (B, Q_L) query_att_sum = f.weighted_sum(query_attention, query_encoded) # (B, H*2) span_start_logits = self.start_attn(context_encoded, query_att_sum, context_mask) span_end_logits = self.end_attn(context_encoded, query_att_sum, context_mask) # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span(span_start_logits, span_end_logits, answer_maxlen=self.answer_maxlen), } if labels: data_idx = labels["data_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] output_dict["data_idx"] = data_idx loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze(0) return output_dict
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"sequence": [0, 3, 4, 1]} * Kwargs: label: label dictionary like below. {"class_idx": 2, "data_idx": 0} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - sequence_embed: embedding vector of the sequence - class_logits: representing unnormalized log probabilities of the class. - class_idx: target class idx - data_idx: data idx - loss: a scalar loss to be optimized """ sequence = features["sequence"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward sequence_config = f.get_sorted_seq_config(sequence) token_embed = self.token_embedder(sequence) token_encodings = f.forward_rnn_with_pack( self.encoder, token_embed, sequence_config ) # [B, L, encoding_rnn_hidden_dim] attention = self.A(token_encodings).transpose(1, 2) # [B, num_attention_heads, L] sequence_mask = f.get_mask_from_tokens(sequence).float() # [B, L] sequence_mask = sequence_mask.unsqueeze(1).expand_as(attention) attention = F.softmax(f.add_masked_value(attention, sequence_mask) + 1e-13, dim=2) attended_encodings = torch.bmm( attention, token_encodings ) # [B, num_attention_heads, sequence_embed_dim] sequence_embed = self.fully_connected( attended_encodings.view(attended_encodings.size(0), -1) ) # [B, sequence_embed_dim] class_logits = self.classifier(sequence_embed) # [B, num_classes] output_dict = {"sequence_embed": sequence_embed, "class_logits": class_logits} if labels: class_idx = labels["class_idx"] data_idx = labels["data_idx"] output_dict["class_idx"] = class_idx output_dict["data_idx"] = data_idx # Loss loss = self.criterion(class_logits, class_idx) loss += self.penalty(attention) output_dict["loss"] = loss.unsqueeze(0) # NOTE: DataParallel concat Error return output_dict
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - answer_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward context_seq_config = f.get_sorted_seq_config(context) query_seq_config = f.get_sorted_seq_config(question) # Embedding query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() # B X 1 X C_L query_mask = f.get_mask_from_tokens(question).float() # B X 1 X Q_L # Pre-process context_embed = self.dropout(context_embed) context_encoded = f.forward_rnn_with_pack(self.context_preprocess_rnn, context_embed, context_seq_config) context_encoded = self.dropout(context_encoded) query_embed = self.dropout(query_embed) query_encoded = f.forward_rnn_with_pack(self.query_preprocess_rnn, query_embed, query_seq_config) query_encoded = self.dropout(query_encoded) # Attention -> Projection context_attnded = self.bi_attention(context_encoded, context_mask, query_encoded, query_mask) context_attnded = self.activation_fn( self.attn_linear(context_attnded)) # B X C_L X dim*2 # Residual Self-Attention context_attnded = self.dropout(context_attnded) context_encoded = f.forward_rnn_with_pack(self.modeling_rnn, context_attnded, context_seq_config) context_encoded = self.dropout(context_encoded) context_self_attnded = self.self_attention( context_encoded, context_mask) # B X C_L X dim*2 context_final = self.dropout(context_attnded + context_self_attnded) # B X C_L X dim*2 # Prediction span_start_input = f.forward_rnn_with_pack( self.span_start_rnn, context_final, context_seq_config) # B X C_L X dim*2 span_start_input = self.dropout(span_start_input) span_start_logits = self.span_start_linear(span_start_input).squeeze( -1) # B X C_L span_end_input = torch.cat([span_start_input, context_final], dim=-1) # B X C_L X dim*4 span_end_input = f.forward_rnn_with_pack( self.span_end_rnn, span_end_input, context_seq_config) # B X C_L X dim*2 span_end_input = self.dropout(span_end_input) span_end_logits = self.span_end_linear(span_end_input).squeeze( -1) # B X C_L # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span(span_start_logits, span_end_logits, answer_maxlen=self.answer_maxlen), } if labels: answer_idx = labels["answer_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] output_dict["answer_idx"] = answer_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze( 0) # NOTE: DataParallel concat Error return output_dict
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - data_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # 1. Input Embedding Layer query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() query_mask = f.get_mask_from_tokens(question).float() context_embed = self.context_highway(context_embed) context_embed = self.dropout(context_embed) context_embed = self.context_embed_pointwise_conv(context_embed) query_embed = self.query_highway(query_embed) query_embed = self.dropout(query_embed) query_embed = self.query_embed_pointwise_conv(query_embed) # 2. Embedding Encoder Layer for encoder_block in self.embed_encoder_blocks: context = encoder_block(context_embed) context_embed = context query = encoder_block(query_embed) query_embed = query # 3. Context-Query Attention Layer context_query_attention = self.co_attention(context, query, context_mask, query_mask) # Projection (memory issue) context_query_attention = self.pointwise_conv(context_query_attention) context_query_attention = self.dropout(context_query_attention) # 4. Model Encoder Layer model_encoder_block_inputs = context_query_attention # Stacked Model Encoder Block stacked_model_encoder_blocks = [] for i in range(3): for _, model_encoder_block in enumerate(self.model_encoder_blocks): output = model_encoder_block(model_encoder_block_inputs, context_mask) model_encoder_block_inputs = output stacked_model_encoder_blocks.append(output) # 5. Output Layer span_start_inputs = torch.cat( [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[1]], dim=-1) span_start_inputs = self.dropout(span_start_inputs) span_start_logits = self.span_start_linear(span_start_inputs).squeeze( -1) span_end_inputs = torch.cat( [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[2]], dim=-1) span_end_inputs = self.dropout(span_end_inputs) span_end_logits = self.span_end_linear(span_end_inputs).squeeze(-1) # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span(span_start_logits, span_end_logits), } if labels: data_idx = labels["data_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] output_dict["data_idx"] = data_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze( 0) # NOTE: DataParallel concat Error return output_dict