def test_accuracy_computation(self): accuracy = BooleanAccuracy() predictions = torch.Tensor([[0, 1], [2, 3], [4, 5], [6, 7]]) targets = torch.Tensor([[0, 1], [2, 2], [4, 5], [7, 7]]) accuracy(predictions, targets) assert accuracy.get_metric() == 2. / 4 mask = torch.ones(4, 2) mask[1, 1] = 0 accuracy(predictions, targets, mask) assert accuracy.get_metric() == 5. / 8 targets[1, 1] = 3 accuracy(predictions, targets) assert accuracy.get_metric() == 8. / 12 accuracy.reset() accuracy(predictions, targets) assert accuracy.get_metric() == 3. / 4
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self)
class BertSpanPointerResolution(Model): """该模型同时预测mask位置以及span的起始位置""" def __init__(self, vocab: Vocabulary, model_name: str = None, start_attention: Attention = None, end_attention: Attention = None, text_field_embedder: TextFieldEmbedder = None, task_pretrained_file: str = None, neg_sample_ratio: float = 0.0, max_turn_len: int = 3, start_token: str = "[CLS]", end_token: str = "[SEP]", index_name: str = "bert", eps: float = 1e-8, seed: int = 42, loss_factor: float = 1.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None): super().__init__(vocab, regularizer) if model_name is None and text_field_embedder is None: raise ValueError( f"`model_name` and `text_field_embedder` can't both equal to None." ) # 单纯的resolution任务,只需要返回最后一层的embedding表征即可 self._text_field_embedder = text_field_embedder or PretrainedChineseBertMismatchedEmbedder( model_name, return_all=False, output_hidden_states=False, max_turn_length=max_turn_len) seed_everything(seed) self._neg_sample_ratio = neg_sample_ratio self._start_token = start_token self._end_token = end_token self._index_name = index_name self._initializer = initializer linear_input_size = self._text_field_embedder.get_output_dim() # 使用attention的方法 self.start_attention = start_attention or BilinearAttention( vector_dim=linear_input_size, matrix_dim=linear_input_size) self.end_attention = end_attention or BilinearAttention( vector_dim=linear_input_size, matrix_dim=linear_input_size) # mask的指标,主要考虑F-score,而且我们更加关注`1`的召回率 self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._rewrite_em = RewriteEM(valid_keys="semr,nr_semr,re_semr") self._restore_score = RestorationScore(compute_restore_tokens=True) self._metrics = [ TokenBasedBLEU(mode="1,2"), TokenBasedROUGE(mode="1r,2r") ] self._eps = eps self._loss_factor = loss_factor self._initializer(self.start_attention) self._initializer(self.end_attention) # 加载其他任务预训练的模型 if task_pretrained_file is not None and os.path.isfile( task_pretrained_file): logger.info("loading related task pretrained weights...") self.load_state_dict(torch.load(task_pretrained_file), strict=False) def _calc_loss(self, span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, use_mask_label: torch.Tensor, start_label: torch.Tensor, end_label: torch.Tensor, best_spans: torch.Tensor): batch_size = start_label.size(0) # 常规loss loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-1) # --- 计算start和end标签对应的loss --- # 选择出mask_label等于1的位置对应的start和end的结果 # [B_mask, ] span_start_label = start_label.masked_select( use_mask_label.to(dtype=torch.bool)) span_end_label = end_label.masked_select( use_mask_label.to(dtype=torch.bool)) # mask掉大部分为0的标签来计算准确率 train_span_mask = (span_start_label != -1) # [B_mask, 2] answer_spans = torch.stack([span_start_label, span_end_label], dim=-1) self._span_accuracy( best_spans, answer_spans, train_span_mask.unsqueeze(-1).expand_as(best_spans)) # -- 计算start_loss -- start_losses = loss_fct(span_start_logits, span_start_label) # start_label_weight = self._calc_loss_weight(span_start_label) # 计算标签的weight start_loss = torch.sum(start_losses) / batch_size # 对loss的值进行检查 big_constant = min(torch.finfo(start_loss.dtype).max, 1e9) if torch.any(start_loss > big_constant): logger.critical("Start loss too high (%r)", start_loss) logger.critical("span_start_logits: %r", span_start_logits) logger.critical("span_start: %r", span_start_label) assert False # -- 计算end_loss -- end_losses = loss_fct(span_end_logits, span_end_label) # end_label_weight = self._calc_loss_weight(span_end_label) # 计算标签的weight end_loss = torch.sum(end_losses) / batch_size if torch.any(end_loss > big_constant): logger.critical("End loss too high (%r)", end_loss) logger.critical("span_end_logits: %r", span_end_logits) logger.critical("span_end: %r", span_end_label) assert False span_loss = (start_loss + end_loss) / 2 self._span_start_accuracy(span_start_logits, span_start_label, train_span_mask) self._span_end_accuracy(span_end_logits, span_end_label, train_span_mask) loss = span_loss return loss def _calc_loss_weight(self, label: torch.Tensor): label_mask = (label != 0).to(torch.float16) label_weight = label_mask * self._loss_factor + 1.0 return label_weight def _get_rewrite_result(self, use_mask_label: torch.Tensor, best_spans: torch.Tensor, query_lens: torch.Tensor, context_lens: torch.Tensor, metadata: List[Dict[str, Any]]): # 将两个标签转换成numpy类型 # [B, query_len] use_mask_label = use_mask_label.detach().cpu().numpy() # [B_mask, 2] best_spans = best_spans.detach().cpu().numpy().tolist() predict_rewrite_results = [] for cur_query_len, cur_context_len, cur_query_mask_labels, mdata in zip( query_lens, context_lens, use_mask_label, metadata): context_tokens = mdata['context_tokens'] query_tokens = mdata['query_tokens'] cur_rewrite_result = copy.deepcopy(query_tokens) already_insert_tokens = 0 # 记录已经插入的tokens的数量 already_insert_min_start = cur_context_len # 表示当前已经添加过的信息的最小的start already_insert_max_end = 0 # 表示当前已经添加过的信息的最大的end # 遍历当前mask的所有标签,如果标签为1,则计算对应的span_string for i in range(cur_query_len): cur_mask_label = cur_query_mask_labels[i] # 只有当预测的label为1时,才进行补充 if cur_mask_label: predict_start, predict_end = best_spans.pop(0) # 如果都为0则继续 if predict_start == 0 and predict_end == 0: continue # 如果start大于长度,则继续 if predict_start >= cur_context_len: continue # 如果当前想要插入的信息,在之前已经插入过信息的内部,则不再插入 if predict_start >= already_insert_min_start and predict_end <= already_insert_max_end: continue # 对位置进行矫正 if predict_start < 0 or context_tokens[ predict_start] == self._start_token: predict_start = 1 if predict_end >= cur_context_len: predict_end = cur_context_len - 1 # 获取预测的span predict_span_tokens = context_tokens[ predict_start:predict_end + 1] # 更新已经插入的最小的start和最大的end if predict_start < already_insert_min_start: already_insert_min_start = predict_start if predict_end > already_insert_max_end: already_insert_max_end = predict_end # 再对预测的span按照要求进行矫正,只取end_token之前的所有tokens try: index = predict_span_tokens.index(self._end_token) predict_span_tokens = predict_span_tokens[:index] except BaseException: pass # 获取当前span插入的位置 # 如果是要插入到当前位置后面,则需要+1 # 如果是要插入到当前位置前面,则不需要 cur_insert_index = i + already_insert_tokens cur_rewrite_result = cur_rewrite_result[:cur_insert_index] + \ predict_span_tokens + cur_rewrite_result[cur_insert_index:] # 记录插入的tokens的数量 already_insert_tokens += len(predict_span_tokens) cur_rewrite_result = cur_rewrite_result[:-1] # 不再以list of tokens的形式 # 而是以string的形式去计算 cur_rewrite_string = "".join(cur_rewrite_result) rewrite_tokens = mdata.get("rewrite_tokens", None) if rewrite_tokens is not None: rewrite_string = "".join(rewrite_tokens) # 去除[SEP]这个token query_string = "".join(query_tokens[:-1]) self._rewrite_em(cur_rewrite_string, rewrite_string, query_string) # 额外增加的指标 for metric in self._metrics: metric(cur_rewrite_result, rewrite_tokens) # 获取restore_tokens并计算对应的指标 restore_tokens = mdata.get("restore_tokens", None) self._restore_score(cur_rewrite_result, rewrite_tokens, queries=query_tokens[:-1], restore_tokens=restore_tokens) predict_rewrite_results.append("".join(cur_rewrite_result)) return predict_rewrite_results @overrides def forward(self, context_ids: TextFieldTensors, query_ids: TextFieldTensors, context_lens: torch.Tensor, query_lens: torch.Tensor, mask_label: Optional[torch.Tensor] = None, start_label: Optional[torch.Tensor] = None, end_label: Optional[torch.Tensor] = None, metadata: List[Dict[str, Any]] = None): # concat the context and query to the encoder # get the indexers first indexers = context_ids.keys() dialogue_ids = {} # 获取context和query的长度 context_len = torch.max(context_lens).item() query_len = torch.max(query_lens).item() # [B, _len] context_mask = get_mask_from_sequence_lengths(context_lens, context_len) query_mask = get_mask_from_sequence_lengths(query_lens, query_len) for indexer in indexers: # get the various variables of context and query dialogue_ids[indexer] = {} for key in context_ids[indexer].keys(): context = context_ids[indexer][key] query = query_ids[indexer][key] # concat the context and query in the length dim dialogue = torch.cat([context, query], dim=1) dialogue_ids[indexer][key] = dialogue # get the outputs of the dialogue if isinstance(self._text_field_embedder, TextFieldEmbedder): embedder_outputs = self._text_field_embedder(dialogue_ids) else: embedder_outputs = self._text_field_embedder( **dialogue_ids[self._index_name]) # get the outputs of the query and context # [B, _len, embed_size] context_last_layer = embedder_outputs[:, :context_len].contiguous() query_last_layer = embedder_outputs[:, context_len:].contiguous() # ------- 计算span预测的结果 ------- # 我们想要知道query中的每一个mask位置的token后面需要补充的内容 # 也就是其对应的context中span的start和end的位置 # 同理,将context扩展成 [b, query_len, context_len, embed_size] context_last_layer = context_last_layer.unsqueeze(dim=1).expand( -1, query_len, -1, -1).contiguous() # [b, query_len, context_len] context_expand_mask = context_mask.unsqueeze(dim=1).expand( -1, query_len, -1).contiguous() # 将上面3个部分拼接在一起 # 这里表示query中所有的position span_embed_size = context_last_layer.size(-1) if self.training and self._neg_sample_ratio > 0.0: # 对mask中0的位置进行采样 # [B*query_len, ] sample_mask_label = mask_label.view(-1) # 获取展开之后的长度以及需要采样的负样本的数量 mask_length = sample_mask_label.size(0) mask_sum = int( torch.sum(sample_mask_label).item() * self._neg_sample_ratio) mask_sum = max(10, mask_sum) # 获取需要采样的负样本的索引 neg_indexes = torch.randint(low=0, high=mask_length, size=(mask_sum, )) # 限制在长度范围内 neg_indexes = neg_indexes[:mask_length] # 将负样本对应的位置mask置为1 sample_mask_label[neg_indexes] = 1 # [B, query_len] use_mask_label = sample_mask_label.view( -1, query_len).to(dtype=torch.bool) # 过滤掉query中pad的部分, [B, query_len] use_mask_label = use_mask_label & query_mask span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1) # 选择context部分可以使用的内容 # [B_mask, context_len, span_embed_size] span_context_matrix = context_last_layer.masked_select( span_mask).view(-1, context_len, span_embed_size).contiguous() # 选择query部分可以使用的向量 span_query_vector = query_last_layer.masked_select( span_mask.squeeze(dim=-1)).view(-1, span_embed_size).contiguous() span_context_mask = context_expand_mask.masked_select( span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous() else: use_mask_label = query_mask span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1) # 选择context部分可以使用的内容 # [B_mask, context_len, span_embed_size] span_context_matrix = context_last_layer.masked_select( span_mask).view(-1, context_len, span_embed_size).contiguous() # 选择query部分可以使用的向量 span_query_vector = query_last_layer.masked_select( span_mask.squeeze(dim=-1)).view(-1, span_embed_size).contiguous() span_context_mask = context_expand_mask.masked_select( span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous() # 得到span属于每个位置的logits # [B_mask, context_len] span_start_probs = self.start_attention(span_query_vector, span_context_matrix, span_context_mask) span_end_probs = self.end_attention(span_query_vector, span_context_matrix, span_context_mask) span_start_logits = torch.log(span_start_probs + self._eps) span_end_logits = torch.log(span_end_probs + self._eps) # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置 best_spans = get_best_span(span_start_logits, span_end_logits) # 计算得到每个best_span的分数 best_span_scores = ( torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1))) # [B_mask, ] best_span_scores = best_span_scores.squeeze(1) # 将重要的信息写入到输出中 output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_spans": best_spans, "best_span_scores": best_span_scores } # 如果存在标签,则使用标签计算loss if start_label is not None: loss = self._calc_loss(span_start_logits, span_end_logits, use_mask_label, start_label, end_label, best_spans) output_dict["loss"] = loss if metadata is not None: predict_rewrite_results = self._get_rewrite_result( use_mask_label, best_spans, query_lens, context_lens, metadata) output_dict['rewrite_results'] = predict_rewrite_results return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = {} metrics["span_acc"] = self._span_accuracy.get_metric(reset) for metric in self._metrics: metrics.update(metric.get_metric(reset)) metrics.update(self._rewrite_em.get_metric(reset)) metrics.update(self._restore_score.get_metric(reset)) return metrics @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: new_output_dict = {} new_output_dict["rewrite_results"] = output_dict["rewrite_results"] return new_output_dict
def __init__(self, **kwargs): super(MySeq2Seq, self).__init__(**kwargs) self.acc = BooleanAccuracy()
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, attention_similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = MatrixAttention(attention_similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these # aren't necessarily obvious from the configuration files, so we check # here. if modeling_layer.get_input_dim() != 4 * encoding_dim: raise ConfigurationError( "The input dimension to the modeling_layer must be " "equal to 4 times the encoding dimension of the phrase_layer. " "Found {} and 4 * {} respectively.".format( modeling_layer.get_input_dim(), encoding_dim)) if text_field_embedder.get_output_dim() != phrase_layer.get_input_dim( ): raise ConfigurationError( "The output dimension of the text_field_embedder (embedding_dim + " "char_cnn) must match the input dimension of the phrase_encoder. " "Found {} and {}, respectively.".format( text_field_embedder.get_output_dim(), phrase_layer.get_input_dim())) if span_end_encoder.get_input_dim( ) != encoding_dim * 4 + modeling_dim * 3: raise ConfigurationError( "The input dimension of the span_end_encoder should be equal to " "4 * phrase_layer.output_dim + 3 * modeling_layer.output_dim. " "Found {} and (4 * {} + 3 * {}) " "respectively.".format(span_end_encoder.get_input_dim(), encoding_dim, modeling_dim)) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self)
class RNet(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, question_encoder: Seq2SeqEncoder, passage_encoder: Seq2SeqEncoder, pair_encoder: AttentionEncoder, self_encoder: AttentionEncoder, output_layer: QAOutputLayer, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, share_encoder: bool = False): super().__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder self.question_encoder = question_encoder self.passage_encoder = passage_encoder self.pair_encoder = pair_encoder self.self_encoder = self_encoder self.output_layer = output_layer self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self.share_encoder = share_encoder self.loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: question_embeded = self.text_field_embedder(question) passage_embeded = self.text_field_embedder(passage) question_mask = get_text_field_mask(question).byte() passage_mask = get_text_field_mask(passage).byte() quetion_encoded = self.question_encoder(question_embeded, question_mask) if self.share_encoder: passage_encoded = self.question_encoder(passage_embeded, passage_mask) else: passage_encoded = self.passage_encoder(passage_embeded, passage_mask) passage_encoded = self.pair_encoder(passage_encoded, passage_mask, quetion_encoded, question_mask) passage_encoded = self.self_encoder(passage_encoded, passage_mask, passage_encoded, passage_mask) span_start_logits, span_end_logits = self.output_layer( quetion_encoded, question_mask, passage_encoded, passage_mask) # Calculating loss and making prediction # Following code is copied from allennlp.models.BidirectionalAttentionFlow span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] batch_size = question_embeded.size(0) for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span
class TransformerQA(Model): def __init__(self, serialization_dir: str, pretrained_model: str, tokenizer_wrapper: HFTokenizerWrapper, enable_no_answer: bool = False, force_yes_no: bool = False, **kwargs) -> None: super().__init__(**kwargs) self._tokenizer_wrapper = tokenizer_wrapper self._enable_no_answer = enable_no_answer self.force_yes_no = force_yes_no self._qa_model = AutoModelForQuestionAnswering.from_pretrained( pretrained_model, return_dict=True) self._qa_model.resize_token_embeddings(len( tokenizer_wrapper.tokenizer)) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._boolq_accuracy = Squad2EmAndF1() self._per_instance_metrics = Squad2EmAndF1() # Initializer placeholder self._tokenizer_wrapper.tokenizer = self._tokenizer_wrapper.load( serialization_dir, pending=True) self._tokenizer_wrapper.save(serialization_dir) self._qa_model.resize_token_embeddings(len( tokenizer_wrapper.tokenizer)) def forward( # type: ignore self, question_with_context: Dict[str, Dict[str, torch.LongTensor]], context_span: torch.IntTensor, yes_no_span: torch.IntTensor = None, answer_span: Optional[torch.IntTensor] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters question_with_context : `Dict[str, torch.LongTensor]` From a ``TextField``. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and [SEP]) has type id 1. context_span : `torch.IntTensor` From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come. answer_span : `torch.IntTensor`, optional From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory. metadata : `List[Dict[str, Any]]`, optional If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of the ``metadata`` list should be the batch size, and each dictionary should have the keys ``id``, ``question``, ``context``, ``question_tokens``, ``context_tokens``, and ``answers``. # Returns An output dictionary consisting of: span_start_logits : `torch.FloatTensor` A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : `torch.FloatTensor` The result of `softmax(span_start_logits)`. span_end_logits : `torch.FloatTensor` A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : `torch.FloatTensor` The result of ``softmax(span_end_logits)``. best_span : `torch.IntTensor` The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. best_span_scores : `torch.FloatTensor` The score for each of the best spans. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. best_span_str : `List[str]` If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ outputs = self._qa_model(**question_with_context) span_start_logits = outputs["start_logits"] span_end_logits = outputs["end_logits"] with torch.no_grad(): possible_answer_mask = torch.zeros_like( question_with_context["input_ids"], dtype=torch.bool, ) if not self.force_yes_no: for i, (start, end) in enumerate(context_span): if start != -1 and end != -1: possible_answer_mask[i, start:end + 1] = True if yes_no_span is not None: for i, (start, end) in enumerate(yes_no_span): if start != -1 and end != -1: possible_answer_mask[i, start:end + 1] = True for i in range(len(possible_answer_mask)): assert any(possible_answer_mask[i]) # Replace the masked values with a very negative constant. context_masked_span_start_logits = replace_masked_values_with_big_negative_number( span_start_logits, possible_answer_mask) context_masked_span_end_logits = replace_masked_values_with_big_negative_number( span_end_logits, possible_answer_mask) best_spans = get_best_span(context_masked_span_start_logits, context_masked_span_end_logits) best_span_scores = torch.gather( context_masked_span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather( context_masked_span_end_logits, 1, best_spans[:, 1].unsqueeze(1)) best_span_scores = best_span_scores.squeeze(1) output_dict = { "best_span": best_spans, "best_span_scores": best_span_scores, "yes_scores": span_start_logits[:, yes_no_span[:, 0]] + span_end_logits[:, yes_no_span[:, 0]], "no_scores": span_start_logits[:, yes_no_span[:, 1]] + span_end_logits[:, yes_no_span[:, 1]], } if self._enable_no_answer: no_answer_scores = span_start_logits[:, 0] + span_end_logits[:, 0] output_dict.update({"no_answer_scores": no_answer_scores}) # Compute metrics and set loss if answer_span is not None: span_start = answer_span[:, 0] span_end = answer_span[:, 1] span_mask = span_start != -1 if self._enable_no_answer: span_mask &= span_start != 0 self._span_accuracy(best_spans, answer_span, span_mask.unsqueeze(-1).expand_as(best_spans)) self._span_start_accuracy(context_masked_span_start_logits, span_start, span_mask) self._span_end_accuracy(context_masked_span_end_logits, span_end, span_mask) if self._enable_no_answer: possible_answer_mask[:, 0] = True # Replace the masked values with a very negative constant. masked_span_start_logits = replace_masked_values_with_big_negative_number( span_start_logits, possible_answer_mask) masked_span_end_logits = replace_masked_values_with_big_negative_number( span_end_logits, possible_answer_mask) loss_fct = CrossEntropyLoss(ignore_index=-1) start_loss = loss_fct(masked_span_start_logits, span_start) end_loss = loss_fct(masked_span_end_logits, span_end) total_loss = (start_loss + end_loss) / 2 output_dict["loss"] = total_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: best_spans = best_spans.detach().cpu().numpy() output_dict["best_span_str"] = [] for i, (metadata_entry, best_span) in enumerate(zip(metadata, best_spans)): best_span_string = TokensInterpreter.extract_span_string_from_origin_texts( Span(*best_span), [ metadata_entry["modified_question"], metadata_entry["context"] ], metadata_entry["offset_mapping"], metadata_entry["special_tokens_mask"], ) if self.force_yes_no: if output_dict["yes_scores"][i].item( ) > output_dict["no_scores"][i].item(): overriding_best_span_string = "yes" else: overriding_best_span_string = "no" if overriding_best_span_string != best_span_string: best_span_string = overriding_best_span_string output_dict["best_span_str"].append(best_span_string) answers = metadata_entry.get("answers") if answers is not None and len(answers) > 0: if self._enable_no_answer: final_pred = (best_span_string if best_span_scores[i] > no_answer_scores[i] else "") else: final_pred = best_span_string if metadata_entry["is_boolq"]: self._boolq_accuracy(final_pred, answers) else: self._per_instance_metrics(final_pred, answers) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "boolq_acc": self._boolq_accuracy.get_metric(reset)["em"], } metrics.update(self._per_instance_metrics.get_metric(reset)) return metrics default_predictor = "transformer_qa_v2"
class MultiGranularityHierarchicalAttentionFusionNetworks(Model): def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, passage_self_attention: Seq2SeqEncoder, semantic_rep_layer: Seq2SeqEncoder, contextual_question_layer: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, regularizer: Optional[RegularizerApplicator] = None, initializer: InitializerApplicator = InitializerApplicator(), ): super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), # num_highway_layers)) self._encoding_dim = self._phrase_layer.get_output_dim() # self._atten_linear_layer = TimeDistributed(torch.nn.Linear(in_features=self._encoding_dim, # out_features=self._encoding_dim, bias=False)) self._atten_linear_layer = torch.nn.Linear( in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False) self._relu = torch.nn.ReLU() self._softmax_d1 = torch.nn.Softmax(dim=1) self._softmax_d2 = torch.nn.Softmax(dim=2) self._atten_fusion = FusionLayer(self._encoding_dim) self._tanh = torch.nn.Tanh() self._sigmoid = torch.nn.Sigmoid() self._passage_self_attention = passage_self_attention # self._self_atten_layer = SelfAttentionLayer(self._encoding_dim) self._self_atten_layer = torch.nn.Bilinear(self._encoding_dim, self._encoding_dim, self._encoding_dim, bias=False) self._self_atten_fusion = FusionLayer(self._encoding_dim) self._semantic_rep_layer = semantic_rep_layer self._contextual_question_layer = contextual_question_layer # self._vector_linear = TimeDistributed( # torch.nn.Linear(in_features=self._encoding_dim, out_features=1, bias=False)) # # self._model_layer_s = TimeDistributed( # torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False)) # self._model_layer_e = TimeDistributed( # torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False)) self._vector_linear = torch.nn.Linear(in_features=self._encoding_dim, out_features=1, bias=False) self._model_layer_s = torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False) self._model_layer_e = torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False) # self._model_layer_s = torch.nn.Bilinear(in1_features=self._encoding_dim, in2_features=self._encoding_dim, out_features=) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_question = self._text_field_embedder(question) question_mask = util.get_text_field_mask(question).float() # embedded_passage = self._highway_layer(self._text_field_embedder(passage)) embedded_passage = self._text_field_embedder(passage) passage_mask = util.get_text_field_mask(passage).float() batch_size = embedded_passage.size(0) passage_length = embedded_passage.size(1) question_length = embedded_question.size(1) question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None # Shape(batch_size, question_length, encoding_dim) u_q = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoding_dim = u_q.size(-1) # Shape(batch_size, passage_length, encoding_dim) u_p = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) u_q = self._relu(self._atten_linear_layer(u_q)) u_p = self._relu(self._atten_linear_layer(u_p)) # Shape(batch_size, question_length, passage_length) # S_{ij} computes the similarity(attention weights) # between the i_th word of the question and the j_th word of the passage s = torch.bmm(u_q, u_p.transpose(2, 1)) # Shape(batch_size, passage_length, encoding_dim) # P to Q q_ = attention_weight_sum_batch( util.masked_softmax(s.transpose(2, 1), passage_lstm_mask.unsqueeze(-1), dim=2), u_q) # Shape(batch_size, question_length, encoding_dim) # Q tot P p_ = attention_weight_sum_batch( util.masked_softmax(s, question_lstm_mask.unsqueeze(-1), dim=2), u_p) pp = self._atten_fusion(u_p, q_) # Shape(batch_size, question_length, encoding_dim) qq = self._atten_fusion(u_q, p_) # Shape(batch_size, passage_length, encoding_dim) d = self._passage_self_attention(pp, passage_lstm_mask) # Shape(batch_size, passage_length, encoding_dim) l = self._self_atten_layer(d, d) l = self._softmax_d2(l) # Shape(batch_size, passage_length, encoding_dim) d_ = l * d # Shape(batch_size, passage_length, encoding_dim) dd = self._self_atten_fusion(d, d_) # Shape(batch_size, passage_length, encoding_dim) ddd = self._semantic_rep_layer(dd, passage_lstm_mask) # Shape(batch_size, question_length, encoding_dim) qqq = self._contextual_question_layer(qq, question_lstm_mask) # Shape(batch_size, question_length, 1) -> (batch_size, question_length) # gamma = util.masked_softmax(self._vector_linear(qqq), question_lstm_mask.unsqueeze(-1), dim=2).squeeze(-1) qqq_tmp = self._vector_linear(qqq).squeeze(-1) gamma = self._softmax_d1(qqq_tmp) # Shape(batch_size, question_length) # (1, question_length) ` (question_length, encoding_dim) vec_q = torch.bmm(gamma.unsqueeze(1), qqq) # model & output layer # Shape(batch_size, 1, passage_length) vec_q_tmp = self._model_layer_s(vec_q) p_start = util.masked_softmax(torch.bmm(vec_q_tmp, ddd.transpose(2, 1)).squeeze(1), passage_lstm_mask, dim=1) # p_start = torch.bmm(vec_q_tmp, ddd.transpose(2, 1)).squeeze(1) # p_start = self._softmax_d1(p_start) span_start_logits = p_start # span_start_probs = util.masked_softmax(span_start_logits, passage_lstm_mask) # p_end = self._end_vector_matrix_bilinear(vec_q, ddd.permute(0, 2, 1)) p_end = util.masked_softmax(torch.bmm(self._model_layer_e(vec_q), ddd.transpose(2, 1)).squeeze(1), passage_lstm_mask, dim=1) span_end_logits = p_end # span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, 1e-7) # span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, 1e-7) # span_end_probs = util.masked_softmax(span_end_logits, passage_lstm_mask) best_span = self.get_best_span(span_start_logits, span_end_logits) print("span_start_logits") print(span_start_logits) print("span_end_logits") print(span_end_logits) output = dict() output['best_span'] = best_span # Compute the loss for training if span_start is not None: # loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) # self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) # loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) # self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) # self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) loss = self._loss(span_start_logits, span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += self._loss(span_end_logits, span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) print(loss) output['loss'] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output['question_tokens'] = question_tokens output['passage_tokens'] = passage_tokens return output def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span
class EdgeProbingTask(Task): """ Generic class for fine-grained edge probing. Acts as a classifier, but with multiple targets for each input text. Targets are of the form (span1, span2, label), where span1 and span2 are half-open token intervals [i, j). Subclass this for each dataset, or use register_task with appropriate kw args. """ @property def _tokenizer_suffix(self): """ Suffix to make sure we use the correct source files, based on the given tokenizer. """ if self.tokenizer_name: return ".retokenized." + self.tokenizer_name else: return "" def tokenizer_is_supported(self, tokenizer_name): """ Check if the tokenizer is supported for this task. """ # Assume all tokenizers supported; if retokenized data not found # for this particular task, we'll just crash on file loading. return True def __init__( self, path: str, max_seq_len: int, name: str, label_file: str = None, files_by_split: Dict[str, str] = None, is_symmetric: bool = False, single_sided: bool = False, **kw, ): """Construct an edge probing task. path, max_seq_len, and name are passed by the code in preprocess.py; remaining arguments should be provided by a subclass constructor or via @register_task. Args: path: data directory max_seq_len: maximum sequence length (currently ignored) name: task name label_file: relative path to labels file files_by_split: split name ('train', 'val', 'test') mapped to relative filenames (e.g. 'train': 'train.json') is_symmetric: if true, span1 and span2 are assumed to be the same type and share parameters. Otherwise, we learn a separate projection layer and attention weight for each. single_sided: if true, only use span1. """ super().__init__(name, **kw) assert label_file is not None assert files_by_split is not None self._files_by_split = { split: os.path.join(path, fname) + self._tokenizer_suffix for split, fname in files_by_split.items() } self.path = path self.label_file = label_file self.max_seq_len = max_seq_len self.is_symmetric = is_symmetric self.single_sided = single_sided self._iters_by_split = None self.all_labels = None self.n_classes = None # see add_task_label_namespace in preprocess.py self._label_namespace = self.name + "_labels" # Scorers # self.acc_scorer = CategoricalAccuracy() # multiclass accuracy self.mcc_scorer = FastMatthews() self.acc_scorer = BooleanAccuracy() # binary accuracy self.f1_scorer = F1Measure(positive_label=1) # binary F1 overall self.val_metric = "%s_f1" % self.name # TODO: switch to MCC? self.val_metric_decreases = False def load_data(self): label_file = os.path.join(self.path, self.label_file) self.all_labels = list(utils.load_lines(label_file)) self.n_classes = len(self.all_labels) @classmethod def _stream_records(cls, filename): skip_ctr = 0 total_ctr = 0 for record in utils.load_json_data(filename): total_ctr += 1 # Skip records with empty targets. # TODO(ian): don't do this if generating negatives! if not record.get("targets", None): skip_ctr += 1 continue yield record log.info( "Read=%d, Skip=%d, Total=%d from %s", total_ctr - skip_ctr, skip_ctr, total_ctr, filename, ) @staticmethod def merge_preds(record: Dict, preds: Dict) -> Dict: """ Merge predictions into record, in-place. List-valued predictions should align to targets, and are attached to the corresponding target entry. Non-list predictions are attached to the top-level record. """ record["preds"] = {} for target in record["targets"]: target["preds"] = {} for key, val in preds.items(): if isinstance(val, list): assert len(val) == len(record["targets"]) for i, target in enumerate(record["targets"]): target["preds"][key] = val[i] else: # non-list predictions, attach to top-level preds record["preds"][key] = val return record def load_data(self): iters_by_split = collections.OrderedDict() for split, filename in self._files_by_split.items(): # # Lazy-load using RepeatableIterator. # loader = functools.partial(utils.load_json_data, # filename=filename) # iter = serialize.RepeatableIterator(loader) iter = list(self._stream_records(filename)) iters_by_split[split] = iter self._iters_by_split = iters_by_split def get_split_text(self, split: str): """ Get split text as iterable of records. Split should be one of 'train', 'val', or 'test'. """ return self._iters_by_split[split] @classmethod def get_num_examples(cls, split_text): """ Return number of examples in the result of get_split_text. Subclass can override this if data is not stored in column format. """ return len(split_text) @classmethod def _make_span_field(cls, s, text_field, offset=1): return SpanField(s[0] + offset, s[1] - 1 + offset, text_field) def _pad_tokens(self, tokens): """Pad tokens according to the current tokenization style.""" if self.tokenizer_name.startswith("bert-"): # standard padding for BERT; see # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/extract_features.py#L85 # noqa return ["[CLS]"] + tokens + ["[SEP]"] else: return [utils.SOS_TOK] + tokens + [utils.EOS_TOK] def make_instance(self, record, idx, indexers) -> Type[Instance]: """Convert a single record to an AllenNLP Instance.""" tokens = record["text"].split() # already space-tokenized by Moses tokens = self._pad_tokens(tokens) text_field = sentence_to_text_field(tokens, indexers) d = {} d["idx"] = MetadataField(idx) d["input1"] = text_field d["span1s"] = ListField([ self._make_span_field(t["span1"], text_field, 1) for t in record["targets"] ]) if not self.single_sided: d["span2s"] = ListField([ self._make_span_field(t["span2"], text_field, 1) for t in record["targets"] ]) # Always use multilabel targets, so be sure each label is a list. labels = [ utils.wrap_singleton_string(t["label"]) for t in record["targets"] ] d["labels"] = ListField([ MultiLabelField(label_set, label_namespace=self._label_namespace, skip_indexing=False) for label_set in labels ]) return Instance(d) def process_split(self, records, indexers) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ def _map_fn(r, idx): return self.make_instance(r, idx, indexers) return map(_map_fn, records, itertools.count()) def get_all_labels(self) -> List[str]: return self.all_labels def get_sentences(self) -> Iterable[Sequence[str]]: """ Yield sentences, used to compute vocabulary. """ for split, iter in self._iters_by_split.items(): # Don't use test set for vocab building. if split.startswith("test"): continue for record in iter: yield record["text"].split() def get_metrics(self, reset=False): """Get metrics specific to the task""" metrics = {} metrics["mcc"] = self.mcc_scorer.get_metric(reset) metrics["acc"] = self.acc_scorer.get_metric(reset) precision, recall, f1 = self.f1_scorer.get_metric(reset) metrics["precision"] = precision metrics["recall"] = recall metrics["f1"] = f1 return metrics
class CMVDiscriminator(FeedForward): def __init__(self, input_dim: int, num_layers: int, hidden_dims: Union[int, Sequence[int]], activations: Union[Activation, Sequence[Activation]], dropout: Union[float, Sequence[float]] = 0.0, gate_bias: float = -2) -> None: super(CMVDiscriminator, self).__init__(input_dim, num_layers, hidden_dims, activations, dropout) if not isinstance(hidden_dims, list): hidden_dims = [hidden_dims] * (num_layers - 1) input_dims = hidden_dims[1:] gate_layers = [None] #so we can zip this later for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims): gate_layer = torch.nn.Linear(layer_input_dim, layer_output_dim) gate_layer.bias.data.fill_(gate_bias) gate_layers.append(gate_layer) self._gate_layers = torch.nn.ModuleList(gate_layers) #feedforward requires an Activation so we just use the identity self._output_feedforward = FeedForward(hidden_dims[-1], 1, 1, lambda x: x) self._accuracy = BooleanAccuracy() def _get_hidden(self, output): layers = list( zip(self._linear_layers, self._activations, self._dropout, self._gate_layers)) layer, activation, dropout, _ = layers[0] output = dropout(activation(layer(output))) for layer, activation, dropout, gate in layers[1:]: gate_output = torch.sigmoid(gate(output)) new_output = dropout(activation(layer(output))) output = torch.add(torch.mul(gate_output, new_output), torch.mul(1 - gate_output, output)) return output def forward(self, real_output, fake_output=None): real_hidden = self._get_hidden(real_output) real_value = self._output_feedforward(real_hidden) labels = torch.ones(real_hidden.size(0)) if torch.cuda.is_available() and real_value.is_cuda: idx = real_value.get_device() labels = labels.cuda(idx) loss = torch.nn.functional.binary_cross_entropy_with_logits( real_value.view(-1), labels) predictions = torch.sigmoid(real_value) > 0.5 if fake_output is not None: fake_hidden = self._get_hidden(fake_output) fake_value = self._output_feedforward(fake_hidden) fake_labels = torch.zeros(fake_hidden.size(0)) if torch.cuda.is_available() and fake_value.is_cuda: idx = fake_value.get_device() fake_labels = fake_labels.cuda(idx) loss += torch.nn.functional.binary_cross_entropy_with_logits( fake_value.view(-1), fake_labels) predictions = torch.cat( [predictions, torch.sigmoid(fake_value) > 0.5]) labels = torch.cat([labels, fake_labels]) self._accuracy(predictions, labels.byte()) return {'loss': loss, 'predictions': predictions, 'labels': labels} def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'accuracy': self._accuracy.get_metric(reset)}
class BertQA(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, sim_text_field_embedder: TextFieldEmbedder, loss_weights: Dict, sim_class_weights: List, pretrained_sim_path: str = None, use_scenario_encoding: bool = True, sim_pretraining: bool = False, dropout: float = 0.2, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BertQA, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder if use_scenario_encoding: self._sim_text_field_embedder = sim_text_field_embedder self.loss_weights = loss_weights self.sim_class_weights = sim_class_weights self.use_scenario_encoding = use_scenario_encoding self.sim_pretraining = sim_pretraining if self.sim_pretraining and not self.use_scenario_encoding: raise ValueError( "When pretraining Scenario Interpretation Module, you should use it." ) embedding_dim = self._text_field_embedder.get_output_dim() self._action_predictor = torch.nn.Linear(embedding_dim, 4) self._sim_token_label_predictor = torch.nn.Linear(embedding_dim, 4) self._span_predictor = torch.nn.Linear(embedding_dim, 2) self._action_accuracy = CategoricalAccuracy() self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self._span_loss_metric = Average() self._action_loss_metric = Average() self._sim_loss_metric = Average() self._sim_yes_f1 = F1Measure(2) self._sim_no_f1 = F1Measure(3) if use_scenario_encoding and pretrained_sim_path is not None: logger.info("Loading pretrained model..") self.load_state_dict(torch.load(pretrained_sim_path)) for param in self._sim_text_field_embedder.parameters(): param.requires_grad = False if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializer(self) def get_passage_representation(self, bert_output, bert_input): # Shape: (batch_size, bert_input_len) input_type_ids = self.get_input_type_ids( bert_input['bert-type-ids'], bert_input['bert-offsets'], self._text_field_embedder._token_embedders['bert']).float() # Shape: (batch_size, bert_input_len) input_mask = util.get_text_field_mask(bert_input).float() passage_mask = input_mask - input_type_ids # works only with one [SEP] # Shape: (batch_size, bert_input_len, embedding_dim) passage_representation = bert_output * passage_mask.unsqueeze(2) # Shape: (batch_size, passage_len, embedding_dim) passage_representation = passage_representation[:, passage_mask.sum( dim=0) > 0, :] # Shape: (batch_size, passage_len) passage_mask = passage_mask[:, passage_mask.sum(dim=0) > 0] return passage_representation, passage_mask def forward( self, # type: ignore bert_input: Dict[str, torch.LongTensor], sim_bert_input: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ if self.use_scenario_encoding: # Shape: (batch_size, sim_bert_input_len_wp) sim_bert_input_token_labels_wp = sim_bert_input[ 'scenario_gold_encoding'] # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim) sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input) # Shape: (batch_size, sim_bert_input_len_wp) sim_input_mask_wp = (sim_bert_input['bert'] != 0).float() # Shape: (batch_size, sim_bert_input_len_wp) sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[ 'bert-type-ids'].float() # works only with one [SEP] # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim) sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze( 2) # Shape: (batch_size, passage_len_wp, embedding_dim) sim_passage_representation_wp = sim_passage_representation_wp[:, sim_passage_mask_wp .sum( dim =0 ) > 0, :] # Shape: (batch_size, passage_len_wp) sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:, sim_passage_mask_wp .sum( dim =0 ) > 0] # Shape: (batch_size, passage_len_wp) sim_passage_mask_wp = sim_passage_mask_wp[:, sim_passage_mask_wp.sum( dim=0) > 0] # Shape: (batch_size, passage_len_wp, 4) sim_token_logits_wp = self._sim_token_label_predictor( sim_passage_representation_wp) if span_start is not None: # during training and validation class_weights = torch.tensor(self.sim_class_weights, device=sim_token_logits_wp.device, dtype=torch.float) sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4), sim_passage_token_labels_wp.view(-1), ignore_index=0, weight=class_weights) self._sim_loss_metric(sim_loss.item()) self._sim_yes_f1(sim_token_logits_wp, sim_passage_token_labels_wp, sim_passage_mask_wp) self._sim_no_f1(sim_token_logits_wp, sim_passage_token_labels_wp, sim_passage_mask_wp) if self.sim_pretraining: return {'loss': sim_loss} if not self.sim_pretraining: # Shape: (batch_size, passage_len_wp) bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax( dim=2)) * sim_passage_mask_wp.long() # Shape: (batch_size, bert_input_len_wp) bert_input_wp_len = bert_input['history_encoding'].size(1) if bert_input['scenario_encoding'].size(1) > bert_input_wp_len: # Shape: (batch_size, bert_input_len_wp) bert_input['scenario_encoding'] = bert_input[ 'scenario_encoding'][:, :bert_input_wp_len] else: batch_size = bert_input['scenario_encoding'].size(0) difference = bert_input_wp_len - bert_input[ 'scenario_encoding'].size(1) zeros = torch.zeros( batch_size, difference, dtype=bert_input['scenario_encoding'].dtype, device=bert_input['scenario_encoding'].device) # Shape: (batch_size, bert_input_len_wp) bert_input['scenario_encoding'] = torch.cat( [bert_input['scenario_encoding'], zeros], dim=1) # Shape: (batch_size, bert_input_len + 1, embedding_dim) bert_output = self._text_field_embedder(bert_input) # Shape: (batch_size, embedding_dim) pooled_output = bert_output[:, 0] # Shape: (batch_size, bert_input_len, embedding_dim) bert_output = bert_output[:, 1:, :] # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len) passage_representation, passage_mask = self.get_passage_representation( bert_output, bert_input) # Shape: (batch_size, 4) action_logits = self._action_predictor(pooled_output) # Shape: (batch_size, passage_len, 2) span_logits = self._span_predictor(passage_representation) # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1) span_start_logits, span_end_logits = span_logits.split(1, dim=2) # Shape: (batch_size, passage_len) span_start_logits = span_start_logits.squeeze(2) # Shape: (batch_size, passage_len) span_end_logits = span_end_logits.squeeze(2) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "pooled_output": pooled_output, "passage_representation": passage_representation, "action_logits": action_logits, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } if self.use_scenario_encoding: output_dict["sim_token_logits"] = sim_token_logits_wp # Compute the loss for training (and for validation) if span_start is not None: # Shape: (batch_size,) span_loss = nll_loss(util.masked_log_softmax( span_start_logits, passage_mask), span_start.squeeze(1), reduction='none') # Shape: (batch_size,) span_loss += nll_loss(util.masked_log_softmax( span_end_logits, passage_mask), span_end.squeeze(1), reduction='none') # Shape: (batch_size,) more_mask = (label == self.vocab.get_token_index( 'More', namespace="labels")).float() # Shape: (batch_size,) span_loss = (span_loss * more_mask).sum() / (more_mask.sum() + 1e-6) if more_mask.sum() > 1e-7: self._span_start_accuracy(span_start_logits, span_start.squeeze(1), more_mask) self._span_end_accuracy(span_end_logits, span_end.squeeze(1), more_mask) # Shape: (batch_size, 2) span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long() self._span_accuracy(best_span, torch.cat([span_start, span_end], dim=1), span_acc_mask) action_loss = cross_entropy(action_logits, label) self._action_accuracy(action_logits, label) self._span_loss_metric(span_loss.item()) self._action_loss_metric(action_loss.item()) output_dict['loss'] = self.loss_weights[ 'span_loss'] * span_loss + self.loss_weights[ 'action_loss'] * action_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if not self.training: # true during validation and test output_dict['best_span_str'] = [] batch_size = len(metadata) for i in range(batch_size): passage_text = metadata[i]['passage_text'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_str = passage_text[start_offset:end_offset] output_dict['best_span_str'].append(best_span_str) if 'gold_span' in metadata[i]: if metadata[i]['action'] == 'More': gold_span = metadata[i]['gold_span'] self._squad_metrics(best_span_str, [gold_span]) return output_dict def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: action_probs = softmax(output_dict['action_logits'], dim=1) output_dict['action_probs'] = action_probs predictions = action_probs.cpu().data.numpy() argmax_indices = numpy.argmax(predictions, axis=1) labels = [ self.vocab.get_token_from_index(x, namespace="labels") for x in argmax_indices ] output_dict['label'] = labels return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: if self.use_scenario_encoding: sim_loss = self._sim_loss_metric.get_metric(reset) _, _, yes_f1 = self._sim_yes_f1.get_metric(reset) _, _, no_f1 = self._sim_no_f1.get_metric(reset) if self.sim_pretraining: return {'sim_macro_f1': (yes_f1 + no_f1) / 2} try: action_acc = self._action_accuracy.get_metric(reset) except ZeroDivisionError: action_acc = 0 try: start_acc = self._span_start_accuracy.get_metric(reset) except ZeroDivisionError: start_acc = 0 try: end_acc = self._span_end_accuracy.get_metric(reset) except ZeroDivisionError: end_acc = 0 try: span_acc = self._span_accuracy.get_metric(reset) except ZeroDivisionError: span_acc = 0 exact_match, f1_score = self._squad_metrics.get_metric(reset) span_loss = self._span_loss_metric.get_metric(reset) action_loss = self._action_loss_metric.get_metric(reset) agg_metric = span_acc + action_acc * 0.45 metrics = { 'action_acc': action_acc, 'span_acc': span_acc, 'span_loss': span_loss, 'action_loss': action_loss, 'agg_metric': agg_metric } if self.use_scenario_encoding: metrics['sim_macro_f1'] = (yes_f1 + no_f1) / 2 if not self.training: # during validation metrics['em'] = exact_match metrics['f1'] = f1_score return metrics @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: # We call the inputs "logits" - they could either be unnormalized logits or normalized log # probabilities. A log_softmax operation is a constant shifting of the entire logit # vector, so taking an argmax over either one gives the same result. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() device = span_start_logits.device # (batch_size, passage_length, passage_length) span_log_probs = span_start_logits.unsqueeze( 2) + span_end_logits.unsqueeze(1) # Only the upper triangle of the span matrix is valid; the lower triangle has entries where # the span ends before it starts. span_log_mask = torch.triu( torch.ones((passage_length, passage_length), device=device)).log().unsqueeze(0) valid_span_log_probs = span_log_probs + span_log_mask # Here we take the span matrix and flatten it, then find the best span using argmax. We # can recover the start and end indices from this flattened list using simple modular # arithmetic. # (batch_size, passage_length * passage_length) best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) span_start_indices = best_spans // passage_length span_end_indices = best_spans % passage_length return torch.stack([span_start_indices, span_end_indices], dim=-1) def get_input_type_ids(self, type_ids, offsets, embedder): "Converts (bsz, seq_len_wp) to (bsz, seq_len_wp) by indexing." batch_size = type_ids.size(0) full_seq_len = type_ids.size(1) if full_seq_len > embedder.max_pieces: # Recombine if we had used sliding window approach assert batch_size == 1 and type_ids.max() > 0 num_question_tokens = type_ids[0][:embedder.max_pieces].nonzero( ).size(0) select_indices = embedder.indices_to_select( full_seq_len, num_question_tokens) type_ids = type_ids[:, select_indices] range_vector = util.get_range_vector( batch_size, device=util.get_device_of(type_ids)).unsqueeze(1) type_ids = type_ids[range_vector, offsets] return type_ids
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, matrix_attention: MatrixAttention, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = matrix_attention self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match( modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim", ) check_dimensions_match( text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim", ) check_dimensions_match( span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim", ) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self)
class QaNet(Model): """ This class implements Adams Wei Yu's `QANet Model <https://openreview.net/forum?id=B14TlG-RW>`_ for machine reading comprehension published at ICLR 2018. The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the modeling layer and output layer. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the passage-question attention. matrix_attention_layer : ``MatrixAttention`` The matrix attention function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. dropout_prob : ``float``, optional (default=0.1) If greater than 0, we will apply dropout with this probability between layers. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, matrix_attention_layer: MatrixAttention, modeling_layer: Seq2SeqEncoder, dropout_prob: float = 0.1, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizer) text_embed_dim = text_field_embedder.get_output_dim() encoding_in_dim = phrase_layer.get_input_dim() encoding_out_dim = phrase_layer.get_output_dim() modeling_in_dim = modeling_layer.get_input_dim() modeling_out_dim = modeling_layer.get_output_dim() self._text_field_embedder = text_field_embedder self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim) self._highway_layer = Highway(encoding_in_dim, num_highway_layers) self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim, encoding_in_dim) self._phrase_layer = phrase_layer self._matrix_attention = matrix_attention_layer self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4, modeling_in_dim) self._modeling_layer = modeling_layer self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1) self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._metrics = SquadEmAndF1() self._dropout = torch.nn.Dropout(p=dropout_prob) if dropout_prob > 0 else lambda x: x initializer(self) def forward( # type: ignore self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ question_mask = util.get_text_field_mask(question) passage_mask = util.get_text_field_mask(passage) embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer(embedded_question) projected_embedded_passage = self._encoding_proj_layer(embedded_passage) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_mask) ) encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_mask) ) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True ) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True ) # Shape: (batch_size, passage_length, passage_length) attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat( [ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors, ], dim=-1, ) ) modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)] for _ in range(3): modeled_passage = self._dropout( self._modeling_layer(modeled_passage_list[-1], passage_mask) ) modeled_passage_list.append(modeled_passage) # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1) ) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1) ) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.cat([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict["best_span_str"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]["question_tokens"]) passage_tokens.append(metadata[i]["passage_tokens"]) passage_str = metadata[i]["original_passage"] offsets = metadata[i]["token_offsets"] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict["best_span_str"].append(best_span_string) answer_texts = metadata[i].get("answer_texts", []) if answer_texts: self._metrics(best_span_string, answer_texts) output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._metrics.get_metric(reset) return { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "em": exact_match, "f1": f1_score, }
class BERT_QA(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, dropout: float = 0.0, max_span_length: int = 30, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._max_span_length = max_span_length self.qa_outputs = torch.nn.Linear( self._text_field_embedder.get_output_dim(), 2) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._span_qa_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], context: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # the `context` is the concact of `question` and `passage`, so we just use `context` batch_size, num_of_passage_tokens = context['tokens'].size() # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of # start and end spans. embedded_passage = self._text_field_embedder(context) passage_length = embedded_passage.size(1) logits = self.qa_outputs(embedded_passage) start_logits, end_logits = logits.split(1, dim=-1) span_start_logits = start_logits.squeeze(-1) span_end_logits = end_logits.squeeze(-1) # Adding some masks with numerically stable values passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1) repeated_passage_mask = repeated_passage_mask.view( batch_size, passage_length) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) span_start_probs = util.masked_softmax(span_start_logits, repeated_passage_mask) span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) span_end_probs = util.masked_softmax(span_end_logits, repeated_passage_mask) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict: Dict[str, Any] = {} output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.cat([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on span qa and add the tokenized input to the output. if metadata is not None: output_dict["best_span_str"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]["question_tokens"]) passage_tokens.append(metadata[i]["passage_tokens"]) passage_words = metadata[i]["paragraph_words"] answer_offset = metadata[i]["answer_offset"] tok_to_word_index = metadata[i]["tok_to_word_index"] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_position = tok_to_word_index[predicted_span[0] - answer_offset] end_position = tok_to_word_index[predicted_span[1] - answer_offset] best_span_str = " ".join( passage_words[start_position:end_position + 1]) output_dict["best_span_str"].append(best_span_str) answer_text = metadata[i].get("answer_text", []) if answer_text: answer_text = [answer_text] self._span_qa_metrics(best_span_str, answer_text) output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._span_qa_metrics.get_metric(reset) return { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "em": exact_match, "f1": f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: # We call the inputs "logits" - they could either be unnormalized logits or normalized log # probabilities. A log_softmax operation is a constant shifting of the entire logit # vector, so taking an argmax over either one gives the same result. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() device = span_start_logits.device # (batch_size, passage_length, passage_length) span_log_probs = span_start_logits.unsqueeze( 2) + span_end_logits.unsqueeze(1) # Only the upper triangle of the span matrix is valid; the lower triangle has entries where # the span ends before it starts. span_log_mask = (torch.triu( torch.ones((passage_length, passage_length), device=device)).log().unsqueeze(0)) valid_span_log_probs = span_log_probs + span_log_mask # Here we take the span matrix and flatten it, then find the best span using argmax. We # can recover the start and end indices from this flattened list using simple modular # arithmetic. # (batch_size, passage_length * passage_length) best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) span_start_indices = best_spans // passage_length span_end_indices = best_spans % passage_length return torch.stack([span_start_indices, span_end_indices], dim=-1)
class RobertaSpanPredictionModel(Model): """ """ def __init__(self, vocab: Vocabulary, pretrained_model: str = None, requires_grad: bool = True, transformer_weights_model: str = None, layer_freeze_regexes: List[str] = None, on_load: bool = False, regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) if on_load: logging.info(f"Skipping loading of initial Transformer weights") transformer_config = RobertaConfig.from_pretrained( pretrained_model) self._transformer_model = RobertaModel(transformer_config) elif transformer_weights_model: logging.info( f"Loading Transformer weights model from {transformer_weights_model}" ) transformer_model_loaded = load_archive(transformer_weights_model) self._transformer_model = transformer_model_loaded.model._transformer_model else: self._transformer_model = RobertaModel.from_pretrained( pretrained_model) for name, param in self._transformer_model.named_parameters(): grad = requires_grad if layer_freeze_regexes and grad: grad = not any( [bool(re.search(r, name)) for r in layer_freeze_regexes]) param.requires_grad = grad transformer_config = self._transformer_model.config num_labels = 2 # For start/end self.qa_outputs = Linear(transformer_config.hidden_size, num_labels) # Import GTP2 machinery to get from tokens to actual text self.byte_decoder = {v: k for k, v in bytes_to_unicode().items()} self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self._debug = 2 self._padding_value = 1 # The index of the RoBERTa padding token def forward(self, tokens: Dict[str, torch.LongTensor], segment_ids: torch.LongTensor = None, start_positions: torch.LongTensor = None, end_positions: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> torch.Tensor: self._debug -= 1 input_ids = tokens['tokens'] batch_size = input_ids.size(0) num_choices = input_ids.size(1) tokens_mask = (input_ids != self._padding_value).long() if self._debug > 0: print(f"batch_size = {batch_size}") print(f"num_choices = {num_choices}") print(f"tokens_mask = {tokens_mask}") print(f"input_ids.size() = {input_ids.size()}") print(f"input_ids = {input_ids}") print(f"segment_ids = {segment_ids}") print(f"start_positions = {start_positions}") print(f"end_positions = {end_positions}") # Segment ids are not used by RoBERTa transformer_outputs = self._transformer_model( input_ids=input_ids, # token_type_ids=segment_ids, attention_mask=tokens_mask) sequence_output = transformer_outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) span_start_logits = util.replace_masked_values(start_logits, tokens_mask, -1e7) span_end_logits = util.replace_masked_values(end_logits, tokens_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) span_start_probs = util.masked_softmax(span_start_logits, tokens_mask) span_end_probs = util.masked_softmax(span_end_logits, tokens_mask) output_dict = { "start_logits": start_logits, "end_logits": end_logits, "best_span": best_span } output_dict["start_probs"] = span_start_probs output_dict["end_probs"] = span_end_probs if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) self._span_start_accuracy(span_start_logits, start_positions) self._span_end_accuracy(span_end_logits, end_positions) self._span_accuracy( best_span, torch.cat([ start_positions.unsqueeze(-1), end_positions.unsqueeze(-1) ], -1)) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index) # Should we mask out invalid positions here? start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 output_dict["loss"] = total_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['exact_match'] = [] output_dict['f1_score'] = [] tokens_texts = [] for i in range(batch_size): tokens_text = metadata[i]['tokens'] tokens_texts.append(tokens_text) predicted_span = tuple(best_span[i].detach().cpu().numpy()) predicted_start = predicted_span[0] predicted_end = predicted_span[1] predicted_tokens = tokens_text[predicted_start:(predicted_end + 1)] best_span_string = self.convert_tokens_to_string( predicted_tokens) output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) exact_match = 0 f1_score = 0 if answer_texts: exact_match, f1_score = self._squad_metrics( best_span_string, answer_texts) output_dict['exact_match'].append(exact_match) output_dict['f1_score'].append(f1_score) output_dict['tokens_texts'] = tokens_texts if self._debug > 0: print(f"output_dict = {output_dict}") return output_dict def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ text = ''.join(tokens) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace') return text def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @classmethod def _load(cls, config: Params, serialization_dir: str, weights_file: str = None, cuda_device: int = -1, **kwargs) -> 'Model': model_params = config.get('model') model_params.update({"on_load": True}) config.update({'model': model_params}) return super()._load(config=config, serialization_dir=serialization_dir, weights_file=weights_file, cuda_device=cuda_device, **kwargs)
def __init__(self, vocab: Vocabulary, model_name: str = None, start_attention: Attention = None, end_attention: Attention = None, text_field_embedder: TextFieldEmbedder = None, task_pretrained_file: str = None, neg_sample_ratio: float = 0.0, max_turn_len: int = 3, start_token: str = "[CLS]", end_token: str = "[SEP]", index_name: str = "bert", eps: float = 1e-8, seed: int = 42, loss_factor: float = 1.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None): super().__init__(vocab, regularizer) if model_name is None and text_field_embedder is None: raise ValueError( f"`model_name` and `text_field_embedder` can't both equal to None." ) # 单纯的resolution任务,只需要返回最后一层的embedding表征即可 self._text_field_embedder = text_field_embedder or PretrainedChineseBertMismatchedEmbedder( model_name, return_all=False, output_hidden_states=False, max_turn_length=max_turn_len) seed_everything(seed) self._neg_sample_ratio = neg_sample_ratio self._start_token = start_token self._end_token = end_token self._index_name = index_name self._initializer = initializer linear_input_size = self._text_field_embedder.get_output_dim() # 使用attention的方法 self.start_attention = start_attention or BilinearAttention( vector_dim=linear_input_size, matrix_dim=linear_input_size) self.end_attention = end_attention or BilinearAttention( vector_dim=linear_input_size, matrix_dim=linear_input_size) # mask的指标,主要考虑F-score,而且我们更加关注`1`的召回率 self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._rewrite_em = RewriteEM(valid_keys="semr,nr_semr,re_semr") self._restore_score = RestorationScore(compute_restore_tokens=True) self._metrics = [ TokenBasedBLEU(mode="1,2"), TokenBasedROUGE(mode="1r,2r") ] self._eps = eps self._loss_factor = loss_factor self._initializer(self.start_attention) self._initializer(self.end_attention) # 加载其他任务预训练的模型 if task_pretrained_file is not None and os.path.isfile( task_pretrained_file): logger.info("loading related task pretrained weights...") self.load_state_dict(torch.load(task_pretrained_file), strict=False)
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span
class BertMCQAModel(Model): """ """ def __init__(self, vocab: Vocabulary, pretrained_model: str = None, requires_grad: bool = True, top_layer_only: bool = True, bert_weights_model: str = None, per_choice_loss: bool = False, layer_freeze_regexes: List[str] = None, regularizer: Optional[RegularizerApplicator] = None, use_comparative_bert: bool = True, use_bilinear_classifier: bool = False, train_comparison_layer: bool = False, number_of_choices_compared: int = 0, comparison_layer_hidden_size: int = -1, comparison_layer_use_relu: bool = True) -> None: super().__init__(vocab, regularizer) self._use_comparative_bert = use_comparative_bert self._use_bilinear_classifier = use_bilinear_classifier self._train_comparison_layer = train_comparison_layer if train_comparison_layer: assert number_of_choices_compared > 1 self._num_choices = number_of_choices_compared self._comparison_layer_hidden_size = comparison_layer_hidden_size self._comparison_layer_use_relu = comparison_layer_use_relu # Bert weights and config if bert_weights_model: logging.info(f"Loading BERT weights model from {bert_weights_model}") bert_model_loaded = load_archive(bert_weights_model) self._bert_model = bert_model_loaded.model._bert_model else: self._bert_model = BertModel.from_pretrained(pretrained_model) for param in self._bert_model.parameters(): param.requires_grad = requires_grad #for name, param in self._bert_model.named_parameters(): # grad = requires_grad # if layer_freeze_regexes and grad: # grad = not any([bool(re.search(r, name)) for r in layer_freeze_regexes]) # param.requires_grad = grad bert_config = self._bert_model.config self._output_dim = bert_config.hidden_size self._dropout = torch.nn.Dropout(bert_config.hidden_dropout_prob) self._per_choice_loss = per_choice_loss # Bert Classifier selector final_output_dim = 1 if not use_comparative_bert: if bert_weights_model and hasattr(bert_model_loaded.model, "_classifier"): self._classifier = bert_model_loaded.model._classifier else: self._classifier = Linear(self._output_dim, final_output_dim) else: if use_bilinear_classifier: self._classifier = Bilinear(self._output_dim, self._output_dim, final_output_dim) else: self._classifier = Linear(self._output_dim * 2, final_output_dim) self._classifier.apply(self._bert_model.init_bert_weights) # Comparison layer setup if self._train_comparison_layer: number_of_pairs = self._num_choices * (self._num_choices - 1) if self._comparison_layer_hidden_size == -1: self._comparison_layer_hidden_size = number_of_pairs * number_of_pairs self._comparison_layer_1 = Linear(number_of_pairs, self._comparison_layer_hidden_size) if self._comparison_layer_use_relu: self._comparison_layer_1_activation = torch.nn.LeakyReLU() else: self._comparison_layer_1_activation = torch.nn.Tanh() self._comparison_layer_2 = Linear(self._comparison_layer_hidden_size, self._num_choices) self._comparison_layer_2_activation = torch.nn.Softmax() # Scalar mix, if necessary self._all_layers = not top_layer_only if self._all_layers: if bert_weights_model and hasattr(bert_model_loaded.model, "_scalar_mix") \ and bert_model_loaded.model._scalar_mix is not None: self._scalar_mix = bert_model_loaded.model._scalar_mix else: num_layers = bert_config.num_hidden_layers initial_scalar_parameters = num_layers * [0.0] initial_scalar_parameters[-1] = 5.0 # Starts with most mass on last layer self._scalar_mix = ScalarMix(bert_config.num_hidden_layers, initial_scalar_parameters=initial_scalar_parameters, do_layer_norm=False) else: self._scalar_mix = None # Accuracy and loss setup if self._train_comparison_layer: self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() else: self._accuracy = BooleanAccuracy() self._loss = torch.nn.BCEWithLogitsLoss() self._debug = -1 def _extract_last_token_pooled_output(self, encoded_layers, question_mask): """ Extract the output vector for the last token in the sentence - similarly to how pooled_output is extracted for us when calling 'bert_model'. We need the question mask to find the last actual (non-padding) token :return: """ if self._all_layers: encoded_layers = encoded_layers[-1] # A cool trick to extract the last "True" item in each row question_mask = question_mask.squeeze() # We already asserted this at batch_size == 1, but why not assert question_mask.dim() == 2 shifted_matrix = question_mask.roll(-1, 1) shifted_matrix[:, -1] = 0 last_item_indices = question_mask - shifted_matrix # TODO: This row, for some reason, didn't work as expected, but it is much better then the implementation that follows # last_token_tensor = encoded_layers[last_item_indices] num_pairs, token_number, hidden_size = encoded_layers.size() assert last_item_indices.size() == (num_pairs, token_number) # Don't worry, expand doesn't allocate new memory, it simply views the tensor differently expanded_last_item_indices = last_item_indices.unsqueeze(2).expand(num_pairs, token_number, hidden_size) last_token_tensor = encoded_layers.masked_select(expanded_last_item_indices.byte()) last_token_tensor = last_token_tensor.reshape(num_pairs, hidden_size) pooled_output = self._bert_model.pooler.dense(last_token_tensor) pooled_output = self._bert_model.pooler.activation(pooled_output) return pooled_output def forward(self, question: Dict[str, torch.LongTensor], choice1_indexes: List[int] = None, choice2_indexes: List[int] = None, label: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> torch.Tensor: self._debug -= 1 input_ids = question['bert'] # input_ids.size() == (batch_size, num_pairs, max_sentence_length) batch_size, num_pairs, _ = question['bert'].size() question_mask = (input_ids != 0).long() if self._train_comparison_layer: assert num_pairs == self._num_choices * (self._num_choices - 1) # Segment ids real_segment_ids = question['bert-type-ids'].clone() # Change the last 'SEP' to belong to the second answer (for symmetry) last_seps = (real_segment_ids.roll(-1) == 2) & (real_segment_ids == 1) real_segment_ids[last_seps] = 2 # Update segment ids so that they are '1' for answers and '0' for the question real_segment_ids = (real_segment_ids == 0) | (real_segment_ids == 2) real_segment_ids = real_segment_ids.long() # TODO: How to extract last token pooled output if batch size != 1 assert batch_size == 1 # Run model encoded_layers, first_vectors_pooled_output = self._bert_model(input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(real_segment_ids), attention_mask=util.combine_initial_dims(question_mask), output_all_encoded_layers=self._all_layers) if self._use_comparative_bert: last_vectors_pooled_output = self._extract_last_token_pooled_output(encoded_layers, question_mask) else: last_vectors_pooled_output = None if self._all_layers: mixed_layer = self._scalar_mix(encoded_layers, question_mask) first_vectors_pooled_output = self._bert_model.pooler(mixed_layer) # Apply dropout first_vectors_pooled_output = self._dropout(first_vectors_pooled_output) if self._use_comparative_bert: last_vectors_pooled_output = self._dropout(last_vectors_pooled_output) # Classify if not self._use_comparative_bert: pair_label_logits = self._classifier(first_vectors_pooled_output) else: if self._use_bilinear_classifier: pair_label_logits = self._classifier(first_vectors_pooled_output, last_vectors_pooled_output) else: all_pooled_output = torch.cat((first_vectors_pooled_output, last_vectors_pooled_output), 1) pair_label_logits = self._classifier(all_pooled_output) pair_label_logits = pair_label_logits.view(-1, num_pairs) pair_label_probs = torch.sigmoid(pair_label_logits) output_dict = {} pair_label_probs_flat = pair_label_probs.squeeze(1) output_dict['pair_label_probs'] = pair_label_probs_flat.view(-1, num_pairs) output_dict['pair_label_logits'] = pair_label_logits output_dict['choice1_indexes'] = choice1_indexes output_dict['choice2_indexes'] = choice2_indexes if not self._train_comparison_layer: if label is not None: label = label.unsqueeze(1) label = label.expand(-1, num_pairs) relevant_pairs = (choice1_indexes == label) | (choice2_indexes == label) relevant_probs = pair_label_probs[relevant_pairs] choice1_is_the_label = (choice1_indexes == label)[relevant_pairs] # choice1_is_the_label = choice1_is_the_label.type_as(relevant_logits) loss = self._loss(relevant_probs, choice1_is_the_label.float()) self._accuracy(relevant_probs >= 0.5, choice1_is_the_label) output_dict["loss"] = loss return output_dict else: choice_logits = self._comparison_layer_2(self._comparison_layer_1_activation(self._comparison_layer_1( pair_label_probs))) output_dict['choice_logits'] = choice_logits output_dict['choice_probs'] = torch.softmax(choice_logits, 1) output_dict['predicted_choice'] = torch.argmax(choice_logits, 1) if label is not None: loss = self._loss(choice_logits, label) self._accuracy(choice_logits, label) output_dict["loss"] = loss return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'EM': self._accuracy.get_metric(reset), }
class BidirectionalAttentionFlow(Model): def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, char_field_embedder: TextFieldEmbedder, # num_highway_layers: int, phrase_layer: Seq2SeqEncoder, char_rnn: Seq2SeqEncoder, hops: int, hidden_dim: int, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._char_field_embedder = char_field_embedder self._features_embedder = nn.Embedding(2, 5) # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim() + 5 * 3, # num_highway_layers)) self._phrase_layer = phrase_layer self._encoding_dim = phrase_layer.get_output_dim() # self._stacked_brnn = PytorchSeq2SeqWrapper( # StackedBidirectionalLstm(input_size=self._encoding_dim, hidden_size=hidden_dim, # num_layers=3, recurrent_dropout_probability=0.2)) self._char_rnn = char_rnn self.hops = hops self.interactive_aligners = nn.ModuleList() self.interactive_SFUs = nn.ModuleList() self.self_aligners = nn.ModuleList() self.self_SFUs = nn.ModuleList() self.aggregate_rnns = nn.ModuleList() for i in range(hops): # interactive aligner self.interactive_aligners.append( layers.SeqAttnMatch(self._encoding_dim)) self.interactive_SFUs.append( layers.SFU(self._encoding_dim, 3 * self._encoding_dim)) # self aligner self.self_aligners.append(layers.SelfAttnMatch(self._encoding_dim)) self.self_SFUs.append( layers.SFU(self._encoding_dim, 3 * self._encoding_dim)) # aggregating self.aggregate_rnns.append( PytorchSeq2SeqWrapper( nn.LSTM(input_size=self._encoding_dim, hidden_size=hidden_dim, num_layers=1, dropout=0.2, bidirectional=True, batch_first=True))) # Memmory-based Answer Pointer self.mem_ans_ptr = layers.MemoryAnsPointer(x_size=self._encoding_dim, y_size=self._encoding_dim, hidden_size=hidden_dim, hop=hops, dropout_rate=0.2, normalize=True) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, yesno: torch.IntTensor = None, question_tf: torch.FloatTensor = None, passage_tf: torch.FloatTensor = None, q_em_cased: torch.IntTensor = None, p_em_cased: torch.IntTensor = None, q_em_uncased: torch.IntTensor = None, p_em_uncased: torch.IntTensor = None, q_in_lemma: torch.IntTensor = None, p_in_lemma: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ x1_c_emb = self._dropout(self._char_field_embedder(passage)) x2_c_emb = self._dropout(self._char_field_embedder(question)) # embedded_question = torch.cat([self._dropout(self._text_field_embedder(question)), # self._features_embedder(q_em_cased), # self._features_embedder(q_em_uncased), # self._features_embedder(q_in_lemma), # question_tf.unsqueeze(2)], dim=2) # embedded_passage = torch.cat([self._dropout(self._text_field_embedder(passage)), # self._features_embedder(p_em_cased), # self._features_embedder(p_em_uncased), # self._features_embedder(p_in_lemma), # passage_tf.unsqueeze(2)], dim=2) token_emb_q = self._dropout(self._text_field_embedder(question)) token_emb_c = self._dropout(self._text_field_embedder(passage)) token_emb_question, q_ner_and_pos = torch.split(token_emb_q, [300, 40], dim=2) token_emb_passage, p_ner_and_pos = torch.split(token_emb_c, [300, 40], dim=2) question_word_features = torch.cat([ q_ner_and_pos, self._features_embedder(q_em_cased), self._features_embedder(q_em_uncased), self._features_embedder(q_in_lemma), question_tf.unsqueeze(2) ], dim=2) passage_word_features = torch.cat([ p_ner_and_pos, self._features_embedder(p_em_cased), self._features_embedder(p_em_uncased), self._features_embedder(p_in_lemma), passage_tf.unsqueeze(2) ], dim=2) # embedded_question = self._highway_layer(embedded_q) # embedded_passage = self._highway_layer(embedded_q) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None char_features_c = self._char_rnn( x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2), x1_c_emb.size(3))), passage_lstm_mask.unsqueeze(2).repeat( 1, 1, x1_c_emb.size(2)).reshape( (x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2)))).reshape( (x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2), -1))[:, :, -1, :] char_features_q = self._char_rnn( x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2), x2_c_emb.size(3))), question_lstm_mask.unsqueeze(2).repeat( 1, 1, x2_c_emb.size(2)).reshape( (x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2)))).reshape( (x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2), -1))[:, :, -1, :] # token_emb_q, char_emb_q, question_word_features = torch.split(embedded_question, [300, 300, 56], dim=2) # token_emb_c, char_emb_c, passage_word_features = torch.split(embedded_passage, [300, 300, 56], dim=2) # char_features_q = self._char_rnn(char_emb_q, question_lstm_mask) # char_features_c = self._char_rnn(char_emb_c, passage_lstm_mask) emb_question = torch.cat( [token_emb_question, char_features_q, question_word_features], dim=2) emb_passage = torch.cat( [token_emb_passage, char_features_c, passage_word_features], dim=2) encoded_question = self._dropout( self._phrase_layer(emb_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(emb_passage, passage_lstm_mask)) batch_size = encoded_question.size(0) passage_length = encoded_passage.size(1) encoding_dim = encoded_question.size(-1) # c_check = self._stacked_brnn(encoded_passage, passage_lstm_mask) # q = self._stacked_brnn(encoded_question, question_lstm_mask) c_check = encoded_passage q = encoded_question for i in range(self.hops): q_tilde = self.interactive_aligners[i].forward( c_check, q, question_mask) c_bar = self.interactive_SFUs[i].forward( c_check, torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2)) c_tilde = self.self_aligners[i].forward(c_bar, passage_mask) c_hat = self.self_SFUs[i].forward( c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde], 2)) c_check = self.aggregate_rnns[i].forward(c_hat, passage_mask) # Predict start_scores, end_scores, yesno_scores = self.mem_ans_ptr.forward( c_check, q, passage_mask, question_mask) best_span, yesno_predict, loc = self.get_best_span( start_scores, end_scores, yesno_scores) output_dict = { "span_start_logits": start_scores, "span_end_logits": end_scores, "best_span": best_span } # Compute the loss for training. if span_start is not None: loss = nll_loss(start_scores, span_start.squeeze(-1)) self._span_start_accuracy(start_scores, span_start.squeeze(-1)) loss += nll_loss(end_scores, span_end.squeeze(-1)) self._span_end_accuracy(end_scores, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) gold_span_end_loc = [] span_end = span_end.view(batch_size).squeeze().data.cpu().numpy() for i in range(batch_size): gold_span_end_loc.append( max(span_end[i] + i * passage_length, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) _yesno = yesno_scores.view(-1, 3).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(_yesno, yesno.view(-1), ignore_index=-1) pred_span_end_loc = [] for i in range(batch_size): pred_span_end_loc.append(max(loc[i], 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = yesno_scores.view(-1, 3).index_select(0, predicted_end).view( -1, 3) self._span_yesno_accuracy(_yesno, yesno.squeeze(-1)) output_dict['loss'] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['yesno'] = yesno_predict return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), "yesno": self._span_yesno_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [ self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in output_dict.pop("yesno") ] output_dict['yesno'] = yesno_tags return output_dict @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, yesno_scores: torch.Tensor): if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) yesno_predict = span_start_logits.new_zeros(batch_size, dtype=torch.long) loc = yesno_scores.new_zeros(batch_size, dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() yesno_logits = yesno_scores.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 yesno_predict[b] = int(np.argmax(yesno_logits[b, j])) loc[b] = j + passage_length * b return best_word_span, yesno_predict, loc
def __init__(self, vocab: Vocabulary, pretrained_model: str = None, requires_grad: bool = True, top_layer_only: bool = True, bert_weights_model: str = None, per_choice_loss: bool = False, layer_freeze_regexes: List[str] = None, regularizer: Optional[RegularizerApplicator] = None, use_comparative_bert: bool = True, use_bilinear_classifier: bool = False, train_comparison_layer: bool = False, number_of_choices_compared: int = 0, comparison_layer_hidden_size: int = -1, comparison_layer_use_relu: bool = True) -> None: super().__init__(vocab, regularizer) self._use_comparative_bert = use_comparative_bert self._use_bilinear_classifier = use_bilinear_classifier self._train_comparison_layer = train_comparison_layer if train_comparison_layer: assert number_of_choices_compared > 1 self._num_choices = number_of_choices_compared self._comparison_layer_hidden_size = comparison_layer_hidden_size self._comparison_layer_use_relu = comparison_layer_use_relu # Bert weights and config if bert_weights_model: logging.info(f"Loading BERT weights model from {bert_weights_model}") bert_model_loaded = load_archive(bert_weights_model) self._bert_model = bert_model_loaded.model._bert_model else: self._bert_model = BertModel.from_pretrained(pretrained_model) for param in self._bert_model.parameters(): param.requires_grad = requires_grad #for name, param in self._bert_model.named_parameters(): # grad = requires_grad # if layer_freeze_regexes and grad: # grad = not any([bool(re.search(r, name)) for r in layer_freeze_regexes]) # param.requires_grad = grad bert_config = self._bert_model.config self._output_dim = bert_config.hidden_size self._dropout = torch.nn.Dropout(bert_config.hidden_dropout_prob) self._per_choice_loss = per_choice_loss # Bert Classifier selector final_output_dim = 1 if not use_comparative_bert: if bert_weights_model and hasattr(bert_model_loaded.model, "_classifier"): self._classifier = bert_model_loaded.model._classifier else: self._classifier = Linear(self._output_dim, final_output_dim) else: if use_bilinear_classifier: self._classifier = Bilinear(self._output_dim, self._output_dim, final_output_dim) else: self._classifier = Linear(self._output_dim * 2, final_output_dim) self._classifier.apply(self._bert_model.init_bert_weights) # Comparison layer setup if self._train_comparison_layer: number_of_pairs = self._num_choices * (self._num_choices - 1) if self._comparison_layer_hidden_size == -1: self._comparison_layer_hidden_size = number_of_pairs * number_of_pairs self._comparison_layer_1 = Linear(number_of_pairs, self._comparison_layer_hidden_size) if self._comparison_layer_use_relu: self._comparison_layer_1_activation = torch.nn.LeakyReLU() else: self._comparison_layer_1_activation = torch.nn.Tanh() self._comparison_layer_2 = Linear(self._comparison_layer_hidden_size, self._num_choices) self._comparison_layer_2_activation = torch.nn.Softmax() # Scalar mix, if necessary self._all_layers = not top_layer_only if self._all_layers: if bert_weights_model and hasattr(bert_model_loaded.model, "_scalar_mix") \ and bert_model_loaded.model._scalar_mix is not None: self._scalar_mix = bert_model_loaded.model._scalar_mix else: num_layers = bert_config.num_hidden_layers initial_scalar_parameters = num_layers * [0.0] initial_scalar_parameters[-1] = 5.0 # Starts with most mass on last layer self._scalar_mix = ScalarMix(bert_config.num_hidden_layers, initial_scalar_parameters=initial_scalar_parameters, do_layer_norm=False) else: self._scalar_mix = None # Accuracy and loss setup if self._train_comparison_layer: self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() else: self._accuracy = BooleanAccuracy() self._loss = torch.nn.BCEWithLogitsLoss() self._debug = -1
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, passage_self_attention: Seq2SeqEncoder, semantic_rep_layer: Seq2SeqEncoder, contextual_question_layer: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, regularizer: Optional[RegularizerApplicator] = None, initializer: InitializerApplicator = InitializerApplicator(), ): super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), # num_highway_layers)) self._encoding_dim = self._phrase_layer.get_output_dim() # self._atten_linear_layer = TimeDistributed(torch.nn.Linear(in_features=self._encoding_dim, # out_features=self._encoding_dim, bias=False)) self._atten_linear_layer = torch.nn.Linear( in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False) self._relu = torch.nn.ReLU() self._softmax_d1 = torch.nn.Softmax(dim=1) self._softmax_d2 = torch.nn.Softmax(dim=2) self._atten_fusion = FusionLayer(self._encoding_dim) self._tanh = torch.nn.Tanh() self._sigmoid = torch.nn.Sigmoid() self._passage_self_attention = passage_self_attention # self._self_atten_layer = SelfAttentionLayer(self._encoding_dim) self._self_atten_layer = torch.nn.Bilinear(self._encoding_dim, self._encoding_dim, self._encoding_dim, bias=False) self._self_atten_fusion = FusionLayer(self._encoding_dim) self._semantic_rep_layer = semantic_rep_layer self._contextual_question_layer = contextual_question_layer # self._vector_linear = TimeDistributed( # torch.nn.Linear(in_features=self._encoding_dim, out_features=1, bias=False)) # # self._model_layer_s = TimeDistributed( # torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False)) # self._model_layer_e = TimeDistributed( # torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False)) self._vector_linear = torch.nn.Linear(in_features=self._encoding_dim, out_features=1, bias=False) self._model_layer_s = torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False) self._model_layer_e = torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False) # self._model_layer_s = torch.nn.Bilinear(in1_features=self._encoding_dim, in2_features=self._encoding_dim, out_features=) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms self._loss = torch.nn.CrossEntropyLoss() initializer(self)
class DialogQA(Model): """ This class implements modified version of BiDAF (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. In this set-up, a single instance is a dialog, list of question answer pairs. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. span_start_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span end predictions into the passage state. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_context_answers : ``int``, optional (default=0) If greater than 0, the model will consider previous question answering context. max_span_length: ``int``, optional (default=0) Maximum token length of the output span. max_turn_length: ``int``, optional (default=12) Maximum length of an interaction. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: Optional[InitializerApplicator] = None, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30, max_turn_length: int = 12, ) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, "x,y,x*y") self._merge_atten = TimeDistributed( torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding( max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding( (num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, "x,y,x*y") self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed( torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match( phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers", ) if initializer is not None: initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) def forward( # type: ignore self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ token_character_ids = question["token_characters"]["token_characters"] batch_size, max_qa_count, max_q_len, _ = token_character_ids.size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape( total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout( self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1) question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage) repeated_passage_mask = passage_mask.unsqueeze(1).repeat( 1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view( total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector( max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat( 1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat( batch_size, 1, 1) question_num_ind = question_num_ind.reshape( total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker( question_num_ind) embedded_question = torch.cat( [embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = (embedded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1).view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view( total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view( total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker( p3_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout( self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout( self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view( total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout( self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention( repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax( question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum( repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat( [ repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector, ], dim=-1, ) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout( self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape( total_qa_count, passage_length, 1) * repeated_passage_mask.reshape( total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, dtype=torch.bool, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask & ~self_mask self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([ self_attention_vecs, residual_layer, residual_layer * self_attention_vecs ], dim=-1) residual_layer = F.relu( self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder( torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze( -1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup( span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length, ) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1, ) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss( util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1, ) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy( best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2), ) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view( total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict["best_span_str"] = [] output_dict["qid"] = [] output_dict["followup"] = [] output_dict["yesno"] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]["original_passage"] offsets = metadata[i]["token_offsets"] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append( squad.metric_max_over_ground_truths( squad.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad.metric_max_over_ground_truths( squad.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict["qid"].append(per_dialog_query_id_list) output_dict["best_span_str"].append(per_dialog_best_span_list) output_dict["yesno"].append(per_dialog_yesno_list) output_dict["followup"].append(per_dialog_followup_list) return output_dict @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: yesno_tags = [[ self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list ] for yn_list in output_dict.pop("yesno")] followup_tags = [[ self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list ] for followup_list in output_dict.pop("followup")] output_dict["yesno"] = yesno_tags output_dict["followup"] = followup_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "yesno": self._span_yesno_accuracy.get_metric(reset), "followup": self._span_followup_accuracy.get_metric(reset), "f1": self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup( span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, span_followup_logits: torch.Tensor, max_span_length: int, ) -> torch.Tensor: # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as # yesno prediction bit and followup prediction bit from the predicted span end token. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() span_followup_logits = span_followup_logits.data.cpu().numpy() for b_i in range(batch_size): for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) followup_pred = np.argmax(span_followup_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) best_word_span[b_i, 3] = int(followup_pred) return best_word_span
label = event_item['label'] event_type = event_item['event_type'] logits = predictor.predict(event, event_type)['logits'] predict_out.append(logits) label_y.append(label) if logits == 1: predict_out_f1.append([0, 1]) else: predict_out_f1.append([1, 0]) predict_out = torch.LongTensor(predict_out) predict_out_f1 = torch.LongTensor(predict_out_f1) label_y = torch.LongTensor(label_y) # metrics get_accuracy = BooleanAccuracy() get_f1_score = F1Measure(positive_label=1) get_accuracy(predict_out, label_y) accuracy = get_accuracy.get_metric(reset=False) get_f1_score(predict_out_f1, label_y) precision, recall, f1_measure = get_f1_score.get_metric(reset=False) logger.debug('-------Train Metrics-------') for k in metrics: logger.debug('{}: {}'.format(k, metrics[k])) logger.debug('-------Test Output-------') logger.debug('accuracy: {}'.format(accuracy)) logger.debug('precision: {}'.format(precision)) logger.debug('recall: {}'.format(recall)) logger.debug('f1_measure: {}'.format(f1_measure))
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: Optional[InitializerApplicator] = None, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30, max_turn_length: int = 12, ) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, "x,y,x*y") self._merge_atten = TimeDistributed( torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding( max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding( (num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, "x,y,x*y") self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed( torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match( phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers", ) if initializer is not None: initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout)
class Seq2SeqTask(SequenceGenerationTask): """Sequence-to-sequence Task""" def __init__(self, path, max_seq_len, max_targ_v_size, name, **kw): super().__init__(name, **kw) self.scorer2 = BooleanAccuracy() self.scorers.append(self.scorer2) self.val_metric = "%s_accuracy" % self.name self.val_metric_decreases = False self.max_seq_len = max_seq_len self._label_namespace = self.name + "_tokens" self.max_targ_v_size = max_targ_v_size self.target_indexer = {"words": SingleIdTokenIndexer(namespace=self._label_namespace)} self.files_by_split = { split: os.path.join(path, "%s.tsv" % split) for split in ["train", "val", "test"] } # The following is necessary since word-level tasks (e.g., MT) haven't been tested, yet. if self._tokenizer_name != "SplitChars" and self._tokenizer_name != "dummy_tokenizer_name": raise NotImplementedError("For now, Seq2SeqTask only supports character-level tasks.") def load_data(self): # Data is exposed as iterable: no preloading pass def get_split_text(self, split: str): """ Get split text as iterable of records. Split should be one of 'train', 'val', or 'test'. """ return self.get_data_iter(self.files_by_split[split]) def get_all_labels(self) -> List[str]: """ Build character vocabulary and return it as a list """ token2freq = collections.Counter() for split in ["train", "val"]: for _, sequence in self.get_data_iter(self.files_by_split[split]): for token in sequence: token2freq[token] += 1 return [t for t, _ in token2freq.most_common(self.max_targ_v_size)] def get_data_iter(self, path): """ Load data """ with codecs.open(path, "r", "utf-8", errors="ignore") as txt_fh: for row in txt_fh: row = row.strip().split("\t") if len(row) < 2 or not row[0] or not row[1]: continue src_sent = tokenize_and_truncate(self._tokenizer_name, row[0], self.max_seq_len) tgt_sent = tokenize_and_truncate(self._tokenizer_name, row[2], self.max_seq_len) yield (src_sent, tgt_sent) def get_sentences(self) -> Iterable[Sequence[str]]: """ Yield sentences, used to compute vocabulary. """ for split in self.files_by_split: # Don't use test set for vocab building. if split.startswith("test"): continue path = self.files_by_split[split] yield from self.get_data_iter(path) def count_examples(self): """ Compute here b/c we're streaming the sentences. """ example_counts = {} for split, split_path in self.files_by_split.items(): example_counts[split] = sum( 1 for _ in codecs.open(split_path, "r", "utf-8", errors="ignore") ) self.example_counts = example_counts def process_split( self, split, indexers, model_preprocessing_interface ) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ def _make_instance(input_, target): d = { "inputs": sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(input_), indexers ), "targs": sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(target), self.target_indexer ), } return Instance(d) for sent1, sent2 in split: yield _make_instance(sent1, sent2) def get_metrics(self, reset=False): """Get metrics specific to the task""" avg_nll = self.scorer1.get_metric(reset) acc = self.scorer2.get_metric(reset) return {"perplexity": math.exp(avg_nll), "accuracy": acc} def update_metrics(self, logits, labels, tagmask=None, predictions=None): # This doesn't require logits for now, since loss is updated in another part. assert logits is None and predictions is not None if labels.shape[1] < predictions.shape[2]: predictions = predictions[:, 0, : labels.shape[1]] else: predictions = predictions[:, 0, :] # Cut labels if predictions (without gold target) are shorter. labels = labels[:, : predictions.shape[1]] tagmask = tagmask[:, : predictions.shape[1]] self.scorer2(predictions, labels, tagmask) return def get_prediction(self, voc_src, voc_trg, inputs, gold, output): tokenizer = get_tokenizer(self._tokenizer_name) input_string = tokenizer.detokenize([voc_src[token.item()] for token in inputs]).split( "<EOS>" )[0] gold_string = tokenizer.detokenize([voc_trg[token.item()] for token in gold]).split( "<EOS>" )[0] output_string = tokenizer.detokenize([voc_trg[token.item()] for token in output]).split( "<EOS>" )[0] return input_string, gold_string, output_string
class BidafV4(Model): """ MODIFICATION NOTE: This class is a modification of BiDAF. In here we try to see what happens to our results if we convert the question encoder into a simple term frequency (bag-of-words) encoder which disregards word order. By doing so we analyze whether BiDAF can learn to solve SQuAD without having to encode the question sequentially. It has been shown in previous work that BiDAF and other models trained on SQuAD do not focus on questions words as we would expect them to. For example, they will often focus ORIGINAL DOCSTRING: This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidafV4, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) # # v5: # # remember to set token embeddings in the CONFIG JSON # encoded_question = self._dropout(embedded_question) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX similarity_matrix = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY passage_question_attention = util.last_dim_softmax( similarity_matrix, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Our custom query2context q2c_attention = util.masked_softmax(similarity_matrix, question_mask, dim=1).transpose(-1, -2) q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention) # Now we try the various variants # v1: # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention) # v2: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1])) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2) # v3: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1)) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # v4: # Re-application of query2context attention new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs) masked_similarity = util.replace_masked_values( new_similarity_matrix, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # ------- Original variant # # We replace masked values with something really negative here, so they don't affect the # # max below. # masked_similarity = util.replace_masked_values(similarity_matrix, # question_mask.unsqueeze(1), # -1e7) # # Shape: (batch_size, passage_length) # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # # Shape: (batch_size, passage_length) # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # # Shape: (batch_size, encoding_dim) # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # # Shape: (batch_size, passage_length, encoding_dim) # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, # passage_length, # encoding_dim) # ------- END # Shape: (batch_size, passage_length, encoding_dim * 4) # original beta combination function final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # # v6: # final_merged_passage = torch.cat([tiled_question_passage_vector], # dim=-1) # # # v7: # final_merged_passage = torch.cat([passage_question_vectors], # dim=-1) # # # v8: # final_merged_passage = torch.cat([passage_question_vectors, # tiled_question_passage_vector], # dim=-1) # # # v9: # final_merged_passage = torch.cat([encoded_passage, # passage_question_vectors, # encoded_passage * passage_question_vectors], # dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span
class TransformerQA(Model): """ Registered as `"transformer_qa"`, this class implements a reading comprehension model patterned after the proposed model in [Devlin et al]([email protected]:huggingface/transformers.git), with improvements borrowed from the SQuAD model in the transformers project. It predicts start tokens and end tokens with a linear layer on top of word piece embeddings. If you want to use this model on SQuAD datasets, you can use it with the [`TransformerSquadReader`](../../dataset_readers/transformer_squad#transformersquadreader) dataset reader, registered as `"transformer_squad"`. Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could be more than one instance per question, these metrics are not the official numbers on either SQuAD task. To get official numbers for SQuAD v1.1, for example, you can run ``` python -m allennlp_models.rc.tools.transformer_qa_eval ``` # Parameters vocab : `Vocabulary` transformer_model_name : `str`, optional (default=`'bert-base-cased'`) This model chooses the embedder according to this setting. You probably want to make sure this is set to the same thing as the reader. """ def __init__( self, vocab: Vocabulary, transformer_model_name: str = "bert-base-cased", **kwargs ) -> None: super().__init__(vocab, **kwargs) self._text_field_embedder = BasicTextFieldEmbedder( {"tokens": PretrainedTransformerEmbedder(transformer_model_name)} ) self._linear_layer = nn.Linear(self._text_field_embedder.get_output_dim(), 2) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._per_instance_metrics = SquadEmAndF1() def forward( # type: ignore self, question_with_context: Dict[str, Dict[str, torch.LongTensor]], context_span: torch.IntTensor, cls_index: torch.LongTensor = None, answer_span: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters question_with_context : `Dict[str, torch.LongTensor]` From a `TextField`. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including `[CLS]` and `[SEP]`) has type id 1. context_span : `torch.IntTensor` From a `SpanField`. This marks the span of word pieces in `question` from which answers can come. cls_index : `torch.LongTensor`, optional A tensor of shape `(batch_size,)` that provides the index of the `[CLS]` token in the `question_with_context` for each instance. This is needed because the `[CLS]` token is used to indicate that the question is impossible. If this is `None`, it's assumed that the `[CLS]` token is at index 0 for each instance in the batch. answer_span : `torch.IntTensor`, optional From a `SpanField`. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory. metadata : `List[Dict[str, Any]]`, optional If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of the `metadata` list should be the batch size, and each dictionary should have the keys `id`, `question`, `context`, `question_tokens`, `context_tokens`, and `answers`. # Returns `Dict[str, torch.Tensor]` : An output dictionary with the following fields: - span_start_logits (`torch.FloatTensor`) : A tensor of shape `(batch_size, passage_length)` representing unnormalized log probabilities of the span start position. - span_end_logits (`torch.FloatTensor`) : A tensor of shape `(batch_size, passage_length)` representing unnormalized log probabilities of the span end position (inclusive). - best_span_scores (`torch.FloatTensor`) : The score for each of the best spans. - loss (`torch.FloatTensor`, optional) : A scalar loss to be optimised, evaluated against `answer_span`. - best_span (`torch.IntTensor`, optional) : Provided when not in train mode and sufficient metadata given for the instance. The result of a constrained inference over `span_start_logits` and `span_end_logits` to find the most probable span. Shape is `(batch_size, 2)` and each offset is a token index, unless the best span for an instance was predicted to be the `[CLS]` token, in which case the span will be (-1, -1). - best_span_str (`List[str]`, optional) : Provided when not in train mode and sufficient metadata given for the instance. This is the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._text_field_embedder(question_with_context) # shape: (batch_size, sequence_length, 2) logits = self._linear_layer(embedded_question) # shape: (batch_size, sequence_length, 1) span_start_logits, span_end_logits = logits.split(1, dim=-1) # shape: (batch_size, sequence_length) span_start_logits = span_start_logits.squeeze(-1) # shape: (batch_size, sequence_length) span_end_logits = span_end_logits.squeeze(-1) # Create a mask for `question_with_context` to mask out tokens that are not part # of the context. # shape: (batch_size, sequence_length) possible_answer_mask = torch.zeros_like( get_token_ids_from_text_field_tensors(question_with_context), dtype=torch.bool ) for i, (start, end) in enumerate(context_span): possible_answer_mask[i, start : end + 1] = True # Also unmask the [CLS] token since that token is used to indicate that # the question is impossible. possible_answer_mask[i, 0 if cls_index is None else cls_index[i]] = True # Calculate span start and end probabilities # shape: (batch_size, sequence_length) span_start_probs = softmax(span_start_logits, dim=-1) # shape: (batch_size, sequence_length) span_end_probs = softmax(span_end_logits, dim=-1) # Replace the masked values with a very negative constant since we're in log-space. # shape: (batch_size, sequence_length) span_start_logits = replace_masked_values_with_big_negative_number( span_start_logits, possible_answer_mask ) # shape: (batch_size, sequence_length) span_end_logits = replace_masked_values_with_big_negative_number( span_end_logits, possible_answer_mask ) # Now calculate the best span. # shape: (batch_size, 2) best_spans = get_best_span(span_start_logits, span_end_logits) # Sum the span start score with the span end score to get an overall score for the span. # shape: (batch_size,) best_span_scores = torch.gather( span_start_logits, 1, best_spans[:, 0].unsqueeze(1) ) + torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1)) best_span_scores = best_span_scores.squeeze(1) best_span_probs = torch.gather( span_start_probs, 1, best_spans[:, 0].unsqueeze(1) ) * torch.gather(span_end_probs, 1, best_spans[:, 1].unsqueeze(1)) best_span_probs = best_span_probs.squeeze(1) output_dict = { "span_start_logits": span_start_logits, "span_end_logits": span_end_logits, "best_span_scores": best_span_scores, "span_start_probs": span_start_probs, "span_end_probs": span_end_probs, "best_span_probs": best_span_probs, } # Compute the loss. if answer_span is not None: output_dict["loss"] = self._evaluate_span( best_spans, span_start_logits, span_end_logits, answer_span ) # Gather the string of the best span and compute the EM and F1 against the gold span, # if given. if not self.training and metadata is not None: ( output_dict["best_span_str"], output_dict["best_span"], ) = self._collect_best_span_strings(best_spans, context_span, metadata, cls_index) return output_dict def _evaluate_span( self, best_spans: torch.Tensor, span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, answer_span: torch.Tensor, ) -> torch.Tensor: """ Calculate the loss against the `answer_span` and also update the span metrics. """ span_start = answer_span[:, 0] span_end = answer_span[:, 1] self._span_accuracy(best_spans, answer_span) start_loss = cross_entropy(span_start_logits, span_start, ignore_index=-1) big_constant = min(torch.finfo(start_loss.dtype).max, 1e9) assert not torch.any(start_loss > big_constant), "Start loss too high" end_loss = cross_entropy(span_end_logits, span_end, ignore_index=-1) assert not torch.any(end_loss > big_constant), "End loss too high" self._span_start_accuracy(span_start_logits, span_start) self._span_end_accuracy(span_end_logits, span_end) return (start_loss + end_loss) / 2 def _collect_best_span_strings( self, best_spans: torch.Tensor, context_span: torch.IntTensor, metadata: List[Dict[str, Any]], cls_index: Optional[torch.LongTensor], ) -> Tuple[List[str], torch.Tensor]: """ Collect the string of the best predicted span from the context metadata and update `self._per_instance_metrics`, which in the case of SQuAD v1.1 / v2.0 includes the EM and F1 score. This returns a `Tuple[List[str], torch.Tensor]`, where the `List[str]` is the predicted answer for each instance in the batch, and the tensor is just the input tensor `best_spans` after adjustments so that each answer span corresponds to the context tokens only, and not the question tokens. Spans that correspond to the `[CLS]` token, i.e. the question was predicted to be impossible, will be set to `(-1, -1)`. """ _best_spans = best_spans.detach().cpu().numpy() best_span_strings: List[str] = [] best_span_strings_for_metric: List[str] = [] answer_strings_for_metric: List[List[str]] = [] for (metadata_entry, best_span, cspan, cls_ind) in zip( metadata, _best_spans, context_span, cls_index or (0 for _ in range(len(metadata))), ): context_tokens_for_question = metadata_entry["context_tokens"] if best_span[0] == cls_ind: # Predicting [CLS] is interpreted as predicting the question as unanswerable. best_span_string = "" # NOTE: even though we've "detached" 'best_spans' above, this still # modifies the original tensor in-place. best_span[0], best_span[1] = -1, -1 else: best_span -= int(cspan[0]) assert np.all(best_span >= 0) predicted_start, predicted_end = tuple(best_span) while ( predicted_start >= 0 and context_tokens_for_question[predicted_start].idx is None ): predicted_start -= 1 if predicted_start < 0: logger.warning( f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index " f"'{best_span[0]}' to an offset in the original text." ) character_start = 0 else: character_start = context_tokens_for_question[predicted_start].idx while ( predicted_end < len(context_tokens_for_question) and context_tokens_for_question[predicted_end].idx is None ): predicted_end += 1 if predicted_end >= len(context_tokens_for_question): logger.warning( f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index " f"'{best_span[1]}' to an offset in the original text." ) character_end = len(metadata_entry["context"]) else: end_token = context_tokens_for_question[predicted_end] character_end = end_token.idx + len(sanitize_wordpiece(end_token.text)) best_span_string = metadata_entry["context"][character_start:character_end] best_span_strings.append(best_span_string) answers = metadata_entry.get("answers") if answers: best_span_strings_for_metric.append(best_span_string) answer_strings_for_metric.append(answers) if answer_strings_for_metric: self._per_instance_metrics(best_span_strings_for_metric, answer_strings_for_metric) return best_span_strings, best_spans def get_metrics(self, reset: bool = False) -> Dict[str, float]: output = { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), } if not self.training: exact_match, f1_score = self._per_instance_metrics.get_metric(reset) output["per_instance_em"] = exact_match output["per_instance_f1"] = f1_score return output default_predictor = "transformer_qa"
class DialogQA(Model): """ This class implements modified version of BiDAF (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. In this set-up, a single instance is a dialog, list of question answer pairs. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. span_start_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span end predictions into the passage state. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_context_answers : ``int``, optional (default=0) If greater than 0, the model will consider previous question answering context. max_span_length: ``int``, optional (default=0) Maximum token length of the output span. max_turn_length: ``int``, optional (default=12) Maximum length of an interaction. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30, max_turn_length: int = 12) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding(max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match(phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers") initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \ for yn_list in output_dict.pop("yesno")] followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \ for followup_list in output_dict.pop("followup")] output_dict['yesno'] = yesno_tags output_dict['followup'] = followup_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'yesno': self._span_yesno_accuracy.get_metric(reset), 'followup': self._span_followup_accuracy.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, span_followup_logits: torch.Tensor, max_span_length: int) -> torch.Tensor: # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as # yesno prediction bit and followup prediction bit from the predicted span end token. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() span_followup_logits = span_followup_logits.data.cpu().numpy() for b_i in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) followup_pred = np.argmax(span_followup_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) best_word_span[b_i, 3] = int(followup_pred) return best_word_span
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. attention_similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, attention_similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = MatrixAttention(attention_similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these # aren't necessarily obvious from the configuration files, so we check # here. if modeling_layer.get_input_dim() != 4 * encoding_dim: raise ConfigurationError( "The input dimension to the modeling_layer must be " "equal to 4 times the encoding dimension of the phrase_layer. " "Found {} and 4 * {} respectively.".format( modeling_layer.get_input_dim(), encoding_dim)) if text_field_embedder.get_output_dim() != phrase_layer.get_input_dim( ): raise ConfigurationError( "The output dimension of the text_field_embedder (embedding_dim + " "char_cnn) must match the input dimension of the phrase_encoder. " "Found {} and {}, respectively.".format( text_field_embedder.get_output_dim(), phrase_layer.get_input_dim())) if span_end_encoder.get_input_dim( ) != encoding_dim * 4 + modeling_dim * 3: raise ConfigurationError( "The input dimension of the span_end_encoder should be equal to " "4 * phrase_layer.output_dim + 3 * modeling_layer.output_dim. " "Found {} and (4 * {} + 3 * {}) " "respectively.".format(span_end_encoder.get_input_dim(), encoding_dim, modeling_dim)) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)``. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self._get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span } if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss if metadata is not None: output_dict['best_span_str'] = [] for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def _get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = Variable(span_start_logits.data.new().resize_( batch_size, 2).fill_(0)).long() span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'BidirectionalAttentionFlow': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params( vocab, embedder_params) num_highway_layers = params.pop("num_highway_layers") phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer")) similarity_function = SimilarityFunction.from_params( params.pop("similarity_function")) modeling_layer = Seq2SeqEncoder.from_params( params.pop("modeling_layer")) span_end_encoder = Seq2SeqEncoder.from_params( params.pop("span_end_encoder")) dropout = params.pop('dropout', 0.2) init_params = params.pop('initializer', None) reg_params = params.pop('regularizer', None) initializer = (InitializerApplicator.from_params(init_params) if init_params is not None else InitializerApplicator()) regularizer = RegularizerApplicator.from_params( reg_params) if reg_params is not None else None mask_lstms = params.pop('mask_lstms', True) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, num_highway_layers=num_highway_layers, phrase_layer=phrase_layer, attention_similarity_function=similarity_function, modeling_layer=modeling_layer, span_end_encoder=span_end_encoder, dropout=dropout, mask_lstms=mask_lstms, initializer=initializer, regularizer=regularizer)
class BidirectionalAttentionFlow_1(Model): """ This class implements a Bayesian version of Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). """ def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None: super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer) """ Initialize some data structures """ self.cf_a = cf_a # Bayesian data models self.VBmodels = [] self.LinearModels = [] """ ############## TEXT FIELD EMBEDDER with ELMO #################### text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. """ if (cf_a.use_ELMO): if (type(preloaded_elmo) != type(None)): text_field_embedder = preloaded_elmo else: text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput ) print ("ELMO loaded from disk or downloaded") else: text_field_embedder = None # embedder_out_dim = text_field_embedder.get_output_dim() self._text_field_embedder = text_field_embedder if(cf_a.Add_Linear_projection_ELMO): if (self.cf_a.VB_Linear_projection_ELMO): prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior)) print ("----------------- Bayesian Linear Projection ELMO --------------") linear_projection_ELMO = LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior) self.VBmodels.append(linear_projection_ELMO) else: linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200) self._linear_projection_ELMO = linear_projection_ELMO """ ############## Highway layers #################### num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. """ Input_dimension_highway = None if (cf_a.Add_Linear_projection_ELMO): Input_dimension_highway = 200 else: Input_dimension_highway = text_field_embedder.get_output_dim() num_highway_layers = cf_a.num_highway_layers # Linear later to compute the start if (self.cf_a.VB_highway_layers): print ("----------------- Bayesian Highway network --------------") prior = Vil.Prior(**(cf_a.VB_highway_layers_prior)) highway_layer = HighwayVB(Input_dimension_highway, num_highway_layers, prior = prior) self.VBmodels.append(highway_layer) else: highway_layer = Highway(Input_dimension_highway, num_highway_layers) highway_layer = TimeDistributed(highway_layer) self._highway_layer = highway_layer """ ############## Phrase layer #################### phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. """ if cf_a.phrase_layer_dropout > 0: ## Create dropout layer dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout) else: dropout_phrase_layer = lambda x: x phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout)) phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2 self._phrase_layer = phrase_layer self._dropout_phrase_layer = dropout_phrase_layer """ ############## Matrix attention layer #################### similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. """ # Linear later to compute the start if (self.cf_a.VB_similarity_function): prior = Vil.Prior(**(cf_a.VB_similarity_function_prior)) print ("----------------- Bayesian Similarity matrix --------------") similarity_function = LinearSimilarityVB( combination = "x,y,x*y", tensor_1_dim = phrase_encoding_out_dim, tensor_2_dim = phrase_encoding_out_dim, prior = prior) self.VBmodels.append(similarity_function) else: similarity_function = LinearSimilarity( combination = "x,y,x*y", tensor_1_dim = phrase_encoding_out_dim, tensor_2_dim = phrase_encoding_out_dim) matrix_attention = LegacyMatrixAttention(similarity_function) self._matrix_attention = matrix_attention """ ############## Modelling Layer #################### modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. """ ## Create dropout layer if cf_a.modeling_passage_dropout > 0: ## Create dropout layer dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout) else: dropout_modeling_passage = lambda x: x modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout)) self._modeling_layer = modeling_layer self._dropout_modeling_passage = dropout_modeling_passage """ ############## Span Start Representation ##################### span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. """ encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim # Linear later to compute the start if (self.cf_a.VB_span_start_predictor_linear): prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior)) print ("----------------- Bayesian Span Start Predictor--------------") span_start_predictor_linear = LinearVB(span_start_input_dim, 1, prior = prior) self.VBmodels.append(span_start_predictor_linear) else: span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1) self._span_start_predictor_linear = span_start_predictor_linear self._span_start_predictor = TimeDistributed(span_start_predictor_linear) """ ############## Span End Representation ##################### """ ## Create dropout layer if cf_a.span_end_encoder_dropout > 0: dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout) else: dropout_span_end_encode = lambda x: x span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_encoder = span_end_encoder self._dropout_span_end_encode = dropout_span_end_encode if (self.cf_a.VB_span_end_predictor_linear): print ("----------------- Bayesian Span End Predictor--------------") prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior)) span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior) self.VBmodels.append(span_end_predictor_linear) else: span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1) self._span_end_predictor_linear = span_end_predictor_linear self._span_end_predictor = TimeDistributed(span_end_predictor_linear) """ Dropput last layers """ if cf_a.spans_output_dropout > 0: dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout) else: dropout_spans_output = lambda x: x self._dropout_spans_output = dropout_spans_output """ Checkings and accuracy """ # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() """ mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. """ self._mask_lstms = cf_a.mask_lstms """ ################### Initialize parameters ############################## """ #### THEY ARE ALL INITIALIZED WHEN INSTANTING THE COMPONENTS ### """ ####################### OPTIMIZER ################ """ optimizer = pytut.get_optimizers(self, cf_a) self._optimizer = optimizer #### TODO: Learning rate scheduler #### #scheduler = optim.ReduceLROnPlateau(optimizer, 'max') def forward_ensemble(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False) -> Dict[str, torch.Tensor]: """ Sample 10 times and add them together """ self.set_posterior_mean(True) most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information) self.set_posterior_mean(False) subresults = [most_likely_output] for i in range(10): subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)) batch_size = len(subresults[0]["best_span"]) best_span = bidut.merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range (batch_size): span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) return output def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False, get_attentions = False) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ """ #################### Sample Bayesian weights ################## """ self.sample_posterior() """ ################## MASK COMPUTING ######################## """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None """ ###################### EMBEDDING + HIGHWAY LAYER ######################## """ # self.cf_a.use_ELMO if(self.cf_a.Add_Linear_projection_ELMO): embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])) embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])) else: embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]) embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) """ ###################### phrase_layer LAYER ######################## """ encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) """ ###################### Attention LAYER ######################## """ # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) """ ###################### Spans LAYER ######################## """ # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout_spans_output(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout_spans_output(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = bidut.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss output_dict["span_start_loss"] = span_start_loss output_dict["span_end_loss"] = span_end_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: if (get_sample_level_information): output_dict["em_samples"] = [] output_dict["f1_samples"] = [] output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output_dict["em_samples"].append(em_sample) output_dict["f1_samples"].append(f1_sample) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens if (get_sample_level_information): # Add information about the individual samples for future analysis output_dict["span_start_sample_loss"] = [] output_dict["span_end_sample_loss"] = [] for i in range (batch_size): span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]]) output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) if(get_attentions): output_dict["C2Q_attention"] = passage_question_attention output_dict["Q2C_attention"] = question_passage_attention output_dict["simmilarity"] = passage_question_similarity return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } def train_batch(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ It is enough to just compute the total loss because the normal weights do not depend on the KL Divergence """ # Now we can just compute both losses which will build the dynamic graph output = self.forward(question,passage,span_start,span_end,metadata ) data_loss = output["loss"] KL_div = self.get_KL_divergence() total_loss = self.combine_losses(data_loss, KL_div) self.zero_grad() # zeroes the gradient buffers of all parameters total_loss.backward() if (type(self._optimizer) == type(None)): parameters = filter(lambda p: p.requires_grad, self.parameters()) with torch.no_grad(): for f in parameters: f.data.sub_(f.grad.data * self.lr ) else: # print ("Training") self._optimizer.step() self._optimizer.zero_grad() return output def fill_batch_training_information(self, training_logger, output_batch): """ Function to fill the the training_logger for each batch. training_logger: Dictionary that will hold all the training info output_batch: Output from training the batch """ training_logger["train"]["span_start_loss_batch"].append(output_batch["span_start_loss"].detach().cpu().numpy()) training_logger["train"]["span_end_loss_batch"].append(output_batch["span_end_loss"].detach().cpu().numpy()) training_logger["train"]["loss_batch"].append(output_batch["loss"].detach().cpu().numpy()) # Training metrics: metrics = self.get_metrics() training_logger["train"]["start_acc_batch"].append(metrics["start_acc"]) training_logger["train"]["end_acc_batch"].append(metrics["end_acc"]) training_logger["train"]["span_acc_batch"].append(metrics["span_acc"]) training_logger["train"]["em_batch"].append(metrics["em"]) training_logger["train"]["f1_batch"].append(metrics["f1"]) def fill_epoch_training_information(self, training_logger,device, validation_iterable, num_batches_validation): """ Fill the information per each epoch """ Ntrials_CUDA = 100 # Training Epoch final metrics metrics = self.get_metrics(reset = True) training_logger["train"]["start_acc"].append(metrics["start_acc"]) training_logger["train"]["end_acc"].append(metrics["end_acc"]) training_logger["train"]["span_acc"].append(metrics["span_acc"]) training_logger["train"]["em"].append(metrics["em"]) training_logger["train"]["f1"].append(metrics["f1"]) self.set_posterior_mean(True) self.eval() data_loss_validation = 0 loss_validation = 0 with torch.no_grad(): # Compute the validation accuracy by using all the Validation dataset but in batches. for j in range(num_batches_validation): tensor_dict = next(validation_iterable) trial_index = 0 while (1): try: tensor_dict = pytut.move_to_device(tensor_dict, device) ## Move the tensor to cuda output_batch = self.forward(**tensor_dict) break; except RuntimeError as er: print (er.args) torch.cuda.empty_cache() time.sleep(5) torch.cuda.empty_cache() trial_index += 1 if (trial_index == Ntrials_CUDA): print ("Too many failed trials to allocate in memory") send_error_email(str(er.args)) sys.exit(0) data_loss_validation += output_batch["loss"].detach().cpu().numpy() ## Memmory management !! if (self.cf_a.force_free_batch_memory): del tensor_dict["question"]; del tensor_dict["passage"] del tensor_dict del output_batch torch.cuda.empty_cache() if (self.cf_a.force_call_garbage_collector): gc.collect() data_loss_validation = data_loss_validation/num_batches_validation # loss_validation = loss_validation/num_batches_validation # Training Epoch final metrics metrics = self.get_metrics(reset = True) training_logger["validation"]["start_acc"].append(metrics["start_acc"]) training_logger["validation"]["end_acc"].append(metrics["end_acc"]) training_logger["validation"]["span_acc"].append(metrics["span_acc"]) training_logger["validation"]["em"].append(metrics["em"]) training_logger["validation"]["f1"].append(metrics["f1"]) training_logger["validation"]["data_loss"].append(data_loss_validation) self.train() self.set_posterior_mean(False) def trim_model(self, mu_sigma_ratio = 2): total_size_w = [] total_removed_w = [] total_size_b = [] total_removed_b = [] if (self.cf_a.VB_Linear_projection_ELMO): VBmodel = self._linear_projection_ELMO size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_highway_layers): VBmodel = self._highway_layer._module.VBmodels[0] Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_similarity_function): VBmodel = self._matrix_attention._similarity_function Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_span_start_predictor_linear): VBmodel = self._span_start_predictor_linear Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_span_end_predictor_linear): VBmodel = self._span_end_predictor_linear Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) return total_size_w, total_removed_w, total_size_b, total_removed_b # print (weights_to_remove_W.shape) """ BAYESIAN NECESSARY FUNCTIONS """ sample_posterior = GeneralVBModel.sample_posterior get_KL_divergence = GeneralVBModel.get_KL_divergence set_posterior_mean = GeneralVBModel.set_posterior_mean combine_losses = GeneralVBModel.combine_losses def save_VB_weights(self): """ Function that saves only the VB weights of the model. """ pretrained_dict = ... model_dict = self.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict self.load_state_dict(pretrained_dict)
def test_incorrect_gold_labels_shape_catches_exceptions(self, device: str): accuracy = BooleanAccuracy() predictions = torch.rand([5, 7], device=device) incorrect_shape_labels = torch.rand([5, 8], device=device) with pytest.raises(ValueError): accuracy(predictions, incorrect_shape_labels)
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, char_field_embedder: TextFieldEmbedder, # num_highway_layers: int, phrase_layer: Seq2SeqEncoder, char_rnn: Seq2SeqEncoder, hops: int, hidden_dim: int, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._char_field_embedder = char_field_embedder self._features_embedder = nn.Embedding(2, 5) # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim() + 5 * 3, # num_highway_layers)) self._phrase_layer = phrase_layer self._encoding_dim = phrase_layer.get_output_dim() # self._stacked_brnn = PytorchSeq2SeqWrapper( # StackedBidirectionalLstm(input_size=self._encoding_dim, hidden_size=hidden_dim, # num_layers=3, recurrent_dropout_probability=0.2)) self._char_rnn = char_rnn self.hops = hops self.interactive_aligners = nn.ModuleList() self.interactive_SFUs = nn.ModuleList() self.self_aligners = nn.ModuleList() self.self_SFUs = nn.ModuleList() self.aggregate_rnns = nn.ModuleList() for i in range(hops): # interactive aligner self.interactive_aligners.append( layers.SeqAttnMatch(self._encoding_dim)) self.interactive_SFUs.append( layers.SFU(self._encoding_dim, 3 * self._encoding_dim)) # self aligner self.self_aligners.append(layers.SelfAttnMatch(self._encoding_dim)) self.self_SFUs.append( layers.SFU(self._encoding_dim, 3 * self._encoding_dim)) # aggregating self.aggregate_rnns.append( PytorchSeq2SeqWrapper( nn.LSTM(input_size=self._encoding_dim, hidden_size=hidden_dim, num_layers=1, dropout=0.2, bidirectional=True, batch_first=True))) # Memmory-based Answer Pointer self.mem_ans_ptr = layers.MemoryAnsPointer(x_size=self._encoding_dim, y_size=self._encoding_dim, hidden_size=hidden_dim, hop=hops, dropout_rate=0.2, normalize=True) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self)
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30, max_turn_length: int = 12) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding(max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match(phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers") initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout)
def test_does_not_divide_by_zero_with_no_count(self, device: str): accuracy = BooleanAccuracy() assert accuracy.get_metric() == pytest.approx(0.0)
def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None: super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer) """ Initialize some data structures """ self.cf_a = cf_a # Bayesian data models self.VBmodels = [] self.LinearModels = [] """ ############## TEXT FIELD EMBEDDER with ELMO #################### text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. """ if (cf_a.use_ELMO): if (type(preloaded_elmo) != type(None)): text_field_embedder = preloaded_elmo else: text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput ) print ("ELMO loaded from disk or downloaded") else: text_field_embedder = None # embedder_out_dim = text_field_embedder.get_output_dim() self._text_field_embedder = text_field_embedder if(cf_a.Add_Linear_projection_ELMO): if (self.cf_a.VB_Linear_projection_ELMO): prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior)) print ("----------------- Bayesian Linear Projection ELMO --------------") linear_projection_ELMO = LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior) self.VBmodels.append(linear_projection_ELMO) else: linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200) self._linear_projection_ELMO = linear_projection_ELMO """ ############## Highway layers #################### num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. """ Input_dimension_highway = None if (cf_a.Add_Linear_projection_ELMO): Input_dimension_highway = 200 else: Input_dimension_highway = text_field_embedder.get_output_dim() num_highway_layers = cf_a.num_highway_layers # Linear later to compute the start if (self.cf_a.VB_highway_layers): print ("----------------- Bayesian Highway network --------------") prior = Vil.Prior(**(cf_a.VB_highway_layers_prior)) highway_layer = HighwayVB(Input_dimension_highway, num_highway_layers, prior = prior) self.VBmodels.append(highway_layer) else: highway_layer = Highway(Input_dimension_highway, num_highway_layers) highway_layer = TimeDistributed(highway_layer) self._highway_layer = highway_layer """ ############## Phrase layer #################### phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. """ if cf_a.phrase_layer_dropout > 0: ## Create dropout layer dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout) else: dropout_phrase_layer = lambda x: x phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout)) phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2 self._phrase_layer = phrase_layer self._dropout_phrase_layer = dropout_phrase_layer """ ############## Matrix attention layer #################### similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. """ # Linear later to compute the start if (self.cf_a.VB_similarity_function): prior = Vil.Prior(**(cf_a.VB_similarity_function_prior)) print ("----------------- Bayesian Similarity matrix --------------") similarity_function = LinearSimilarityVB( combination = "x,y,x*y", tensor_1_dim = phrase_encoding_out_dim, tensor_2_dim = phrase_encoding_out_dim, prior = prior) self.VBmodels.append(similarity_function) else: similarity_function = LinearSimilarity( combination = "x,y,x*y", tensor_1_dim = phrase_encoding_out_dim, tensor_2_dim = phrase_encoding_out_dim) matrix_attention = LegacyMatrixAttention(similarity_function) self._matrix_attention = matrix_attention """ ############## Modelling Layer #################### modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. """ ## Create dropout layer if cf_a.modeling_passage_dropout > 0: ## Create dropout layer dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout) else: dropout_modeling_passage = lambda x: x modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout)) self._modeling_layer = modeling_layer self._dropout_modeling_passage = dropout_modeling_passage """ ############## Span Start Representation ##################### span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. """ encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim # Linear later to compute the start if (self.cf_a.VB_span_start_predictor_linear): prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior)) print ("----------------- Bayesian Span Start Predictor--------------") span_start_predictor_linear = LinearVB(span_start_input_dim, 1, prior = prior) self.VBmodels.append(span_start_predictor_linear) else: span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1) self._span_start_predictor_linear = span_start_predictor_linear self._span_start_predictor = TimeDistributed(span_start_predictor_linear) """ ############## Span End Representation ##################### """ ## Create dropout layer if cf_a.span_end_encoder_dropout > 0: dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout) else: dropout_span_end_encode = lambda x: x span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_encoder = span_end_encoder self._dropout_span_end_encode = dropout_span_end_encode if (self.cf_a.VB_span_end_predictor_linear): print ("----------------- Bayesian Span End Predictor--------------") prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior)) span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior) self.VBmodels.append(span_end_predictor_linear) else: span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1) self._span_end_predictor_linear = span_end_predictor_linear self._span_end_predictor = TimeDistributed(span_end_predictor_linear) """ Dropput last layers """ if cf_a.spans_output_dropout > 0: dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout) else: dropout_spans_output = lambda x: x self._dropout_spans_output = dropout_spans_output """ Checkings and accuracy """ # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() """ mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. """ self._mask_lstms = cf_a.mask_lstms """ ################### Initialize parameters ############################## """ #### THEY ARE ALL INITIALIZED WHEN INSTANTING THE COMPONENTS ### """ ####################### OPTIMIZER ################ """ optimizer = pytut.get_optimizers(self, cf_a) self._optimizer = optimizer
print ("-------------- LOGITS OF BOTH SPANS and BEST SPAN ---------------") span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = BidirectionalAttentionFlow_1.get_best_span(span_start_logits, span_end_logits) print ("best_spans", best_span) """ ------------------------------ GET LOSES AND ACCURACIES ----------------------------------- """ span_start_accuracy_function = CategoricalAccuracy() span_end_accuracy_function = CategoricalAccuracy() span_accuracy_function = BooleanAccuracy() squad_metrics_function = SquadEmAndF1() # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss span_start_accuracy_function(span_start_logits, span_start.squeeze(-1)) span_end_accuracy_function(span_end_logits, span_end.squeeze(-1)) span_accuracy_function(best_span, torch.stack([span_start, span_end], -1)) span_start_accuracy = span_start_accuracy_function.get_metric() span_end_accuracy = span_end_accuracy_function.get_metric() span_accuracy = span_accuracy_function.get_metric()
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['best_span_indices'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] output_dict['best_span_indices'].append( [start_offset, end_offset]) best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: # We call the inputs "logits" - they could either be unnormalized logits or normalized log # probabilities. A log_softmax operation is a constant shifting of the entire logit # vector, so taking an argmax over either one gives the same result. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() device = span_start_logits.device # (batch_size, passage_length, passage_length) span_log_probs = span_start_logits.unsqueeze( 2) + span_end_logits.unsqueeze(1) # Only the upper triangle of the span matrix is valid; the lower triangle has entries where # the span ends before it starts. span_log_mask = torch.triu( torch.ones((passage_length, passage_length), device=device)).log().unsqueeze(0) valid_span_log_probs = span_log_probs + span_log_mask # Here we take the span matrix and flatten it, then find the best span using argmax. We # can recover the start and end indices from this flattened list using simple modular # arithmetic. # (batch_size, passage_length * passage_length) best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) span_start_indices = best_spans // passage_length span_end_indices = best_spans % passage_length return torch.stack([span_start_indices, span_end_indices], dim=-1)
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, attention_similarity_function: SimilarityFunction, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, feed_forward: FeedForward, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(ModelSQUAD, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer #self._matrix_attention = MatrixAttention(attention_similarity_function) self._residual_encoder = residual_encoder self._span_end_encoder = span_end_encoder self._span_start_encoder = span_start_encoder self._feed_forward = feed_forward encoding_dim = phrase_layer.get_output_dim() self._span_start_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() self._span_end_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) self._no_answer_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) #self._self_matrix_attention = MatrixAttention(attention_similarity_function) self._linear_layer = TimeDistributed( torch.nn.Linear(4 * encoding_dim, encoding_dim)) self._residual_linear_layer = TimeDistributed( torch.nn.Linear(3 * encoding_dim, encoding_dim)) self._w_p = torch.nn.Parameter(torch.Tensor(encoding_dim)) self._w_q = torch.nn.Parameter(torch.Tensor(encoding_dim)) self._w_pq = torch.nn.Parameter(torch.Tensor(encoding_dim)) std = math.sqrt(6 / (encoding_dim * 3 + 1)) self._w_p.data.uniform_(-std, std) self._w_q.data.uniform_(-std, std) self._w_pq.data.uniform_(-std, std) self._w_x = torch.nn.Parameter(torch.Tensor(encoding_dim)) self._w_y = torch.nn.Parameter(torch.Tensor(encoding_dim)) self._w_xy = torch.nn.Parameter(torch.Tensor(encoding_dim)) #std = math.sqrt(6/ (encoding_dim*3 + 1)) self._w_x.data.uniform_(-std, std) self._w_y.data.uniform_(-std, std) self._w_xy.data.uniform_(-std, std) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self)