def test_last_dim_masked_softmax_with_2_dim(): tensor = torch.FloatTensor([ [2, 3, 1, 0, 0], [4, 1, 0, 0, 0], [1, 5, 2, 4, 1], ]) mask = f.get_mask_from_tokens({"word": tensor}).float() result = f.last_dim_masked_softmax(tensor, mask) assert result.argmax(dim=-1).equal(torch.LongTensor([1, 0, 1]))
def test_get_mask_from_tokens_with_2_dim(): tokens = { "word" : torch.LongTensor([ [1, 1, 1, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 1, 1], ]), } mask = f.get_mask_from_tokens(tokens) print(mask) assert mask.equal(tokens["word"])
def test_get_mask_from_tokens_with_3_dim(): tokens = { "char" : torch.LongTensor([ [[4, 2], [3, 6], [0, 0]], [[5, 1], [0, 0], [0, 0]], [[1, 3], [2, 4], [3, 6]], ]), } mask = f.get_mask_from_tokens(tokens) expect_tensor = torch.LongTensor([ [1, 1, 0], [1, 0, 0], [1, 1, 1], ]) assert mask.equal(expect_tensor)
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - answer_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward context_seq_config = f.get_sorted_seq_config(context) query_seq_config = f.get_sorted_seq_config(question) # Embedding Layer (Char + Word -> Contextual) query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() query_mask = f.get_mask_from_tokens(question).float() B, C_L = context_embed.size(0), context_embed.size(1) context_embed = self.context_highway(context_embed) query_embed = self.query_highway(query_embed) context_encoded = f.forward_rnn_with_pack(self.context_contextual_rnn, context_embed, context_seq_config) context_encoded = self.dropout(context_encoded) query_encoded = f.forward_rnn_with_pack(self.query_contextual_rnn, query_embed, query_seq_config) query_encoded = self.dropout(query_encoded) # Attention Flow Layer attention_context_query = self.attention(context_encoded, context_mask, query_encoded, query_mask) # Modeling Layer modeled_context = f.forward_rnn_with_pack(self.modeling_rnn, attention_context_query, context_seq_config) modeled_context = self.dropout(modeled_context) M_D = modeled_context.size(-1) # Output Layer span_start_input = self.dropout( torch.cat([attention_context_query, modeled_context], dim=-1)) # (B, C_L, 10d) span_start_logits = self.span_start_linear(span_start_input).squeeze( -1) # (B, C_L) span_start_probs = f.masked_softmax(span_start_logits, context_mask) span_start_representation = f.weighted_sum(attention=span_start_probs, matrix=modeled_context) tiled_span_start_representation = span_start_representation.unsqueeze( 1).expand(B, C_L, M_D) span_end_representation = torch.cat( [ attention_context_query, modeled_context, tiled_span_start_representation, modeled_context * tiled_span_start_representation, ], dim=-1, ) encoded_span_end = f.forward_rnn_with_pack(self.output_end_rnn, span_end_representation, context_seq_config) encoded_span_end = self.dropout(encoded_span_end) span_end_input = self.dropout( torch.cat([attention_context_query, encoded_span_end], dim=-1)) span_end_logits = self.span_end_linear(span_end_input).squeeze(-1) # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) # No_Answer Bias bias = self.bias.expand(B, 1) span_start_logits = torch.cat([span_start_logits, bias], dim=-1) span_end_logits = torch.cat([span_end_logits, bias], dim=-1) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span( span_start_logits[:, :-1], span_end_logits[:, :-1], answer_maxlen=self.answer_maxlen, # except no_answer bias ), } if labels: answer_idx = labels["answer_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] answerable = labels["answerable"] # No_Asnwer Case C_L = context_mask.size(1) answer_start_idx = answer_start_idx.masked_fill( answerable.eq(0), C_L) answer_end_idx = answer_end_idx.masked_fill(answerable.eq(0), C_L) output_dict["answer_idx"] = answer_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze( 0) # NOTE: DataParallel concat Error return output_dict
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - data_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] # aka paragraph question = features["question"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward context_seq_config = f.get_sorted_seq_config(context) query_seq_config = f.get_sorted_seq_config(question) # Embedding query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() query_mask = f.get_mask_from_tokens(question).float() context_embed = self.dropout(context_embed) query_embed = self.dropout(query_embed) # RNN (LSTM) context_encoded = f.forward_rnn_with_pack(self.paragraph_rnn, context_embed, context_seq_config) context_encoded = self.dropout(context_encoded) query_encoded = f.forward_rnn_with_pack( self.query_rnn, query_embed, query_seq_config) # (B, Q_L, H*2) query_encoded = self.dropout(query_encoded) query_attention = self.query_att(query_encoded, query_mask) # (B, Q_L) query_att_sum = f.weighted_sum(query_attention, query_encoded) # (B, H*2) span_start_logits = self.start_attn(context_encoded, query_att_sum, context_mask) span_end_logits = self.end_attn(context_encoded, query_att_sum, context_mask) # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span(span_start_logits, span_end_logits, answer_maxlen=self.answer_maxlen), } if labels: data_idx = labels["data_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] output_dict["data_idx"] = data_idx loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze(0) return output_dict
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"sequence": [0, 3, 4, 1]} * Kwargs: label: label dictionary like below. {"class_idx": 2, "data_idx": 0} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - sequence_embed: embedding vector of the sequence - class_logits: representing unnormalized log probabilities of the class. - class_idx: target class idx - data_idx: data idx - loss: a scalar loss to be optimized """ sequence = features["sequence"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward sequence_config = f.get_sorted_seq_config(sequence) token_embed = self.token_embedder(sequence) token_encodings = f.forward_rnn_with_pack( self.encoder, token_embed, sequence_config ) # [B, L, encoding_rnn_hidden_dim] attention = self.A(token_encodings).transpose(1, 2) # [B, num_attention_heads, L] sequence_mask = f.get_mask_from_tokens(sequence).float() # [B, L] sequence_mask = sequence_mask.unsqueeze(1).expand_as(attention) attention = F.softmax(f.add_masked_value(attention, sequence_mask) + 1e-13, dim=2) attended_encodings = torch.bmm( attention, token_encodings ) # [B, num_attention_heads, sequence_embed_dim] sequence_embed = self.fully_connected( attended_encodings.view(attended_encodings.size(0), -1) ) # [B, sequence_embed_dim] class_logits = self.classifier(sequence_embed) # [B, num_classes] output_dict = {"sequence_embed": sequence_embed, "class_logits": class_logits} if labels: class_idx = labels["class_idx"] data_idx = labels["data_idx"] output_dict["class_idx"] = class_idx output_dict["data_idx"] = data_idx # Loss loss = self.criterion(class_logits, class_idx) loss += self.penalty(attention) output_dict["loss"] = loss.unsqueeze(0) # NOTE: DataParallel concat Error return output_dict
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - answer_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward context_seq_config = f.get_sorted_seq_config(context) query_seq_config = f.get_sorted_seq_config(question) # Embedding query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() # B X 1 X C_L query_mask = f.get_mask_from_tokens(question).float() # B X 1 X Q_L # Pre-process context_embed = self.dropout(context_embed) context_encoded = f.forward_rnn_with_pack(self.context_preprocess_rnn, context_embed, context_seq_config) context_encoded = self.dropout(context_encoded) query_embed = self.dropout(query_embed) query_encoded = f.forward_rnn_with_pack(self.query_preprocess_rnn, query_embed, query_seq_config) query_encoded = self.dropout(query_encoded) # Attention -> Projection context_attnded = self.bi_attention(context_encoded, context_mask, query_encoded, query_mask) context_attnded = self.activation_fn( self.attn_linear(context_attnded)) # B X C_L X dim*2 # Residual Self-Attention context_attnded = self.dropout(context_attnded) context_encoded = f.forward_rnn_with_pack(self.modeling_rnn, context_attnded, context_seq_config) context_encoded = self.dropout(context_encoded) context_self_attnded = self.self_attention( context_encoded, context_mask) # B X C_L X dim*2 context_final = self.dropout(context_attnded + context_self_attnded) # B X C_L X dim*2 # Prediction span_start_input = f.forward_rnn_with_pack( self.span_start_rnn, context_final, context_seq_config) # B X C_L X dim*2 span_start_input = self.dropout(span_start_input) span_start_logits = self.span_start_linear(span_start_input).squeeze( -1) # B X C_L span_end_input = torch.cat([span_start_input, context_final], dim=-1) # B X C_L X dim*4 span_end_input = f.forward_rnn_with_pack( self.span_end_rnn, span_end_input, context_seq_config) # B X C_L X dim*2 span_end_input = self.dropout(span_end_input) span_end_logits = self.span_end_linear(span_end_input).squeeze( -1) # B X C_L # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span(span_start_logits, span_end_logits, answer_maxlen=self.answer_maxlen), } if labels: answer_idx = labels["answer_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] output_dict["answer_idx"] = answer_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze( 0) # NOTE: DataParallel concat Error return output_dict
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - data_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # 1. Input Embedding Layer query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() query_mask = f.get_mask_from_tokens(question).float() context_embed = self.context_highway(context_embed) context_embed = self.dropout(context_embed) context_embed = self.context_embed_pointwise_conv(context_embed) query_embed = self.query_highway(query_embed) query_embed = self.dropout(query_embed) query_embed = self.query_embed_pointwise_conv(query_embed) # 2. Embedding Encoder Layer for encoder_block in self.embed_encoder_blocks: context = encoder_block(context_embed) context_embed = context query = encoder_block(query_embed) query_embed = query # 3. Context-Query Attention Layer context_query_attention = self.co_attention(context, query, context_mask, query_mask) # Projection (memory issue) context_query_attention = self.pointwise_conv(context_query_attention) context_query_attention = self.dropout(context_query_attention) # 4. Model Encoder Layer model_encoder_block_inputs = context_query_attention # Stacked Model Encoder Block stacked_model_encoder_blocks = [] for i in range(3): for _, model_encoder_block in enumerate(self.model_encoder_blocks): output = model_encoder_block(model_encoder_block_inputs, context_mask) model_encoder_block_inputs = output stacked_model_encoder_blocks.append(output) # 5. Output Layer span_start_inputs = torch.cat( [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[1]], dim=-1) span_start_inputs = self.dropout(span_start_inputs) span_start_logits = self.span_start_linear(span_start_inputs).squeeze( -1) span_end_inputs = torch.cat( [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[2]], dim=-1) span_end_inputs = self.dropout(span_end_inputs) span_end_logits = self.span_end_linear(span_end_inputs).squeeze(-1) # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span(span_start_logits, span_end_logits), } if labels: data_idx = labels["data_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] output_dict["data_idx"] = data_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze( 0) # NOTE: DataParallel concat Error return output_dict
def forward(self, features, labels=None): column = features["column"] question = features["question"] column_embed = self.token_embedder(column) question_embed = self.token_embedder(question) B, C_L = column_embed.size(0), column_embed.size(1) column_indexed = column[next(iter(column))] column_name_mask = column_indexed.gt(0).float() # NOTE: hard-code column_lengths = utils.get_column_lengths(column_embed, column_name_mask) column_mask = column_lengths.view(B, C_L).gt(0).float() # NOTE: hard-code question_mask = f.get_mask_from_tokens(question).float() agg_logits = self.agg_predictor(question_embed, question_mask) sel_logits = self.sel_predictor(question_embed, question_mask, column_embed, column_name_mask, column_mask) conds_col_idx, conds_val_pos = None, None if labels: data_idx = labels["data_idx"] ground_truths = self._dataset.get_ground_truths(data_idx) conds_col_idx = [ ground_truth["conds_col"] for ground_truth in ground_truths ] conds_val_pos = [ ground_truth["conds_val_pos"] for ground_truth in ground_truths ] conds_logits = self.conds_predictor( question_embed, question_mask, column_embed, column_name_mask, column_mask, conds_col_idx, conds_val_pos, ) # Convert GPU to CPU agg_logits = agg_logits.cpu() sel_logits = sel_logits.cpu() conds_logits = [logits.cpu() for logits in conds_logits] output_dict = { "agg_logits": agg_logits, "sel_logits": sel_logits, "conds_logits": conds_logits, } if labels: data_idx = labels["data_idx"] output_dict["data_id"] = data_idx ground_truths = self._dataset.get_ground_truths(data_idx) # Aggregator, Select Column target_agg_idx = torch.LongTensor( [ground_truth["agg_idx"] for ground_truth in ground_truths]) target_sel_idx = torch.LongTensor( [ground_truth["sel_idx"] for ground_truth in ground_truths]) loss = 0 loss += self.cross_entropy(agg_logits, target_agg_idx) loss += self.cross_entropy(sel_logits, target_sel_idx) conds_num_logits, conds_column_logits, conds_op_logits, conds_value_logits = ( conds_logits) # Conditions # 1. The number of conditions target_conds_num = torch.LongTensor( [ground_truth["conds_num"] for ground_truth in ground_truths]) target_conds_column = [ ground_truth["conds_col"] for ground_truth in ground_truths ] loss += self.cross_entropy(conds_num_logits, target_conds_num) # 2. Columns of conditions B = conds_column_logits.size(0) target_conds_columns = np.zeros(list(conds_column_logits.size()), dtype=np.float32) for i in range(B): target_conds_column_idx = target_conds_column[i] if len(target_conds_column_idx) == 0: continue target_conds_columns[i][target_conds_column_idx] = 1 target_conds_columns = torch.from_numpy(target_conds_columns) conds_column_probs = torch.sigmoid(conds_column_logits) bce_loss = -torch.mean(self.conds_column_loss_alpha * (target_conds_columns * torch.log(conds_column_probs + 1e-10)) + (1 - target_conds_columns) * torch.log(1 - conds_column_probs + 1e-10)) loss += bce_loss # 3. Operator of conditions conds_op_loss = 0 for i in range(B): target_conds_op = ground_truths[i]["conds_op"] if len(target_conds_op) == 0: continue target_conds_op = torch.from_numpy(np.array(target_conds_op)) logits_conds_op = conds_op_logits[i, :len(target_conds_op)] target_op_count = len(target_conds_op) conds_op_loss += ( self.cross_entropy(logits_conds_op, target_conds_op) / target_op_count) loss += conds_op_loss # 4. Value of conditions conds_val_pos = [ ground_truth["conds_val_pos"] for ground_truth in ground_truths ] conds_value_loss = 0 for i in range(B): for j in range(len(conds_val_pos[i])): cond_val_pos = conds_val_pos[i][j] if len(cond_val_pos) == 1: continue target_cond_val_pos = torch.from_numpy( np.array(cond_val_pos[1:])) # index 0: START_TOKEN logits_cond_val_pos = conds_value_logits[ i, j, :len(cond_val_pos) - 1] conds_value_loss += self.cross_entropy( logits_cond_val_pos, target_cond_val_pos) / len( conds_val_pos[i]) loss += conds_value_loss / B output_dict["loss"] = loss.unsqueeze(0) return output_dict
def forward(self, context, query, context_params={}, query_params={}, query_align=False): """ * Args: context: context inputs (eg. {"token_name1": tensor, "token_name2": tensor, ...}) query: query inputs (eg. {"token_name1": tensor, "token_name2": tensor, ...}) * Kwargs: context_params: custom context parameters query_params: query context parameters query_align: f_align(p_i) = sum(a_ij, E(qj), where the attention score a_ij captures the similarity between pi and each question words q_j. these features add soft alignments between similar but non-identical words (e.g., car and vehicle) it only apply to 'context_embed'. """ if set(self.token_names) != set(context.keys()): raise ValueError( f"Mismatch token_names inputs: {context.keys()}, embeddings: {self.token_names}" ) context_tokens, query_tokens = {}, {} for token_name, context_tensors in context.items(): embedding = getattr(self, token_name) context_tokens[token_name] = embedding( context_tensors, **context_params.get(token_name, {})) if token_name in query: query_tokens[token_name] = embedding( query[token_name], **query_params.get(token_name, {})) # query_align_embedding if query_align: common_context = self._filter(context_tokens, exclusive=False) embedded_common_context = torch.cat(list(common_context.values()), dim=-1) exclusive_context = self._filter(context_tokens, exclusive=True) embedded_exclusive_context = None if exclusive_context != {}: embedded_exclusive_context = torch.cat(list( exclusive_context.values()), dim=-1) query_mask = f.get_mask_from_tokens(query_tokens) embedded_query = torch.cat(list(query_tokens.values()), dim=-1) embedded_aligned_query = self.align_attention( embedded_common_context, embedded_query, query_mask) # Merge context embedded embedded_context = [ embedded_common_context, embedded_aligned_query ] if embedded_exclusive_context is not None: embedded_context.append(embedded_exclusive_context) context_output = torch.cat(embedded_context, dim=-1) query_output = embedded_query else: context_output = torch.cat(list(context_tokens.values()), dim=-1) query_output = torch.cat(list(query_tokens.values()), dim=-1) return context_output, query_output