def test_accuracy_computation(self):
        accuracy = BooleanAccuracy()
        predictions = torch.Tensor([[0, 1],
                                    [2, 3],
                                    [4, 5],
                                    [6, 7]])
        targets = torch.Tensor([[0, 1],
                                [2, 2],
                                [4, 5],
                                [7, 7]])
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 2. / 4

        mask = torch.ones(4, 2)
        mask[1, 1] = 0
        accuracy(predictions, targets, mask)
        assert accuracy.get_metric() == 5. / 8

        targets[1, 1] = 3
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 8. / 12

        accuracy.reset()
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 3. / 4
Example #2
0
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
                                                      num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)
Example #3
0
class BertSpanPointerResolution(Model):
    """该模型同时预测mask位置以及span的起始位置"""
    def __init__(self,
                 vocab: Vocabulary,
                 model_name: str = None,
                 start_attention: Attention = None,
                 end_attention: Attention = None,
                 text_field_embedder: TextFieldEmbedder = None,
                 task_pretrained_file: str = None,
                 neg_sample_ratio: float = 0.0,
                 max_turn_len: int = 3,
                 start_token: str = "[CLS]",
                 end_token: str = "[SEP]",
                 index_name: str = "bert",
                 eps: float = 1e-8,
                 seed: int = 42,
                 loss_factor: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: RegularizerApplicator = None):
        super().__init__(vocab, regularizer)
        if model_name is None and text_field_embedder is None:
            raise ValueError(
                f"`model_name` and `text_field_embedder` can't both equal to None."
            )
        # 单纯的resolution任务,只需要返回最后一层的embedding表征即可
        self._text_field_embedder = text_field_embedder or PretrainedChineseBertMismatchedEmbedder(
            model_name,
            return_all=False,
            output_hidden_states=False,
            max_turn_length=max_turn_len)

        seed_everything(seed)
        self._neg_sample_ratio = neg_sample_ratio
        self._start_token = start_token
        self._end_token = end_token
        self._index_name = index_name
        self._initializer = initializer

        linear_input_size = self._text_field_embedder.get_output_dim()
        # 使用attention的方法
        self.start_attention = start_attention or BilinearAttention(
            vector_dim=linear_input_size, matrix_dim=linear_input_size)
        self.end_attention = end_attention or BilinearAttention(
            vector_dim=linear_input_size, matrix_dim=linear_input_size)
        # mask的指标,主要考虑F-score,而且我们更加关注`1`的召回率
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._rewrite_em = RewriteEM(valid_keys="semr,nr_semr,re_semr")
        self._restore_score = RestorationScore(compute_restore_tokens=True)
        self._metrics = [
            TokenBasedBLEU(mode="1,2"),
            TokenBasedROUGE(mode="1r,2r")
        ]
        self._eps = eps
        self._loss_factor = loss_factor

        self._initializer(self.start_attention)
        self._initializer(self.end_attention)

        # 加载其他任务预训练的模型
        if task_pretrained_file is not None and os.path.isfile(
                task_pretrained_file):
            logger.info("loading related task pretrained weights...")
            self.load_state_dict(torch.load(task_pretrained_file),
                                 strict=False)

    def _calc_loss(self, span_start_logits: torch.Tensor,
                   span_end_logits: torch.Tensor, use_mask_label: torch.Tensor,
                   start_label: torch.Tensor, end_label: torch.Tensor,
                   best_spans: torch.Tensor):
        batch_size = start_label.size(0)
        # 常规loss
        loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-1)
        # --- 计算start和end标签对应的loss ---
        # 选择出mask_label等于1的位置对应的start和end的结果
        # [B_mask, ]
        span_start_label = start_label.masked_select(
            use_mask_label.to(dtype=torch.bool))
        span_end_label = end_label.masked_select(
            use_mask_label.to(dtype=torch.bool))
        # mask掉大部分为0的标签来计算准确率
        train_span_mask = (span_start_label != -1)

        # [B_mask, 2]
        answer_spans = torch.stack([span_start_label, span_end_label], dim=-1)
        self._span_accuracy(
            best_spans, answer_spans,
            train_span_mask.unsqueeze(-1).expand_as(best_spans))

        # -- 计算start_loss --
        start_losses = loss_fct(span_start_logits, span_start_label)
        # start_label_weight = self._calc_loss_weight(span_start_label)  # 计算标签的weight
        start_loss = torch.sum(start_losses) / batch_size
        # 对loss的值进行检查
        big_constant = min(torch.finfo(start_loss.dtype).max, 1e9)
        if torch.any(start_loss > big_constant):
            logger.critical("Start loss too high (%r)", start_loss)
            logger.critical("span_start_logits: %r", span_start_logits)
            logger.critical("span_start: %r", span_start_label)
            assert False

        # -- 计算end_loss --
        end_losses = loss_fct(span_end_logits, span_end_label)
        # end_label_weight = self._calc_loss_weight(span_end_label)   # 计算标签的weight
        end_loss = torch.sum(end_losses) / batch_size
        if torch.any(end_loss > big_constant):
            logger.critical("End loss too high (%r)", end_loss)
            logger.critical("span_end_logits: %r", span_end_logits)
            logger.critical("span_end: %r", span_end_label)
            assert False

        span_loss = (start_loss + end_loss) / 2

        self._span_start_accuracy(span_start_logits, span_start_label,
                                  train_span_mask)
        self._span_end_accuracy(span_end_logits, span_end_label,
                                train_span_mask)

        loss = span_loss
        return loss

    def _calc_loss_weight(self, label: torch.Tensor):
        label_mask = (label != 0).to(torch.float16)
        label_weight = label_mask * self._loss_factor + 1.0
        return label_weight

    def _get_rewrite_result(self, use_mask_label: torch.Tensor,
                            best_spans: torch.Tensor, query_lens: torch.Tensor,
                            context_lens: torch.Tensor,
                            metadata: List[Dict[str, Any]]):
        # 将两个标签转换成numpy类型
        # [B, query_len]
        use_mask_label = use_mask_label.detach().cpu().numpy()
        # [B_mask, 2]
        best_spans = best_spans.detach().cpu().numpy().tolist()

        predict_rewrite_results = []
        for cur_query_len, cur_context_len, cur_query_mask_labels, mdata in zip(
                query_lens, context_lens, use_mask_label, metadata):
            context_tokens = mdata['context_tokens']
            query_tokens = mdata['query_tokens']
            cur_rewrite_result = copy.deepcopy(query_tokens)
            already_insert_tokens = 0  # 记录已经插入的tokens的数量
            already_insert_min_start = cur_context_len  # 表示当前已经添加过的信息的最小的start
            already_insert_max_end = 0  # 表示当前已经添加过的信息的最大的end
            # 遍历当前mask的所有标签,如果标签为1,则计算对应的span_string
            for i in range(cur_query_len):
                cur_mask_label = cur_query_mask_labels[i]
                # 只有当预测的label为1时,才进行补充
                if cur_mask_label:
                    predict_start, predict_end = best_spans.pop(0)

                    # 如果都为0则继续
                    if predict_start == 0 and predict_end == 0:
                        continue
                    # 如果start大于长度,则继续
                    if predict_start >= cur_context_len:
                        continue
                    # 如果当前想要插入的信息,在之前已经插入过信息的内部,则不再插入
                    if predict_start >= already_insert_min_start and predict_end <= already_insert_max_end:
                        continue
                    # 对位置进行矫正
                    if predict_start < 0 or context_tokens[
                            predict_start] == self._start_token:
                        predict_start = 1

                    if predict_end >= cur_context_len:
                        predict_end = cur_context_len - 1

                    # 获取预测的span
                    predict_span_tokens = context_tokens[
                        predict_start:predict_end + 1]
                    # 更新已经插入的最小的start和最大的end
                    if predict_start < already_insert_min_start:
                        already_insert_min_start = predict_start
                    if predict_end > already_insert_max_end:
                        already_insert_max_end = predict_end
                    # 再对预测的span按照要求进行矫正,只取end_token之前的所有tokens
                    try:
                        index = predict_span_tokens.index(self._end_token)
                        predict_span_tokens = predict_span_tokens[:index]
                    except BaseException:
                        pass

                    # 获取当前span插入的位置
                    # 如果是要插入到当前位置后面,则需要+1
                    # 如果是要插入到当前位置前面,则不需要
                    cur_insert_index = i + already_insert_tokens
                    cur_rewrite_result = cur_rewrite_result[:cur_insert_index] + \
                        predict_span_tokens + cur_rewrite_result[cur_insert_index:]
                    # 记录插入的tokens的数量
                    already_insert_tokens += len(predict_span_tokens)

            cur_rewrite_result = cur_rewrite_result[:-1]
            # 不再以list of tokens的形式
            # 而是以string的形式去计算
            cur_rewrite_string = "".join(cur_rewrite_result)
            rewrite_tokens = mdata.get("rewrite_tokens", None)
            if rewrite_tokens is not None:
                rewrite_string = "".join(rewrite_tokens)
                # 去除[SEP]这个token
                query_string = "".join(query_tokens[:-1])
                self._rewrite_em(cur_rewrite_string, rewrite_string,
                                 query_string)
                # 额外增加的指标
                for metric in self._metrics:
                    metric(cur_rewrite_result, rewrite_tokens)
                # 获取restore_tokens并计算对应的指标
                restore_tokens = mdata.get("restore_tokens", None)
                self._restore_score(cur_rewrite_result,
                                    rewrite_tokens,
                                    queries=query_tokens[:-1],
                                    restore_tokens=restore_tokens)

            predict_rewrite_results.append("".join(cur_rewrite_result))
        return predict_rewrite_results

    @overrides
    def forward(self,
                context_ids: TextFieldTensors,
                query_ids: TextFieldTensors,
                context_lens: torch.Tensor,
                query_lens: torch.Tensor,
                mask_label: Optional[torch.Tensor] = None,
                start_label: Optional[torch.Tensor] = None,
                end_label: Optional[torch.Tensor] = None,
                metadata: List[Dict[str, Any]] = None):
        # concat the context and query to the encoder
        # get the indexers first
        indexers = context_ids.keys()
        dialogue_ids = {}

        # 获取context和query的长度
        context_len = torch.max(context_lens).item()
        query_len = torch.max(query_lens).item()

        # [B, _len]
        context_mask = get_mask_from_sequence_lengths(context_lens,
                                                      context_len)
        query_mask = get_mask_from_sequence_lengths(query_lens, query_len)
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # get the outputs of the dialogue
        if isinstance(self._text_field_embedder, TextFieldEmbedder):
            embedder_outputs = self._text_field_embedder(dialogue_ids)
        else:
            embedder_outputs = self._text_field_embedder(
                **dialogue_ids[self._index_name])

        # get the outputs of the query and context
        # [B, _len, embed_size]
        context_last_layer = embedder_outputs[:, :context_len].contiguous()
        query_last_layer = embedder_outputs[:, context_len:].contiguous()

        # ------- 计算span预测的结果 -------
        # 我们想要知道query中的每一个mask位置的token后面需要补充的内容
        # 也就是其对应的context中span的start和end的位置
        # 同理,将context扩展成 [b, query_len, context_len, embed_size]
        context_last_layer = context_last_layer.unsqueeze(dim=1).expand(
            -1, query_len, -1, -1).contiguous()
        # [b, query_len, context_len]
        context_expand_mask = context_mask.unsqueeze(dim=1).expand(
            -1, query_len, -1).contiguous()

        # 将上面3个部分拼接在一起
        # 这里表示query中所有的position
        span_embed_size = context_last_layer.size(-1)

        if self.training and self._neg_sample_ratio > 0.0:
            # 对mask中0的位置进行采样
            # [B*query_len, ]
            sample_mask_label = mask_label.view(-1)
            # 获取展开之后的长度以及需要采样的负样本的数量
            mask_length = sample_mask_label.size(0)
            mask_sum = int(
                torch.sum(sample_mask_label).item() * self._neg_sample_ratio)
            mask_sum = max(10, mask_sum)
            # 获取需要采样的负样本的索引
            neg_indexes = torch.randint(low=0,
                                        high=mask_length,
                                        size=(mask_sum, ))
            # 限制在长度范围内
            neg_indexes = neg_indexes[:mask_length]
            # 将负样本对应的位置mask置为1
            sample_mask_label[neg_indexes] = 1
            # [B, query_len]
            use_mask_label = sample_mask_label.view(
                -1, query_len).to(dtype=torch.bool)
            # 过滤掉query中pad的部分, [B, query_len]
            use_mask_label = use_mask_label & query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()
        else:
            use_mask_label = query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()

        # 得到span属于每个位置的logits
        # [B_mask, context_len]
        span_start_probs = self.start_attention(span_query_vector,
                                                span_context_matrix,
                                                span_context_mask)
        span_end_probs = self.end_attention(span_query_vector,
                                            span_context_matrix,
                                            span_context_mask)

        span_start_logits = torch.log(span_start_probs + self._eps)
        span_end_logits = torch.log(span_end_probs + self._eps)

        # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置
        best_spans = get_best_span(span_start_logits, span_end_logits)
        # 计算得到每个best_span的分数
        best_span_scores = (
            torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) +
            torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1)))
        # [B_mask, ]
        best_span_scores = best_span_scores.squeeze(1)

        # 将重要的信息写入到输出中
        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_spans": best_spans,
            "best_span_scores": best_span_scores
        }

        # 如果存在标签,则使用标签计算loss
        if start_label is not None:
            loss = self._calc_loss(span_start_logits, span_end_logits,
                                   use_mask_label, start_label, end_label,
                                   best_spans)
            output_dict["loss"] = loss
        if metadata is not None:
            predict_rewrite_results = self._get_rewrite_result(
                use_mask_label, best_spans, query_lens, context_lens, metadata)
            output_dict['rewrite_results'] = predict_rewrite_results
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {}
        metrics["span_acc"] = self._span_accuracy.get_metric(reset)
        for metric in self._metrics:
            metrics.update(metric.get_metric(reset))
        metrics.update(self._rewrite_em.get_metric(reset))
        metrics.update(self._restore_score.get_metric(reset))
        return metrics

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        new_output_dict = {}
        new_output_dict["rewrite_results"] = output_dict["rewrite_results"]
        return new_output_dict
Example #4
0
    def __init__(self, **kwargs):
        super(MySeq2Seq, self).__init__(**kwargs)

        self.acc = BooleanAccuracy()
Example #5
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these
        # aren't necessarily obvious from the configuration files, so we check
        # here.
        if modeling_layer.get_input_dim() != 4 * encoding_dim:
            raise ConfigurationError(
                "The input dimension to the modeling_layer must be "
                "equal to 4 times the encoding dimension of the phrase_layer. "
                "Found {} and 4 * {} respectively.".format(
                    modeling_layer.get_input_dim(), encoding_dim))
        if text_field_embedder.get_output_dim() != phrase_layer.get_input_dim(
        ):
            raise ConfigurationError(
                "The output dimension of the text_field_embedder (embedding_dim + "
                "char_cnn) must match the input dimension of the phrase_encoder. "
                "Found {} and {}, respectively.".format(
                    text_field_embedder.get_output_dim(),
                    phrase_layer.get_input_dim()))

        if span_end_encoder.get_input_dim(
        ) != encoding_dim * 4 + modeling_dim * 3:
            raise ConfigurationError(
                "The input dimension of the span_end_encoder should be equal to "
                "4 * phrase_layer.output_dim + 3 * modeling_layer.output_dim. "
                "Found {} and (4 * {} + 3 * {}) "
                "respectively.".format(span_end_encoder.get_input_dim(),
                                       encoding_dim, modeling_dim))

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)
Example #6
0
class RNet(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 question_encoder: Seq2SeqEncoder,
                 passage_encoder: Seq2SeqEncoder,
                 pair_encoder: AttentionEncoder,
                 self_encoder: AttentionEncoder,
                 output_layer: QAOutputLayer,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 share_encoder: bool = False):

        super().__init__(vocab, regularizer)
        self.text_field_embedder = text_field_embedder
        self.question_encoder = question_encoder
        self.passage_encoder = passage_encoder
        self.pair_encoder = pair_encoder
        self.self_encoder = self_encoder
        self.output_layer = output_layer

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self.share_encoder = share_encoder
        self.loss = torch.nn.CrossEntropyLoss()
        initializer(self)

    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        question_embeded = self.text_field_embedder(question)
        passage_embeded = self.text_field_embedder(passage)

        question_mask = get_text_field_mask(question).byte()
        passage_mask = get_text_field_mask(passage).byte()

        quetion_encoded = self.question_encoder(question_embeded,
                                                question_mask)

        if self.share_encoder:
            passage_encoded = self.question_encoder(passage_embeded,
                                                    passage_mask)
        else:
            passage_encoded = self.passage_encoder(passage_embeded,
                                                   passage_mask)

        passage_encoded = self.pair_encoder(passage_encoded, passage_mask,
                                            quetion_encoded, question_mask)
        passage_encoded = self.self_encoder(passage_encoded, passage_mask,
                                            passage_encoded, passage_mask)

        span_start_logits, span_end_logits = self.output_layer(
            quetion_encoded, question_mask, passage_encoded, passage_mask)

        # Calculating loss and making prediction
        # Following code is copied from allennlp.models.BidirectionalAttentionFlow
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            batch_size = question_embeded.size(0)
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2),
                                                     dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
Example #7
0
class TransformerQA(Model):
    def __init__(self,
                 serialization_dir: str,
                 pretrained_model: str,
                 tokenizer_wrapper: HFTokenizerWrapper,
                 enable_no_answer: bool = False,
                 force_yes_no: bool = False,
                 **kwargs) -> None:
        super().__init__(**kwargs)
        self._tokenizer_wrapper = tokenizer_wrapper

        self._enable_no_answer = enable_no_answer
        self.force_yes_no = force_yes_no

        self._qa_model = AutoModelForQuestionAnswering.from_pretrained(
            pretrained_model, return_dict=True)
        self._qa_model.resize_token_embeddings(len(
            tokenizer_wrapper.tokenizer))

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._boolq_accuracy = Squad2EmAndF1()
        self._per_instance_metrics = Squad2EmAndF1()

        # Initializer placeholder
        self._tokenizer_wrapper.tokenizer = self._tokenizer_wrapper.load(
            serialization_dir, pending=True)
        self._tokenizer_wrapper.save(serialization_dir)
        self._qa_model.resize_token_embeddings(len(
            tokenizer_wrapper.tokenizer))

    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        yes_no_span: torch.IntTensor = None,
        answer_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        question_with_context : `Dict[str, torch.LongTensor]`
            From a ``TextField``. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and
            [SEP]) has type id 1.
        context_span : `torch.IntTensor`
            From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come.
        answer_span : `torch.IntTensor`, optional
            From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.
        metadata : `List[Dict[str, Any]]`, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the ``metadata`` list should be the
            batch size, and each dictionary should have the keys ``id``, ``question``, ``context``,
            ``question_tokens``, ``context_tokens``, and ``answers``.

        # Returns

        An output dictionary consisting of:

        span_start_logits : `torch.FloatTensor`
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : `torch.FloatTensor`
            The result of `softmax(span_start_logits)`.
        span_end_logits : `torch.FloatTensor`
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : `torch.FloatTensor`
            The result of ``softmax(span_end_logits)``.
        best_span : `torch.IntTensor`
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        best_span_scores : `torch.FloatTensor`
            The score for each of the best spans.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        best_span_str : `List[str]`
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        outputs = self._qa_model(**question_with_context)
        span_start_logits = outputs["start_logits"]
        span_end_logits = outputs["end_logits"]

        with torch.no_grad():
            possible_answer_mask = torch.zeros_like(
                question_with_context["input_ids"],
                dtype=torch.bool,
            )
            if not self.force_yes_no:
                for i, (start, end) in enumerate(context_span):
                    if start != -1 and end != -1:
                        possible_answer_mask[i, start:end + 1] = True

            if yes_no_span is not None:
                for i, (start, end) in enumerate(yes_no_span):
                    if start != -1 and end != -1:
                        possible_answer_mask[i, start:end + 1] = True

            for i in range(len(possible_answer_mask)):
                assert any(possible_answer_mask[i])

            # Replace the masked values with a very negative constant.
            context_masked_span_start_logits = replace_masked_values_with_big_negative_number(
                span_start_logits, possible_answer_mask)
            context_masked_span_end_logits = replace_masked_values_with_big_negative_number(
                span_end_logits, possible_answer_mask)
            best_spans = get_best_span(context_masked_span_start_logits,
                                       context_masked_span_end_logits)
            best_span_scores = torch.gather(
                context_masked_span_start_logits, 1,
                best_spans[:, 0].unsqueeze(1)) + torch.gather(
                    context_masked_span_end_logits, 1,
                    best_spans[:, 1].unsqueeze(1))
            best_span_scores = best_span_scores.squeeze(1)

            output_dict = {
                "best_span":
                best_spans,
                "best_span_scores":
                best_span_scores,
                "yes_scores":
                span_start_logits[:, yes_no_span[:, 0]] +
                span_end_logits[:, yes_no_span[:, 0]],
                "no_scores":
                span_start_logits[:, yes_no_span[:, 1]] +
                span_end_logits[:, yes_no_span[:, 1]],
            }
            if self._enable_no_answer:
                no_answer_scores = span_start_logits[:, 0] + span_end_logits[:,
                                                                             0]
                output_dict.update({"no_answer_scores": no_answer_scores})

        # Compute metrics and set loss
        if answer_span is not None:
            span_start = answer_span[:, 0]
            span_end = answer_span[:, 1]
            span_mask = span_start != -1
            if self._enable_no_answer:
                span_mask &= span_start != 0

            self._span_accuracy(best_spans, answer_span,
                                span_mask.unsqueeze(-1).expand_as(best_spans))

            self._span_start_accuracy(context_masked_span_start_logits,
                                      span_start, span_mask)
            self._span_end_accuracy(context_masked_span_end_logits, span_end,
                                    span_mask)

            if self._enable_no_answer:
                possible_answer_mask[:, 0] = True
            # Replace the masked values with a very negative constant.
            masked_span_start_logits = replace_masked_values_with_big_negative_number(
                span_start_logits, possible_answer_mask)
            masked_span_end_logits = replace_masked_values_with_big_negative_number(
                span_end_logits, possible_answer_mask)

            loss_fct = CrossEntropyLoss(ignore_index=-1)
            start_loss = loss_fct(masked_span_start_logits, span_start)
            end_loss = loss_fct(masked_span_end_logits, span_end)
            total_loss = (start_loss + end_loss) / 2
            output_dict["loss"] = total_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            best_spans = best_spans.detach().cpu().numpy()

            output_dict["best_span_str"] = []
            for i, (metadata_entry,
                    best_span) in enumerate(zip(metadata, best_spans)):
                best_span_string = TokensInterpreter.extract_span_string_from_origin_texts(
                    Span(*best_span),
                    [
                        metadata_entry["modified_question"],
                        metadata_entry["context"]
                    ],
                    metadata_entry["offset_mapping"],
                    metadata_entry["special_tokens_mask"],
                )

                if self.force_yes_no:
                    if output_dict["yes_scores"][i].item(
                    ) > output_dict["no_scores"][i].item():
                        overriding_best_span_string = "yes"
                    else:
                        overriding_best_span_string = "no"
                    if overriding_best_span_string != best_span_string:
                        best_span_string = overriding_best_span_string

                output_dict["best_span_str"].append(best_span_string)

                answers = metadata_entry.get("answers")
                if answers is not None and len(answers) > 0:
                    if self._enable_no_answer:
                        final_pred = (best_span_string if
                                      best_span_scores[i] > no_answer_scores[i]
                                      else "")
                    else:
                        final_pred = best_span_string

                    if metadata_entry["is_boolq"]:
                        self._boolq_accuracy(final_pred, answers)
                    else:
                        self._per_instance_metrics(final_pred, answers)

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "boolq_acc": self._boolq_accuracy.get_metric(reset)["em"],
        }
        metrics.update(self._per_instance_metrics.get_metric(reset))
        return metrics

    default_predictor = "transformer_qa_v2"
Example #8
0
class MultiGranularityHierarchicalAttentionFusionNetworks(Model):
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            phrase_layer: Seq2SeqEncoder,
            passage_self_attention: Seq2SeqEncoder,
            semantic_rep_layer: Seq2SeqEncoder,
            contextual_question_layer: Seq2SeqEncoder,
            dropout: float = 0.2,
            mask_lstms: bool = True,
            regularizer: Optional[RegularizerApplicator] = None,
            initializer: InitializerApplicator = InitializerApplicator(),
    ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks,
              self).__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
        #                                               num_highway_layers))
        self._encoding_dim = self._phrase_layer.get_output_dim()
        # self._atten_linear_layer = TimeDistributed(torch.nn.Linear(in_features=self._encoding_dim,
        #                                                            out_features=self._encoding_dim, bias=False))
        self._atten_linear_layer = torch.nn.Linear(
            in_features=self._encoding_dim,
            out_features=self._encoding_dim,
            bias=False)
        self._relu = torch.nn.ReLU()
        self._softmax_d1 = torch.nn.Softmax(dim=1)
        self._softmax_d2 = torch.nn.Softmax(dim=2)

        self._atten_fusion = FusionLayer(self._encoding_dim)

        self._tanh = torch.nn.Tanh()
        self._sigmoid = torch.nn.Sigmoid()

        self._passage_self_attention = passage_self_attention

        # self._self_atten_layer = SelfAttentionLayer(self._encoding_dim)
        self._self_atten_layer = torch.nn.Bilinear(self._encoding_dim,
                                                   self._encoding_dim,
                                                   self._encoding_dim,
                                                   bias=False)
        self._self_atten_fusion = FusionLayer(self._encoding_dim)

        self._semantic_rep_layer = semantic_rep_layer
        self._contextual_question_layer = contextual_question_layer

        # self._vector_linear = TimeDistributed(
        #     torch.nn.Linear(in_features=self._encoding_dim, out_features=1, bias=False))
        #
        # self._model_layer_s = TimeDistributed(
        #     torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False))
        # self._model_layer_e = TimeDistributed(
        #     torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False))
        self._vector_linear = torch.nn.Linear(in_features=self._encoding_dim,
                                              out_features=1,
                                              bias=False)

        self._model_layer_s = torch.nn.Linear(in_features=self._encoding_dim,
                                              out_features=self._encoding_dim,
                                              bias=False)
        self._model_layer_e = torch.nn.Linear(in_features=self._encoding_dim,
                                              out_features=self._encoding_dim,
                                              bias=False)
        # self._model_layer_s = torch.nn.Bilinear(in1_features=self._encoding_dim, in2_features=self._encoding_dim, out_features=)
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        self._mask_lstms = mask_lstms
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)

    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_question = self._text_field_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        embedded_passage = self._text_field_embedder(passage)
        passage_mask = util.get_text_field_mask(passage).float()

        batch_size = embedded_passage.size(0)
        passage_length = embedded_passage.size(1)
        question_length = embedded_question.size(1)

        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        # Shape(batch_size, question_length, encoding_dim)
        u_q = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoding_dim = u_q.size(-1)
        # Shape(batch_size, passage_length, encoding_dim)
        u_p = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        u_q = self._relu(self._atten_linear_layer(u_q))
        u_p = self._relu(self._atten_linear_layer(u_p))
        # Shape(batch_size, question_length, passage_length)
        # S_{ij} computes the similarity(attention weights)
        # between the i_th word of the question and the j_th word of the passage
        s = torch.bmm(u_q, u_p.transpose(2, 1))
        # Shape(batch_size, passage_length, encoding_dim)
        # P to Q
        q_ = attention_weight_sum_batch(
            util.masked_softmax(s.transpose(2, 1),
                                passage_lstm_mask.unsqueeze(-1),
                                dim=2), u_q)
        # Shape(batch_size, question_length, encoding_dim)
        # Q tot P
        p_ = attention_weight_sum_batch(
            util.masked_softmax(s, question_lstm_mask.unsqueeze(-1), dim=2),
            u_p)
        pp = self._atten_fusion(u_p, q_)
        # Shape(batch_size, question_length, encoding_dim)
        qq = self._atten_fusion(u_q, p_)
        # Shape(batch_size, passage_length, encoding_dim)
        d = self._passage_self_attention(pp, passage_lstm_mask)
        # Shape(batch_size, passage_length, encoding_dim)
        l = self._self_atten_layer(d, d)
        l = self._softmax_d2(l)
        # Shape(batch_size, passage_length, encoding_dim)
        d_ = l * d
        # Shape(batch_size, passage_length, encoding_dim)
        dd = self._self_atten_fusion(d, d_)
        # Shape(batch_size, passage_length, encoding_dim)
        ddd = self._semantic_rep_layer(dd, passage_lstm_mask)
        # Shape(batch_size, question_length, encoding_dim)
        qqq = self._contextual_question_layer(qq, question_lstm_mask)
        # Shape(batch_size, question_length, 1) -> (batch_size, question_length)
        # gamma = util.masked_softmax(self._vector_linear(qqq), question_lstm_mask.unsqueeze(-1), dim=2).squeeze(-1)
        qqq_tmp = self._vector_linear(qqq).squeeze(-1)
        gamma = self._softmax_d1(qqq_tmp)
        # Shape(batch_size, question_length)
        # (1, question_length) ` (question_length, encoding_dim)
        vec_q = torch.bmm(gamma.unsqueeze(1), qqq)
        # model & output layer
        # Shape(batch_size, 1, passage_length)
        vec_q_tmp = self._model_layer_s(vec_q)
        p_start = util.masked_softmax(torch.bmm(vec_q_tmp,
                                                ddd.transpose(2,
                                                              1)).squeeze(1),
                                      passage_lstm_mask,
                                      dim=1)
        # p_start = torch.bmm(vec_q_tmp, ddd.transpose(2, 1)).squeeze(1)
        # p_start = self._softmax_d1(p_start)
        span_start_logits = p_start
        # span_start_probs = util.masked_softmax(span_start_logits, passage_lstm_mask)
        # p_end = self._end_vector_matrix_bilinear(vec_q, ddd.permute(0, 2, 1))
        p_end = util.masked_softmax(torch.bmm(self._model_layer_e(vec_q),
                                              ddd.transpose(2, 1)).squeeze(1),
                                    passage_lstm_mask,
                                    dim=1)
        span_end_logits = p_end
        # span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, 1e-7)
        # span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, 1e-7)
        # span_end_probs = util.masked_softmax(span_end_logits, passage_lstm_mask)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        print("span_start_logits")
        print(span_start_logits)
        print("span_end_logits")
        print(span_end_logits)

        output = dict()
        output['best_span'] = best_span

        # Compute the loss for training
        if span_start is not None:
            # loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            # self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            # loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            # self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            # self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            loss = self._loss(span_start_logits, span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += self._loss(span_end_logits, span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            print(loss)
            output['loss'] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output['question_tokens'] = question_tokens
            output['passage_tokens'] = passage_tokens
        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2),
                                                     dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
Example #9
0
class EdgeProbingTask(Task):
    """ Generic class for fine-grained edge probing.

    Acts as a classifier, but with multiple targets for each input text.

    Targets are of the form (span1, span2, label), where span1 and span2 are
    half-open token intervals [i, j).

    Subclass this for each dataset, or use register_task with appropriate kw
    args.
    """
    @property
    def _tokenizer_suffix(self):
        """ Suffix to make sure we use the correct source files,
        based on the given tokenizer.
        """
        if self.tokenizer_name:
            return ".retokenized." + self.tokenizer_name
        else:
            return ""

    def tokenizer_is_supported(self, tokenizer_name):
        """ Check if the tokenizer is supported for this task. """
        # Assume all tokenizers supported; if retokenized data not found
        # for this particular task, we'll just crash on file loading.
        return True

    def __init__(
        self,
        path: str,
        max_seq_len: int,
        name: str,
        label_file: str = None,
        files_by_split: Dict[str, str] = None,
        is_symmetric: bool = False,
        single_sided: bool = False,
        **kw,
    ):
        """Construct an edge probing task.

        path, max_seq_len, and name are passed by the code in preprocess.py;
        remaining arguments should be provided by a subclass constructor or via
        @register_task.

        Args:
            path: data directory
            max_seq_len: maximum sequence length (currently ignored)
            name: task name
            label_file: relative path to labels file
            files_by_split: split name ('train', 'val', 'test') mapped to
                relative filenames (e.g. 'train': 'train.json')
            is_symmetric: if true, span1 and span2 are assumed to be the same
                type and share parameters. Otherwise, we learn a separate
                projection layer and attention weight for each.
            single_sided: if true, only use span1.
        """
        super().__init__(name, **kw)

        assert label_file is not None
        assert files_by_split is not None
        self._files_by_split = {
            split: os.path.join(path, fname) + self._tokenizer_suffix
            for split, fname in files_by_split.items()
        }
        self.path = path
        self.label_file = label_file
        self.max_seq_len = max_seq_len
        self.is_symmetric = is_symmetric
        self.single_sided = single_sided

        self._iters_by_split = None
        self.all_labels = None
        self.n_classes = None

        # see add_task_label_namespace in preprocess.py
        self._label_namespace = self.name + "_labels"

        # Scorers
        #  self.acc_scorer = CategoricalAccuracy()  # multiclass accuracy
        self.mcc_scorer = FastMatthews()
        self.acc_scorer = BooleanAccuracy()  # binary accuracy
        self.f1_scorer = F1Measure(positive_label=1)  # binary F1 overall
        self.val_metric = "%s_f1" % self.name  # TODO: switch to MCC?
        self.val_metric_decreases = False

    def load_data(self):
        label_file = os.path.join(self.path, self.label_file)
        self.all_labels = list(utils.load_lines(label_file))
        self.n_classes = len(self.all_labels)

    @classmethod
    def _stream_records(cls, filename):
        skip_ctr = 0
        total_ctr = 0
        for record in utils.load_json_data(filename):
            total_ctr += 1
            # Skip records with empty targets.
            # TODO(ian): don't do this if generating negatives!
            if not record.get("targets", None):
                skip_ctr += 1
                continue
            yield record
        log.info(
            "Read=%d, Skip=%d, Total=%d from %s",
            total_ctr - skip_ctr,
            skip_ctr,
            total_ctr,
            filename,
        )

    @staticmethod
    def merge_preds(record: Dict, preds: Dict) -> Dict:
        """ Merge predictions into record, in-place.

        List-valued predictions should align to targets,
        and are attached to the corresponding target entry.

        Non-list predictions are attached to the top-level record.
        """
        record["preds"] = {}
        for target in record["targets"]:
            target["preds"] = {}
        for key, val in preds.items():
            if isinstance(val, list):
                assert len(val) == len(record["targets"])
                for i, target in enumerate(record["targets"]):
                    target["preds"][key] = val[i]
            else:
                # non-list predictions, attach to top-level preds
                record["preds"][key] = val
        return record

    def load_data(self):
        iters_by_split = collections.OrderedDict()
        for split, filename in self._files_by_split.items():
            #  # Lazy-load using RepeatableIterator.
            #  loader = functools.partial(utils.load_json_data,
            #                             filename=filename)
            #  iter = serialize.RepeatableIterator(loader)
            iter = list(self._stream_records(filename))
            iters_by_split[split] = iter
        self._iters_by_split = iters_by_split

    def get_split_text(self, split: str):
        """ Get split text as iterable of records.

        Split should be one of 'train', 'val', or 'test'.
        """
        return self._iters_by_split[split]

    @classmethod
    def get_num_examples(cls, split_text):
        """ Return number of examples in the result of get_split_text.

        Subclass can override this if data is not stored in column format.
        """
        return len(split_text)

    @classmethod
    def _make_span_field(cls, s, text_field, offset=1):
        return SpanField(s[0] + offset, s[1] - 1 + offset, text_field)

    def _pad_tokens(self, tokens):
        """Pad tokens according to the current tokenization style."""
        if self.tokenizer_name.startswith("bert-"):
            # standard padding for BERT; see
            # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/extract_features.py#L85  # noqa
            return ["[CLS]"] + tokens + ["[SEP]"]
        else:
            return [utils.SOS_TOK] + tokens + [utils.EOS_TOK]

    def make_instance(self, record, idx, indexers) -> Type[Instance]:
        """Convert a single record to an AllenNLP Instance."""
        tokens = record["text"].split()  # already space-tokenized by Moses
        tokens = self._pad_tokens(tokens)
        text_field = sentence_to_text_field(tokens, indexers)

        d = {}
        d["idx"] = MetadataField(idx)

        d["input1"] = text_field

        d["span1s"] = ListField([
            self._make_span_field(t["span1"], text_field, 1)
            for t in record["targets"]
        ])
        if not self.single_sided:
            d["span2s"] = ListField([
                self._make_span_field(t["span2"], text_field, 1)
                for t in record["targets"]
            ])

        # Always use multilabel targets, so be sure each label is a list.
        labels = [
            utils.wrap_singleton_string(t["label"]) for t in record["targets"]
        ]
        d["labels"] = ListField([
            MultiLabelField(label_set,
                            label_namespace=self._label_namespace,
                            skip_indexing=False) for label_set in labels
        ])
        return Instance(d)

    def process_split(self, records, indexers) -> Iterable[Type[Instance]]:
        """ Process split text into a list of AllenNLP Instances. """
        def _map_fn(r, idx):
            return self.make_instance(r, idx, indexers)

        return map(_map_fn, records, itertools.count())

    def get_all_labels(self) -> List[str]:
        return self.all_labels

    def get_sentences(self) -> Iterable[Sequence[str]]:
        """ Yield sentences, used to compute vocabulary. """
        for split, iter in self._iters_by_split.items():
            # Don't use test set for vocab building.
            if split.startswith("test"):
                continue
            for record in iter:
                yield record["text"].split()

    def get_metrics(self, reset=False):
        """Get metrics specific to the task"""
        metrics = {}
        metrics["mcc"] = self.mcc_scorer.get_metric(reset)
        metrics["acc"] = self.acc_scorer.get_metric(reset)
        precision, recall, f1 = self.f1_scorer.get_metric(reset)
        metrics["precision"] = precision
        metrics["recall"] = recall
        metrics["f1"] = f1
        return metrics
Example #10
0
class CMVDiscriminator(FeedForward):
    def __init__(self,
                 input_dim: int,
                 num_layers: int,
                 hidden_dims: Union[int, Sequence[int]],
                 activations: Union[Activation, Sequence[Activation]],
                 dropout: Union[float, Sequence[float]] = 0.0,
                 gate_bias: float = -2) -> None:

        super(CMVDiscriminator,
              self).__init__(input_dim, num_layers, hidden_dims, activations,
                             dropout)

        if not isinstance(hidden_dims, list):
            hidden_dims = [hidden_dims] * (num_layers - 1)
        input_dims = hidden_dims[1:]

        gate_layers = [None]  #so we can zip this later
        for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims):
            gate_layer = torch.nn.Linear(layer_input_dim, layer_output_dim)
            gate_layer.bias.data.fill_(gate_bias)

            gate_layers.append(gate_layer)

        self._gate_layers = torch.nn.ModuleList(gate_layers)

        #feedforward requires an Activation so we just use the identity
        self._output_feedforward = FeedForward(hidden_dims[-1], 1, 1,
                                               lambda x: x)

        self._accuracy = BooleanAccuracy()

    def _get_hidden(self, output):
        layers = list(
            zip(self._linear_layers, self._activations, self._dropout,
                self._gate_layers))
        layer, activation, dropout, _ = layers[0]
        output = dropout(activation(layer(output)))

        for layer, activation, dropout, gate in layers[1:]:
            gate_output = torch.sigmoid(gate(output))
            new_output = dropout(activation(layer(output)))

            output = torch.add(torch.mul(gate_output, new_output),
                               torch.mul(1 - gate_output, output))

        return output

    def forward(self, real_output, fake_output=None):

        real_hidden = self._get_hidden(real_output)

        real_value = self._output_feedforward(real_hidden)
        labels = torch.ones(real_hidden.size(0))
        if torch.cuda.is_available() and real_value.is_cuda:
            idx = real_value.get_device()
            labels = labels.cuda(idx)

        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            real_value.view(-1), labels)

        predictions = torch.sigmoid(real_value) > 0.5

        if fake_output is not None:
            fake_hidden = self._get_hidden(fake_output)
            fake_value = self._output_feedforward(fake_hidden)
            fake_labels = torch.zeros(fake_hidden.size(0))

            if torch.cuda.is_available() and fake_value.is_cuda:
                idx = fake_value.get_device()
                fake_labels = fake_labels.cuda(idx)

            loss += torch.nn.functional.binary_cross_entropy_with_logits(
                fake_value.view(-1), fake_labels)

            predictions = torch.cat(
                [predictions, torch.sigmoid(fake_value) > 0.5])
            labels = torch.cat([labels, fake_labels])

        self._accuracy(predictions, labels.byte())

        return {'loss': loss, 'predictions': predictions, 'labels': labels}

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self._accuracy.get_metric(reset)}
Example #11
0
class BertQA(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sim_text_field_embedder: TextFieldEmbedder,
                 loss_weights: Dict,
                 sim_class_weights: List,
                 pretrained_sim_path: str = None,
                 use_scenario_encoding: bool = True,
                 sim_pretraining: bool = False,
                 dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BertQA, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        if use_scenario_encoding:
            self._sim_text_field_embedder = sim_text_field_embedder
        self.loss_weights = loss_weights
        self.sim_class_weights = sim_class_weights
        self.use_scenario_encoding = use_scenario_encoding
        self.sim_pretraining = sim_pretraining

        if self.sim_pretraining and not self.use_scenario_encoding:
            raise ValueError(
                "When pretraining Scenario Interpretation Module, you should use it."
            )

        embedding_dim = self._text_field_embedder.get_output_dim()
        self._action_predictor = torch.nn.Linear(embedding_dim, 4)
        self._sim_token_label_predictor = torch.nn.Linear(embedding_dim, 4)
        self._span_predictor = torch.nn.Linear(embedding_dim, 2)
        self._action_accuracy = CategoricalAccuracy()
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_loss_metric = Average()
        self._action_loss_metric = Average()
        self._sim_loss_metric = Average()
        self._sim_yes_f1 = F1Measure(2)
        self._sim_no_f1 = F1Measure(3)

        if use_scenario_encoding and pretrained_sim_path is not None:
            logger.info("Loading pretrained model..")
            self.load_state_dict(torch.load(pretrained_sim_path))
            for param in self._sim_text_field_embedder.parameters():
                param.requires_grad = False

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        initializer(self)

    def get_passage_representation(self, bert_output, bert_input):
        # Shape: (batch_size, bert_input_len)
        input_type_ids = self.get_input_type_ids(
            bert_input['bert-type-ids'], bert_input['bert-offsets'],
            self._text_field_embedder._token_embedders['bert']).float()
        # Shape: (batch_size, bert_input_len)
        input_mask = util.get_text_field_mask(bert_input).float()
        passage_mask = input_mask - input_type_ids  # works only with one [SEP]
        # Shape: (batch_size, bert_input_len, embedding_dim)
        passage_representation = bert_output * passage_mask.unsqueeze(2)
        # Shape: (batch_size, passage_len, embedding_dim)
        passage_representation = passage_representation[:,
                                                        passage_mask.sum(
                                                            dim=0) > 0, :]
        # Shape: (batch_size, passage_len)
        passage_mask = passage_mask[:, passage_mask.sum(dim=0) > 0]

        return passage_representation, passage_mask

    def forward(
            self,  # type: ignore
            bert_input: Dict[str, torch.LongTensor],
            sim_bert_input: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        if self.use_scenario_encoding:
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_bert_input_token_labels_wp = sim_bert_input[
                'scenario_gold_encoding']
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input)
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_input_mask_wp = (sim_bert_input['bert'] != 0).float()
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[
                'bert-type-ids'].float()  # works only with one [SEP]
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze(
                2)
            # Shape: (batch_size, passage_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_passage_representation_wp[:,
                                                                          sim_passage_mask_wp
                                                                          .sum(
                                                                              dim
                                                                              =0
                                                                          ) >
                                                                          0, :]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:,
                                                                         sim_passage_mask_wp
                                                                         .sum(
                                                                             dim
                                                                             =0
                                                                         ) > 0]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_mask_wp = sim_passage_mask_wp[:,
                                                      sim_passage_mask_wp.sum(
                                                          dim=0) > 0]

            # Shape: (batch_size, passage_len_wp, 4)
            sim_token_logits_wp = self._sim_token_label_predictor(
                sim_passage_representation_wp)

            if span_start is not None:  # during training and validation
                class_weights = torch.tensor(self.sim_class_weights,
                                             device=sim_token_logits_wp.device,
                                             dtype=torch.float)
                sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4),
                                         sim_passage_token_labels_wp.view(-1),
                                         ignore_index=0,
                                         weight=class_weights)
                self._sim_loss_metric(sim_loss.item())
                self._sim_yes_f1(sim_token_logits_wp,
                                 sim_passage_token_labels_wp,
                                 sim_passage_mask_wp)
                self._sim_no_f1(sim_token_logits_wp,
                                sim_passage_token_labels_wp,
                                sim_passage_mask_wp)
                if self.sim_pretraining:
                    return {'loss': sim_loss}

            if not self.sim_pretraining:
                # Shape: (batch_size, passage_len_wp)
                bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax(
                    dim=2)) * sim_passage_mask_wp.long()
                # Shape: (batch_size, bert_input_len_wp)
                bert_input_wp_len = bert_input['history_encoding'].size(1)
                if bert_input['scenario_encoding'].size(1) > bert_input_wp_len:
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = bert_input[
                        'scenario_encoding'][:, :bert_input_wp_len]
                else:
                    batch_size = bert_input['scenario_encoding'].size(0)
                    difference = bert_input_wp_len - bert_input[
                        'scenario_encoding'].size(1)
                    zeros = torch.zeros(
                        batch_size,
                        difference,
                        dtype=bert_input['scenario_encoding'].dtype,
                        device=bert_input['scenario_encoding'].device)
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = torch.cat(
                        [bert_input['scenario_encoding'], zeros], dim=1)

        # Shape: (batch_size, bert_input_len + 1, embedding_dim)
        bert_output = self._text_field_embedder(bert_input)
        # Shape: (batch_size, embedding_dim)
        pooled_output = bert_output[:, 0]
        # Shape: (batch_size, bert_input_len, embedding_dim)
        bert_output = bert_output[:, 1:, :]
        # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len)
        passage_representation, passage_mask = self.get_passage_representation(
            bert_output, bert_input)

        # Shape: (batch_size, 4)
        action_logits = self._action_predictor(pooled_output)
        # Shape: (batch_size, passage_len, 2)
        span_logits = self._span_predictor(passage_representation)
        # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1)
        span_start_logits, span_end_logits = span_logits.split(1, dim=2)
        # Shape: (batch_size, passage_len)
        span_start_logits = span_start_logits.squeeze(2)
        # Shape: (batch_size, passage_len)
        span_end_logits = span_end_logits.squeeze(2)

        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "pooled_output": pooled_output,
            "passage_representation": passage_representation,
            "action_logits": action_logits,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        if self.use_scenario_encoding:
            output_dict["sim_token_logits"] = sim_token_logits_wp

        # Compute the loss for training (and for validation)
        if span_start is not None:
            # Shape: (batch_size,)
            span_loss = nll_loss(util.masked_log_softmax(
                span_start_logits, passage_mask),
                                 span_start.squeeze(1),
                                 reduction='none')
            # Shape: (batch_size,)
            span_loss += nll_loss(util.masked_log_softmax(
                span_end_logits, passage_mask),
                                  span_end.squeeze(1),
                                  reduction='none')
            # Shape: (batch_size,)
            more_mask = (label == self.vocab.get_token_index(
                'More', namespace="labels")).float()
            # Shape: (batch_size,)
            span_loss = (span_loss * more_mask).sum() / (more_mask.sum() +
                                                         1e-6)
            if more_mask.sum() > 1e-7:
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(1), more_mask)
                self._span_end_accuracy(span_end_logits, span_end.squeeze(1),
                                        more_mask)
                # Shape: (batch_size, 2)
                span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long()
                self._span_accuracy(best_span,
                                    torch.cat([span_start, span_end], dim=1),
                                    span_acc_mask)

            action_loss = cross_entropy(action_logits, label)
            self._action_accuracy(action_logits, label)

            self._span_loss_metric(span_loss.item())
            self._action_loss_metric(action_loss.item())
            output_dict['loss'] = self.loss_weights[
                'span_loss'] * span_loss + self.loss_weights[
                    'action_loss'] * action_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if not self.training:  # true during validation and test
            output_dict['best_span_str'] = []
            batch_size = len(metadata)
            for i in range(batch_size):
                passage_text = metadata[i]['passage_text']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_str = passage_text[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_str)
                if 'gold_span' in metadata[i]:
                    if metadata[i]['action'] == 'More':
                        gold_span = metadata[i]['gold_span']
                        self._squad_metrics(best_span_str, [gold_span])
        return output_dict

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        action_probs = softmax(output_dict['action_logits'], dim=1)
        output_dict['action_probs'] = action_probs

        predictions = action_probs.cpu().data.numpy()
        argmax_indices = numpy.argmax(predictions, axis=1)
        labels = [
            self.vocab.get_token_from_index(x, namespace="labels")
            for x in argmax_indices
        ]
        output_dict['label'] = labels
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if self.use_scenario_encoding:
            sim_loss = self._sim_loss_metric.get_metric(reset)
            _, _, yes_f1 = self._sim_yes_f1.get_metric(reset)
            _, _, no_f1 = self._sim_no_f1.get_metric(reset)

        if self.sim_pretraining:
            return {'sim_macro_f1': (yes_f1 + no_f1) / 2}

        try:
            action_acc = self._action_accuracy.get_metric(reset)
        except ZeroDivisionError:
            action_acc = 0
        try:
            start_acc = self._span_start_accuracy.get_metric(reset)
        except ZeroDivisionError:
            start_acc = 0
        try:
            end_acc = self._span_end_accuracy.get_metric(reset)
        except ZeroDivisionError:
            end_acc = 0
        try:
            span_acc = self._span_accuracy.get_metric(reset)
        except ZeroDivisionError:
            span_acc = 0

        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        span_loss = self._span_loss_metric.get_metric(reset)
        action_loss = self._action_loss_metric.get_metric(reset)
        agg_metric = span_acc + action_acc * 0.45

        metrics = {
            'action_acc': action_acc,
            'span_acc': span_acc,
            'span_loss': span_loss,
            'action_loss': action_loss,
            'agg_metric': agg_metric
        }

        if self.use_scenario_encoding:
            metrics['sim_macro_f1'] = (yes_f1 + no_f1) / 2

        if not self.training:  # during validation
            metrics['em'] = exact_match
            metrics['f1'] = f1_score

        return metrics

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = torch.triu(
            torch.ones((passage_length, passage_length),
                       device=device)).log().unsqueeze(0)
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)

    def get_input_type_ids(self, type_ids, offsets, embedder):
        "Converts (bsz, seq_len_wp) to (bsz, seq_len_wp) by indexing."
        batch_size = type_ids.size(0)
        full_seq_len = type_ids.size(1)
        if full_seq_len > embedder.max_pieces:  # Recombine if we had used sliding window approach
            assert batch_size == 1 and type_ids.max() > 0
            num_question_tokens = type_ids[0][:embedder.max_pieces].nonzero(
            ).size(0)
            select_indices = embedder.indices_to_select(
                full_seq_len, num_question_tokens)
            type_ids = type_ids[:, select_indices]

        range_vector = util.get_range_vector(
            batch_size, device=util.get_device_of(type_ids)).unsqueeze(1)
        type_ids = type_ids[range_vector, offsets]
        return type_ids
Example #12
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        num_highway_layers: int,
        phrase_layer: Seq2SeqEncoder,
        matrix_attention: MatrixAttention,
        modeling_layer: Seq2SeqEncoder,
        span_end_encoder: Seq2SeqEncoder,
        dropout: float = 0.2,
        mask_lstms: bool = True,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = matrix_attention
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(
            modeling_layer.get_input_dim(),
            4 * encoding_dim,
            "modeling layer input dim",
            "4 * encoding dim",
        )
        check_dimensions_match(
            text_field_embedder.get_output_dim(),
            phrase_layer.get_input_dim(),
            "text field embedder output dim",
            "phrase layer input dim",
        )
        check_dimensions_match(
            span_end_encoder.get_input_dim(),
            4 * encoding_dim + 3 * modeling_dim,
            "span end encoder input dim",
            "4 * encoding dim + 3 * modeling dim",
        )

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)
class QaNet(Model):
    """
    This class implements Adams Wei Yu's `QANet Model <https://openreview.net/forum?id=B14TlG-RW>`_
    for machine reading comprehension published at ICLR 2018.

    The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet
    replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the
    modeling layer and output layer.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the passage-question attention.
    matrix_attention_layer : ``MatrixAttention``
        The matrix attention function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    dropout_prob : ``float``, optional (default=0.1)
        If greater than 0, we will apply dropout with this probability between layers.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """

    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        num_highway_layers: int,
        phrase_layer: Seq2SeqEncoder,
        matrix_attention_layer: MatrixAttention,
        modeling_layer: Seq2SeqEncoder,
        dropout_prob: float = 0.1,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        text_embed_dim = text_field_embedder.get_output_dim()
        encoding_in_dim = phrase_layer.get_input_dim()
        encoding_out_dim = phrase_layer.get_output_dim()
        modeling_in_dim = modeling_layer.get_input_dim()
        modeling_out_dim = modeling_layer.get_output_dim()

        self._text_field_embedder = text_field_embedder

        self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim)
        self._highway_layer = Highway(encoding_in_dim, num_highway_layers)

        self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim, encoding_in_dim)
        self._phrase_layer = phrase_layer

        self._matrix_attention = matrix_attention_layer

        self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4, modeling_in_dim)
        self._modeling_layer = modeling_layer

        self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)
        self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._metrics = SquadEmAndF1()
        self._dropout = torch.nn.Dropout(p=dropout_prob) if dropout_prob > 0 else lambda x: x

        initializer(self)

    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(embedded_passage)

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask)
        )
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask)
        )

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True
        )
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True
        )
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat(
                [
                    encoded_passage,
                    passage_question_vectors,
                    encoded_passage * passage_question_vectors,
                    encoded_passage * passage_passage_vectors,
                ],
                dim=-1,
            )
        )

        modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]

        for _ in range(3):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask)
            )
            modeled_passage_list.append(modeled_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)
            )
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)
            )
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])
                passage_str = metadata[i]["original_passage"]
                offsets = metadata[i]["token_offsets"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict["best_span_str"].append(best_span_string)
                answer_texts = metadata[i].get("answer_texts", [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._metrics.get_metric(reset)
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "em": exact_match,
            "f1": f1_score,
        }
class BERT_QA(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 dropout: float = 0.0,
                 max_span_length: int = 30,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder
        self._max_span_length = max_span_length

        self.qa_outputs = torch.nn.Linear(
            self._text_field_embedder.get_output_dim(), 2)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._span_qa_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            context: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # the `context` is the concact of `question` and `passage`, so we just use `context`
        batch_size, num_of_passage_tokens = context['tokens'].size()

        # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of
        # start and end spans.
        embedded_passage = self._text_field_embedder(context)
        passage_length = embedded_passage.size(1)
        logits = self.qa_outputs(embedded_passage)
        start_logits, end_logits = logits.split(1, dim=-1)
        span_start_logits = start_logits.squeeze(-1)
        span_end_logits = end_logits.squeeze(-1)

        # Adding some masks with numerically stable values
        passage_mask = util.get_text_field_mask(passage).float()
        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            batch_size, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_start_probs = util.masked_softmax(span_start_logits,
                                               repeated_passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)
        span_end_probs = util.masked_softmax(span_end_logits,
                                             repeated_passage_mask)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict: Dict[str, Any] = {}

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end],
                                                     -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on span qa and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])

                passage_words = metadata[i]["paragraph_words"]
                answer_offset = metadata[i]["answer_offset"]
                tok_to_word_index = metadata[i]["tok_to_word_index"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_position = tok_to_word_index[predicted_span[0] -
                                                   answer_offset]
                end_position = tok_to_word_index[predicted_span[1] -
                                                 answer_offset]
                best_span_str = " ".join(
                    passage_words[start_position:end_position + 1])
                output_dict["best_span_str"].append(best_span_str)
                answer_text = metadata[i].get("answer_text", [])
                if answer_text:
                    answer_text = [answer_text]
                    self._span_qa_metrics(best_span_str, answer_text)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._span_qa_metrics.get_metric(reset)
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "em": exact_match,
            "f1": f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = (torch.triu(
            torch.ones((passage_length, passage_length),
                       device=device)).log().unsqueeze(0))
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)
Example #15
0
class RobertaSpanPredictionModel(Model):
    """

    """
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model: str = None,
                 requires_grad: bool = True,
                 transformer_weights_model: str = None,
                 layer_freeze_regexes: List[str] = None,
                 on_load: bool = False,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        if on_load:
            logging.info(f"Skipping loading of initial Transformer weights")
            transformer_config = RobertaConfig.from_pretrained(
                pretrained_model)
            self._transformer_model = RobertaModel(transformer_config)

        elif transformer_weights_model:
            logging.info(
                f"Loading Transformer weights model from {transformer_weights_model}"
            )
            transformer_model_loaded = load_archive(transformer_weights_model)
            self._transformer_model = transformer_model_loaded.model._transformer_model
        else:
            self._transformer_model = RobertaModel.from_pretrained(
                pretrained_model)

        for name, param in self._transformer_model.named_parameters():
            grad = requires_grad
            if layer_freeze_regexes and grad:
                grad = not any(
                    [bool(re.search(r, name)) for r in layer_freeze_regexes])
            param.requires_grad = grad

        transformer_config = self._transformer_model.config
        num_labels = 2  # For start/end
        self.qa_outputs = Linear(transformer_config.hidden_size, num_labels)

        # Import GTP2 machinery to get from tokens to actual text
        self.byte_decoder = {v: k for k, v in bytes_to_unicode().items()}

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._debug = 2
        self._padding_value = 1  # The index of the RoBERTa padding token

    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                segment_ids: torch.LongTensor = None,
                start_positions: torch.LongTensor = None,
                end_positions: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = tokens['tokens']

        batch_size = input_ids.size(0)
        num_choices = input_ids.size(1)

        tokens_mask = (input_ids != self._padding_value).long()

        if self._debug > 0:
            print(f"batch_size = {batch_size}")
            print(f"num_choices = {num_choices}")
            print(f"tokens_mask = {tokens_mask}")
            print(f"input_ids.size() = {input_ids.size()}")
            print(f"input_ids = {input_ids}")
            print(f"segment_ids = {segment_ids}")
            print(f"start_positions = {start_positions}")
            print(f"end_positions = {end_positions}")

        # Segment ids are not used by RoBERTa

        transformer_outputs = self._transformer_model(
            input_ids=input_ids,
            # token_type_ids=segment_ids,
            attention_mask=tokens_mask)
        sequence_output = transformer_outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        span_start_logits = util.replace_masked_values(start_logits,
                                                       tokens_mask, -1e7)
        span_end_logits = util.replace_masked_values(end_logits, tokens_mask,
                                                     -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)
        span_start_probs = util.masked_softmax(span_start_logits, tokens_mask)
        span_end_probs = util.masked_softmax(span_end_logits, tokens_mask)
        output_dict = {
            "start_logits": start_logits,
            "end_logits": end_logits,
            "best_span": best_span
        }
        output_dict["start_probs"] = span_start_probs
        output_dict["end_probs"] = span_end_probs

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            self._span_start_accuracy(span_start_logits, start_positions)
            self._span_end_accuracy(span_end_logits, end_positions)
            self._span_accuracy(
                best_span,
                torch.cat([
                    start_positions.unsqueeze(-1),
                    end_positions.unsqueeze(-1)
                ], -1))

            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
            # Should we mask out invalid positions here?
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            output_dict["loss"] = total_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['exact_match'] = []
            output_dict['f1_score'] = []
            tokens_texts = []
            for i in range(batch_size):
                tokens_text = metadata[i]['tokens']
                tokens_texts.append(tokens_text)
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                predicted_start = predicted_span[0]
                predicted_end = predicted_span[1]
                predicted_tokens = tokens_text[predicted_start:(predicted_end +
                                                                1)]
                best_span_string = self.convert_tokens_to_string(
                    predicted_tokens)
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                exact_match = 0
                f1_score = 0
                if answer_texts:
                    exact_match, f1_score = self._squad_metrics(
                        best_span_string, answer_texts)
                output_dict['exact_match'].append(exact_match)
                output_dict['f1_score'].append(f1_score)
            output_dict['tokens_texts'] = tokens_texts

        if self._debug > 0:
            print(f"output_dict = {output_dict}")

        return output_dict

    def convert_tokens_to_string(self, tokens):
        """ Converts a sequence of tokens (string) in a single string. """
        text = ''.join(tokens)
        text = bytearray([self.byte_decoder[c]
                          for c in text]).decode('utf-8', errors='replace')
        return text

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @classmethod
    def _load(cls,
              config: Params,
              serialization_dir: str,
              weights_file: str = None,
              cuda_device: int = -1,
              **kwargs) -> 'Model':
        model_params = config.get('model')
        model_params.update({"on_load": True})
        config.update({'model': model_params})
        return super()._load(config=config,
                             serialization_dir=serialization_dir,
                             weights_file=weights_file,
                             cuda_device=cuda_device,
                             **kwargs)
Example #16
0
    def __init__(self,
                 vocab: Vocabulary,
                 model_name: str = None,
                 start_attention: Attention = None,
                 end_attention: Attention = None,
                 text_field_embedder: TextFieldEmbedder = None,
                 task_pretrained_file: str = None,
                 neg_sample_ratio: float = 0.0,
                 max_turn_len: int = 3,
                 start_token: str = "[CLS]",
                 end_token: str = "[SEP]",
                 index_name: str = "bert",
                 eps: float = 1e-8,
                 seed: int = 42,
                 loss_factor: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: RegularizerApplicator = None):
        super().__init__(vocab, regularizer)
        if model_name is None and text_field_embedder is None:
            raise ValueError(
                f"`model_name` and `text_field_embedder` can't both equal to None."
            )
        # 单纯的resolution任务,只需要返回最后一层的embedding表征即可
        self._text_field_embedder = text_field_embedder or PretrainedChineseBertMismatchedEmbedder(
            model_name,
            return_all=False,
            output_hidden_states=False,
            max_turn_length=max_turn_len)

        seed_everything(seed)
        self._neg_sample_ratio = neg_sample_ratio
        self._start_token = start_token
        self._end_token = end_token
        self._index_name = index_name
        self._initializer = initializer

        linear_input_size = self._text_field_embedder.get_output_dim()
        # 使用attention的方法
        self.start_attention = start_attention or BilinearAttention(
            vector_dim=linear_input_size, matrix_dim=linear_input_size)
        self.end_attention = end_attention or BilinearAttention(
            vector_dim=linear_input_size, matrix_dim=linear_input_size)
        # mask的指标,主要考虑F-score,而且我们更加关注`1`的召回率
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._rewrite_em = RewriteEM(valid_keys="semr,nr_semr,re_semr")
        self._restore_score = RestorationScore(compute_restore_tokens=True)
        self._metrics = [
            TokenBasedBLEU(mode="1,2"),
            TokenBasedROUGE(mode="1r,2r")
        ]
        self._eps = eps
        self._loss_factor = loss_factor

        self._initializer(self.start_attention)
        self._initializer(self.end_attention)

        # 加载其他任务预训练的模型
        if task_pretrained_file is not None and os.path.isfile(
                task_pretrained_file):
            logger.info("loading related task pretrained weights...")
            self.load_state_dict(torch.load(task_pretrained_file),
                                 strict=False)
Example #17
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
                                                      num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
                }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
Example #18
0
class BertMCQAModel(Model):
    """
    """
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model: str = None,
                 requires_grad: bool = True,
                 top_layer_only: bool = True,
                 bert_weights_model: str = None,
                 per_choice_loss: bool = False,
                 layer_freeze_regexes: List[str] = None,
                 regularizer: Optional[RegularizerApplicator] = None,
                 use_comparative_bert: bool = True,
                 use_bilinear_classifier: bool = False,
                 train_comparison_layer: bool = False,
                 number_of_choices_compared: int = 0,
                 comparison_layer_hidden_size: int = -1,
                 comparison_layer_use_relu: bool = True) -> None:
        super().__init__(vocab, regularizer)

        self._use_comparative_bert = use_comparative_bert
        self._use_bilinear_classifier = use_bilinear_classifier
        self._train_comparison_layer = train_comparison_layer
        if train_comparison_layer:
            assert number_of_choices_compared > 1
            self._num_choices = number_of_choices_compared
            self._comparison_layer_hidden_size = comparison_layer_hidden_size
            self._comparison_layer_use_relu = comparison_layer_use_relu

        # Bert weights and config
        if bert_weights_model:
            logging.info(f"Loading BERT weights model from {bert_weights_model}")
            bert_model_loaded = load_archive(bert_weights_model)
            self._bert_model = bert_model_loaded.model._bert_model
        else:
            self._bert_model = BertModel.from_pretrained(pretrained_model)

        for param in self._bert_model.parameters():
            param.requires_grad = requires_grad
        #for name, param in self._bert_model.named_parameters():
        #    grad = requires_grad
        #    if layer_freeze_regexes and grad:
        #        grad = not any([bool(re.search(r, name)) for r in layer_freeze_regexes])
        #    param.requires_grad = grad

        bert_config = self._bert_model.config
        self._output_dim = bert_config.hidden_size
        self._dropout = torch.nn.Dropout(bert_config.hidden_dropout_prob)
        self._per_choice_loss = per_choice_loss

        # Bert Classifier selector
        final_output_dim = 1
        if not use_comparative_bert:
            if bert_weights_model and hasattr(bert_model_loaded.model, "_classifier"):
                self._classifier = bert_model_loaded.model._classifier
            else:
                self._classifier = Linear(self._output_dim, final_output_dim)
        else:
            if use_bilinear_classifier:
                self._classifier = Bilinear(self._output_dim, self._output_dim, final_output_dim)
            else:
                self._classifier = Linear(self._output_dim * 2, final_output_dim)
        self._classifier.apply(self._bert_model.init_bert_weights)

        # Comparison layer setup
        if self._train_comparison_layer:
            number_of_pairs = self._num_choices * (self._num_choices - 1)
            if self._comparison_layer_hidden_size == -1:
                self._comparison_layer_hidden_size = number_of_pairs * number_of_pairs

            self._comparison_layer_1 = Linear(number_of_pairs, self._comparison_layer_hidden_size)
            if self._comparison_layer_use_relu:
                self._comparison_layer_1_activation = torch.nn.LeakyReLU()
            else:
                self._comparison_layer_1_activation = torch.nn.Tanh()
            self._comparison_layer_2 = Linear(self._comparison_layer_hidden_size, self._num_choices)
            self._comparison_layer_2_activation = torch.nn.Softmax()

        # Scalar mix, if necessary
        self._all_layers = not top_layer_only
        if self._all_layers:
            if bert_weights_model and hasattr(bert_model_loaded.model, "_scalar_mix") \
                    and bert_model_loaded.model._scalar_mix is not None:
                self._scalar_mix = bert_model_loaded.model._scalar_mix
            else:
                num_layers = bert_config.num_hidden_layers
                initial_scalar_parameters = num_layers * [0.0]
                initial_scalar_parameters[-1] = 5.0  # Starts with most mass on last layer
                self._scalar_mix = ScalarMix(bert_config.num_hidden_layers,
                                             initial_scalar_parameters=initial_scalar_parameters,
                                             do_layer_norm=False)
        else:
            self._scalar_mix = None

        # Accuracy and loss setup
        if self._train_comparison_layer:
            self._accuracy = CategoricalAccuracy()
            self._loss = torch.nn.CrossEntropyLoss()
        else:
            self._accuracy = BooleanAccuracy()
            self._loss = torch.nn.BCEWithLogitsLoss()
        self._debug = -1

    def _extract_last_token_pooled_output(self, encoded_layers, question_mask):
        """
        Extract the output vector for the last token in the sentence -
            similarly to how pooled_output is extracted for us when calling 'bert_model'.
        We need the question mask to find the last actual (non-padding) token
        :return:
        """

        if self._all_layers:
            encoded_layers = encoded_layers[-1]

        # A cool trick to extract the last "True" item in each row
        question_mask = question_mask.squeeze()
        # We already asserted this at batch_size == 1, but why not
        assert question_mask.dim() == 2
        shifted_matrix = question_mask.roll(-1, 1)
        shifted_matrix[:, -1] = 0
        last_item_indices = question_mask - shifted_matrix

        # TODO: This row, for some reason, didn't work as expected, but it is much better then the implementation that follows
        # last_token_tensor = encoded_layers[last_item_indices]

        num_pairs, token_number, hidden_size = encoded_layers.size()
        assert last_item_indices.size() == (num_pairs, token_number)
        # Don't worry, expand doesn't allocate new memory, it simply views the tensor differently
        expanded_last_item_indices = last_item_indices.unsqueeze(2).expand(num_pairs, token_number, hidden_size)
        last_token_tensor = encoded_layers.masked_select(expanded_last_item_indices.byte())
        last_token_tensor = last_token_tensor.reshape(num_pairs, hidden_size)

        pooled_output = self._bert_model.pooler.dense(last_token_tensor)
        pooled_output = self._bert_model.pooler.activation(pooled_output)

        return pooled_output

    def forward(self,
                question: Dict[str, torch.LongTensor],
                choice1_indexes: List[int] = None,
                choice2_indexes: List[int] = None,
                label: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = question['bert']

        # input_ids.size() == (batch_size, num_pairs, max_sentence_length)
        batch_size, num_pairs, _ = question['bert'].size()
        question_mask = (input_ids != 0).long()

        if self._train_comparison_layer:
            assert num_pairs == self._num_choices * (self._num_choices - 1)

        # Segment ids
        real_segment_ids = question['bert-type-ids'].clone()
        # Change the last 'SEP' to belong to the second answer (for symmetry)
        last_seps = (real_segment_ids.roll(-1) == 2) & (real_segment_ids == 1)
        real_segment_ids[last_seps] = 2
        # Update segment ids so that they are '1' for answers and '0' for the question
        real_segment_ids = (real_segment_ids == 0) | (real_segment_ids == 2)
        real_segment_ids = real_segment_ids.long()

        # TODO: How to extract last token pooled output if batch size != 1
        assert batch_size == 1

        # Run model
        encoded_layers, first_vectors_pooled_output = self._bert_model(input_ids=util.combine_initial_dims(input_ids),
                                            token_type_ids=util.combine_initial_dims(real_segment_ids),
                                            attention_mask=util.combine_initial_dims(question_mask),
                                            output_all_encoded_layers=self._all_layers)

        if self._use_comparative_bert:
            last_vectors_pooled_output = self._extract_last_token_pooled_output(encoded_layers, question_mask)
        else:
            last_vectors_pooled_output = None
        if self._all_layers:
            mixed_layer = self._scalar_mix(encoded_layers, question_mask)
            first_vectors_pooled_output = self._bert_model.pooler(mixed_layer)

        # Apply dropout
        first_vectors_pooled_output = self._dropout(first_vectors_pooled_output)
        if self._use_comparative_bert:
            last_vectors_pooled_output = self._dropout(last_vectors_pooled_output)

        # Classify
        if not self._use_comparative_bert:
            pair_label_logits = self._classifier(first_vectors_pooled_output)
        else:
            if self._use_bilinear_classifier:
                pair_label_logits = self._classifier(first_vectors_pooled_output, last_vectors_pooled_output)
            else:
                all_pooled_output = torch.cat((first_vectors_pooled_output, last_vectors_pooled_output), 1)
                pair_label_logits = self._classifier(all_pooled_output)

        pair_label_logits = pair_label_logits.view(-1, num_pairs)

        pair_label_probs = torch.sigmoid(pair_label_logits)

        output_dict = {}
        pair_label_probs_flat = pair_label_probs.squeeze(1)
        output_dict['pair_label_probs'] = pair_label_probs_flat.view(-1, num_pairs)
        output_dict['pair_label_logits'] = pair_label_logits
        output_dict['choice1_indexes'] = choice1_indexes
        output_dict['choice2_indexes'] = choice2_indexes

        if not self._train_comparison_layer:
            if label is not None:
                label = label.unsqueeze(1)
                label = label.expand(-1, num_pairs)
                relevant_pairs = (choice1_indexes == label) | (choice2_indexes == label)
                relevant_probs = pair_label_probs[relevant_pairs]
                choice1_is_the_label = (choice1_indexes == label)[relevant_pairs]
                # choice1_is_the_label = choice1_is_the_label.type_as(relevant_logits)

                loss = self._loss(relevant_probs, choice1_is_the_label.float())
                self._accuracy(relevant_probs >= 0.5, choice1_is_the_label)
                output_dict["loss"] = loss

            return output_dict
        else:
            choice_logits = self._comparison_layer_2(self._comparison_layer_1_activation(self._comparison_layer_1(
                pair_label_probs)))
            output_dict['choice_logits'] = choice_logits
            output_dict['choice_probs'] = torch.softmax(choice_logits, 1)
            output_dict['predicted_choice'] = torch.argmax(choice_logits, 1)

            if label is not None:
                loss = self._loss(choice_logits, label)
                self._accuracy(choice_logits, label)
                output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'EM': self._accuracy.get_metric(reset),
        }
Example #19
0
class BidirectionalAttentionFlow(Model):
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            char_field_embedder: TextFieldEmbedder,
            # num_highway_layers: int,
            phrase_layer: Seq2SeqEncoder,
            char_rnn: Seq2SeqEncoder,
            hops: int,
            hidden_dim: int,
            dropout: float = 0.2,
            mask_lstms: bool = True,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._char_field_embedder = char_field_embedder
        self._features_embedder = nn.Embedding(2, 5)
        # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim() + 5 * 3,
        #                                               num_highway_layers))
        self._phrase_layer = phrase_layer
        self._encoding_dim = phrase_layer.get_output_dim()
        # self._stacked_brnn = PytorchSeq2SeqWrapper(
        #     StackedBidirectionalLstm(input_size=self._encoding_dim, hidden_size=hidden_dim,
        #                              num_layers=3, recurrent_dropout_probability=0.2))
        self._char_rnn = char_rnn

        self.hops = hops

        self.interactive_aligners = nn.ModuleList()
        self.interactive_SFUs = nn.ModuleList()
        self.self_aligners = nn.ModuleList()
        self.self_SFUs = nn.ModuleList()
        self.aggregate_rnns = nn.ModuleList()
        for i in range(hops):
            # interactive aligner
            self.interactive_aligners.append(
                layers.SeqAttnMatch(self._encoding_dim))
            self.interactive_SFUs.append(
                layers.SFU(self._encoding_dim, 3 * self._encoding_dim))
            # self aligner
            self.self_aligners.append(layers.SelfAttnMatch(self._encoding_dim))
            self.self_SFUs.append(
                layers.SFU(self._encoding_dim, 3 * self._encoding_dim))
            # aggregating
            self.aggregate_rnns.append(
                PytorchSeq2SeqWrapper(
                    nn.LSTM(input_size=self._encoding_dim,
                            hidden_size=hidden_dim,
                            num_layers=1,
                            dropout=0.2,
                            bidirectional=True,
                            batch_first=True)))

        # Memmory-based Answer Pointer
        self.mem_ans_ptr = layers.MemoryAnsPointer(x_size=self._encoding_dim,
                                                   y_size=self._encoding_dim,
                                                   hidden_size=hidden_dim,
                                                   hop=hops,
                                                   dropout_rate=0.2,
                                                   normalize=True)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno: torch.IntTensor = None,
            question_tf: torch.FloatTensor = None,
            passage_tf: torch.FloatTensor = None,
            q_em_cased: torch.IntTensor = None,
            p_em_cased: torch.IntTensor = None,
            q_em_uncased: torch.IntTensor = None,
            p_em_uncased: torch.IntTensor = None,
            q_in_lemma: torch.IntTensor = None,
            p_in_lemma: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        x1_c_emb = self._dropout(self._char_field_embedder(passage))
        x2_c_emb = self._dropout(self._char_field_embedder(question))

        # embedded_question = torch.cat([self._dropout(self._text_field_embedder(question)),
        #                                self._features_embedder(q_em_cased),
        #                                self._features_embedder(q_em_uncased),
        #                                self._features_embedder(q_in_lemma),
        #                                question_tf.unsqueeze(2)], dim=2)
        # embedded_passage = torch.cat([self._dropout(self._text_field_embedder(passage)),
        #                               self._features_embedder(p_em_cased),
        #                               self._features_embedder(p_em_uncased),
        #                               self._features_embedder(p_in_lemma),
        #                               passage_tf.unsqueeze(2)], dim=2)
        token_emb_q = self._dropout(self._text_field_embedder(question))
        token_emb_c = self._dropout(self._text_field_embedder(passage))
        token_emb_question, q_ner_and_pos = torch.split(token_emb_q, [300, 40],
                                                        dim=2)
        token_emb_passage, p_ner_and_pos = torch.split(token_emb_c, [300, 40],
                                                       dim=2)
        question_word_features = torch.cat([
            q_ner_and_pos,
            self._features_embedder(q_em_cased),
            self._features_embedder(q_em_uncased),
            self._features_embedder(q_in_lemma),
            question_tf.unsqueeze(2)
        ],
                                           dim=2)
        passage_word_features = torch.cat([
            p_ner_and_pos,
            self._features_embedder(p_em_cased),
            self._features_embedder(p_em_uncased),
            self._features_embedder(p_in_lemma),
            passage_tf.unsqueeze(2)
        ],
                                          dim=2)

        # embedded_question = self._highway_layer(embedded_q)
        # embedded_passage = self._highway_layer(embedded_q)

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        char_features_c = self._char_rnn(
            x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1),
                              x1_c_emb.size(2), x1_c_emb.size(3))),
            passage_lstm_mask.unsqueeze(2).repeat(
                1, 1, x1_c_emb.size(2)).reshape(
                    (x1_c_emb.size(0) * x1_c_emb.size(1),
                     x1_c_emb.size(2)))).reshape(
                         (x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2),
                          -1))[:, :, -1, :]
        char_features_q = self._char_rnn(
            x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1),
                              x2_c_emb.size(2), x2_c_emb.size(3))),
            question_lstm_mask.unsqueeze(2).repeat(
                1, 1, x2_c_emb.size(2)).reshape(
                    (x2_c_emb.size(0) * x2_c_emb.size(1),
                     x2_c_emb.size(2)))).reshape(
                         (x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2),
                          -1))[:, :, -1, :]

        # token_emb_q, char_emb_q, question_word_features = torch.split(embedded_question, [300, 300, 56], dim=2)
        # token_emb_c, char_emb_c, passage_word_features = torch.split(embedded_passage, [300, 300, 56], dim=2)

        # char_features_q = self._char_rnn(char_emb_q, question_lstm_mask)
        # char_features_c = self._char_rnn(char_emb_c, passage_lstm_mask)

        emb_question = torch.cat(
            [token_emb_question, char_features_q, question_word_features],
            dim=2)
        emb_passage = torch.cat(
            [token_emb_passage, char_features_c, passage_word_features], dim=2)

        encoded_question = self._dropout(
            self._phrase_layer(emb_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(emb_passage, passage_lstm_mask))

        batch_size = encoded_question.size(0)
        passage_length = encoded_passage.size(1)

        encoding_dim = encoded_question.size(-1)

        # c_check = self._stacked_brnn(encoded_passage, passage_lstm_mask)
        # q = self._stacked_brnn(encoded_question, question_lstm_mask)
        c_check = encoded_passage
        q = encoded_question
        for i in range(self.hops):
            q_tilde = self.interactive_aligners[i].forward(
                c_check, q, question_mask)
            c_bar = self.interactive_SFUs[i].forward(
                c_check,
                torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2))
            c_tilde = self.self_aligners[i].forward(c_bar, passage_mask)
            c_hat = self.self_SFUs[i].forward(
                c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde],
                                 2))
            c_check = self.aggregate_rnns[i].forward(c_hat, passage_mask)

        # Predict
        start_scores, end_scores, yesno_scores = self.mem_ans_ptr.forward(
            c_check, q, passage_mask, question_mask)

        best_span, yesno_predict, loc = self.get_best_span(
            start_scores, end_scores, yesno_scores)

        output_dict = {
            "span_start_logits": start_scores,
            "span_end_logits": end_scores,
            "best_span": best_span
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(start_scores, span_start.squeeze(-1))
            self._span_start_accuracy(start_scores, span_start.squeeze(-1))
            loss += nll_loss(end_scores, span_end.squeeze(-1))
            self._span_end_accuracy(end_scores, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

            gold_span_end_loc = []
            span_end = span_end.view(batch_size).squeeze().data.cpu().numpy()
            for i in range(batch_size):
                gold_span_end_loc.append(
                    max(span_end[i] + i * passage_length, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(_yesno, yesno.view(-1), ignore_index=-1)

            pred_span_end_loc = []
            for i in range(batch_size):
                pred_span_end_loc.append(max(loc[i], 0))
            predicted_end = span_start.new(pred_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(0,
                                                           predicted_end).view(
                                                               -1, 3)
            self._span_yesno_accuracy(_yesno, yesno.squeeze(-1))

            output_dict['loss'] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['yesno'] = yesno_predict
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            "yesno": self._span_yesno_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [
            self.vocab.get_token_from_index(x, namespace="yesno_labels")
            for x in output_dict.pop("yesno")
        ]
        output_dict['yesno'] = yesno_tags
        return output_dict

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor,
                      yesno_scores: torch.Tensor):
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2),
                                                     dtype=torch.long)
        yesno_predict = span_start_logits.new_zeros(batch_size,
                                                    dtype=torch.long)
        loc = yesno_scores.new_zeros(batch_size, dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()
        yesno_logits = yesno_scores.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
                    yesno_predict[b] = int(np.argmax(yesno_logits[b, j]))
                    loc[b] = j + passage_length * b
        return best_word_span, yesno_predict, loc
Example #20
0
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model: str = None,
                 requires_grad: bool = True,
                 top_layer_only: bool = True,
                 bert_weights_model: str = None,
                 per_choice_loss: bool = False,
                 layer_freeze_regexes: List[str] = None,
                 regularizer: Optional[RegularizerApplicator] = None,
                 use_comparative_bert: bool = True,
                 use_bilinear_classifier: bool = False,
                 train_comparison_layer: bool = False,
                 number_of_choices_compared: int = 0,
                 comparison_layer_hidden_size: int = -1,
                 comparison_layer_use_relu: bool = True) -> None:
        super().__init__(vocab, regularizer)

        self._use_comparative_bert = use_comparative_bert
        self._use_bilinear_classifier = use_bilinear_classifier
        self._train_comparison_layer = train_comparison_layer
        if train_comparison_layer:
            assert number_of_choices_compared > 1
            self._num_choices = number_of_choices_compared
            self._comparison_layer_hidden_size = comparison_layer_hidden_size
            self._comparison_layer_use_relu = comparison_layer_use_relu

        # Bert weights and config
        if bert_weights_model:
            logging.info(f"Loading BERT weights model from {bert_weights_model}")
            bert_model_loaded = load_archive(bert_weights_model)
            self._bert_model = bert_model_loaded.model._bert_model
        else:
            self._bert_model = BertModel.from_pretrained(pretrained_model)

        for param in self._bert_model.parameters():
            param.requires_grad = requires_grad
        #for name, param in self._bert_model.named_parameters():
        #    grad = requires_grad
        #    if layer_freeze_regexes and grad:
        #        grad = not any([bool(re.search(r, name)) for r in layer_freeze_regexes])
        #    param.requires_grad = grad

        bert_config = self._bert_model.config
        self._output_dim = bert_config.hidden_size
        self._dropout = torch.nn.Dropout(bert_config.hidden_dropout_prob)
        self._per_choice_loss = per_choice_loss

        # Bert Classifier selector
        final_output_dim = 1
        if not use_comparative_bert:
            if bert_weights_model and hasattr(bert_model_loaded.model, "_classifier"):
                self._classifier = bert_model_loaded.model._classifier
            else:
                self._classifier = Linear(self._output_dim, final_output_dim)
        else:
            if use_bilinear_classifier:
                self._classifier = Bilinear(self._output_dim, self._output_dim, final_output_dim)
            else:
                self._classifier = Linear(self._output_dim * 2, final_output_dim)
        self._classifier.apply(self._bert_model.init_bert_weights)

        # Comparison layer setup
        if self._train_comparison_layer:
            number_of_pairs = self._num_choices * (self._num_choices - 1)
            if self._comparison_layer_hidden_size == -1:
                self._comparison_layer_hidden_size = number_of_pairs * number_of_pairs

            self._comparison_layer_1 = Linear(number_of_pairs, self._comparison_layer_hidden_size)
            if self._comparison_layer_use_relu:
                self._comparison_layer_1_activation = torch.nn.LeakyReLU()
            else:
                self._comparison_layer_1_activation = torch.nn.Tanh()
            self._comparison_layer_2 = Linear(self._comparison_layer_hidden_size, self._num_choices)
            self._comparison_layer_2_activation = torch.nn.Softmax()

        # Scalar mix, if necessary
        self._all_layers = not top_layer_only
        if self._all_layers:
            if bert_weights_model and hasattr(bert_model_loaded.model, "_scalar_mix") \
                    and bert_model_loaded.model._scalar_mix is not None:
                self._scalar_mix = bert_model_loaded.model._scalar_mix
            else:
                num_layers = bert_config.num_hidden_layers
                initial_scalar_parameters = num_layers * [0.0]
                initial_scalar_parameters[-1] = 5.0  # Starts with most mass on last layer
                self._scalar_mix = ScalarMix(bert_config.num_hidden_layers,
                                             initial_scalar_parameters=initial_scalar_parameters,
                                             do_layer_norm=False)
        else:
            self._scalar_mix = None

        # Accuracy and loss setup
        if self._train_comparison_layer:
            self._accuracy = CategoricalAccuracy()
            self._loss = torch.nn.CrossEntropyLoss()
        else:
            self._accuracy = BooleanAccuracy()
            self._loss = torch.nn.BCEWithLogitsLoss()
        self._debug = -1
Example #21
0
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            phrase_layer: Seq2SeqEncoder,
            passage_self_attention: Seq2SeqEncoder,
            semantic_rep_layer: Seq2SeqEncoder,
            contextual_question_layer: Seq2SeqEncoder,
            dropout: float = 0.2,
            mask_lstms: bool = True,
            regularizer: Optional[RegularizerApplicator] = None,
            initializer: InitializerApplicator = InitializerApplicator(),
    ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks,
              self).__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
        #                                               num_highway_layers))
        self._encoding_dim = self._phrase_layer.get_output_dim()
        # self._atten_linear_layer = TimeDistributed(torch.nn.Linear(in_features=self._encoding_dim,
        #                                                            out_features=self._encoding_dim, bias=False))
        self._atten_linear_layer = torch.nn.Linear(
            in_features=self._encoding_dim,
            out_features=self._encoding_dim,
            bias=False)
        self._relu = torch.nn.ReLU()
        self._softmax_d1 = torch.nn.Softmax(dim=1)
        self._softmax_d2 = torch.nn.Softmax(dim=2)

        self._atten_fusion = FusionLayer(self._encoding_dim)

        self._tanh = torch.nn.Tanh()
        self._sigmoid = torch.nn.Sigmoid()

        self._passage_self_attention = passage_self_attention

        # self._self_atten_layer = SelfAttentionLayer(self._encoding_dim)
        self._self_atten_layer = torch.nn.Bilinear(self._encoding_dim,
                                                   self._encoding_dim,
                                                   self._encoding_dim,
                                                   bias=False)
        self._self_atten_fusion = FusionLayer(self._encoding_dim)

        self._semantic_rep_layer = semantic_rep_layer
        self._contextual_question_layer = contextual_question_layer

        # self._vector_linear = TimeDistributed(
        #     torch.nn.Linear(in_features=self._encoding_dim, out_features=1, bias=False))
        #
        # self._model_layer_s = TimeDistributed(
        #     torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False))
        # self._model_layer_e = TimeDistributed(
        #     torch.nn.Linear(in_features=self._encoding_dim, out_features=self._encoding_dim, bias=False))
        self._vector_linear = torch.nn.Linear(in_features=self._encoding_dim,
                                              out_features=1,
                                              bias=False)

        self._model_layer_s = torch.nn.Linear(in_features=self._encoding_dim,
                                              out_features=self._encoding_dim,
                                              bias=False)
        self._model_layer_e = torch.nn.Linear(in_features=self._encoding_dim,
                                              out_features=self._encoding_dim,
                                              bias=False)
        # self._model_layer_s = torch.nn.Bilinear(in1_features=self._encoding_dim, in2_features=self._encoding_dim, out_features=)
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        self._mask_lstms = mask_lstms
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)
Example #22
0
class DialogQA(Model):
    """
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    max_turn_length: ``int``, optional (default=12)
        Maximum length of an interaction.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        phrase_layer: Seq2SeqEncoder,
        residual_encoder: Seq2SeqEncoder,
        span_start_encoder: Seq2SeqEncoder,
        span_end_encoder: Seq2SeqEncoder,
        initializer: Optional[InitializerApplicator] = None,
        dropout: float = 0.2,
        num_context_answers: int = 0,
        marker_embedding_dim: int = 10,
        max_span_length: int = 30,
        max_turn_length: int = 12,
    ) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim,
                                                       self._encoding_dim,
                                                       "x,y,x*y")
        self._merge_atten = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(
                max_turn_length, marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding(
                (num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim,
                                                     self._encoding_dim,
                                                     "x,y,x*y")

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(
            phrase_layer.get_input_dim(),
            text_field_embedder.get_output_dim() +
            marker_embedding_dim * num_context_answers,
            "phrase layer input dim",
            "embedding dim + marker dim * num context answers",
        )

        if initializer is not None:
            initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        p1_answer_marker: torch.IntTensor = None,
        p2_answer_marker: torch.IntTensor = None,
        p3_answer_marker: torch.IntTensor = None,
        yesno_list: torch.IntTensor = None,
        followup_list: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        token_character_ids = question["token_characters"]["token_characters"]
        batch_size, max_qa_count, max_q_len, _ = token_character_ids.size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question,
                                                      num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
            self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(
            self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1)
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage)

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(
                max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(
                question_num_ind)
            embedded_question = torch.cat(
                [embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = (embedded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1,
                1).view(total_qa_count, passage_length,
                        self._text_field_embedder.get_output_dim()))
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count,
                                                     passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat(
                [repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(
                    total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat(
                    [repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(
                        total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(
                        p3_answer_marker)
                    repeated_embedded_passage = torch.cat(
                        [repeated_embedded_passage, p3_answer_marker_emb],
                        dim=-1)

            repeated_encoded_passage = self._variational_dropout(
                self._phrase_layer(repeated_embedded_passage,
                                   repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(
                self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(
                total_qa_count, passage_length, self._encoding_dim)

        encoded_question = self._variational_dropout(
            self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(
            repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(total_qa_count, passage_length, self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat(
            [
                repeated_encoded_passage,
                passage_question_vectors,
                repeated_encoded_passage * passage_question_vectors,
                repeated_encoded_passage * tiled_question_passage_vector,
            ],
            dim=-1,
        )

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(
            self._residual_encoder(final_merged_passage,
                                   repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer,
                                                     residual_layer)

        mask = repeated_passage_mask.reshape(
            total_qa_count, passage_length, 1) * repeated_passage_mask.reshape(
                total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              dtype=torch.bool,
                              device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask & ~self_mask

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs,
                                           residual_layer)
        self_attention_vecs = torch.cat([
            self_attention_vecs, residual_layer,
            residual_layer * self_attention_vecs
        ],
                                        dim=-1)
        residual_layer = F.relu(
            self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage,
                                             repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(
            torch.cat([final_merged_passage, start_rep], dim=-1),
            repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(
            -1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(
            span_start_logits,
            span_end_logits,
            span_yesno_logits,
            span_followup_logits,
            self._max_span_length,
        )

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits,
                                        repeated_passage_mask),
                span_start.view(-1),
                ignore_index=-1,
            )
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits,
                                        repeated_passage_mask),
                span_end.view(-1),
                ignore_index=-1,
            )
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(
                best_span[:, 0:2],
                torch.stack([span_start, span_end],
                            -1).view(total_qa_count, 2),
                mask=qa_mask.unsqueeze(1).expand(-1, 2),
            )
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1),
                             followup_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)
            self._span_followup_accuracy(_followup,
                                         followup_list.view(-1),
                                         mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict["best_span_str"] = []
        output_dict["qid"] = []
        output_dict["followup"] = []
        output_dict["yesno"] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]["original_passage"]
            offsets = metadata[i]["token_offsets"]
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad.metric_max_over_ground_truths(
                                    squad.f1_score, best_span_string, refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad.metric_max_over_ground_truths(
                            squad.f1_score, best_span_string, answer_texts)
                self._official_f1(100 * f1_score)
            output_dict["qid"].append(per_dialog_query_id_list)
            output_dict["best_span_str"].append(per_dialog_best_span_list)
            output_dict["yesno"].append(per_dialog_yesno_list)
            output_dict["followup"].append(per_dialog_followup_list)
        return output_dict

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        yesno_tags = [[
            self.vocab.get_token_from_index(x, namespace="yesno_labels")
            for x in yn_list
        ] for yn_list in output_dict.pop("yesno")]
        followup_tags = [[
            self.vocab.get_token_from_index(x, namespace="followup_labels")
            for x in followup_list
        ] for followup_list in output_dict.pop("followup")]
        output_dict["yesno"] = yesno_tags
        output_dict["followup"] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "yesno": self._span_yesno_accuracy.get_metric(reset),
            "followup": self._span_followup_accuracy.get_metric(reset),
            "f1": self._official_f1.get_metric(reset),
        }

    @staticmethod
    def _get_best_span_yesno_followup(
        span_start_logits: torch.Tensor,
        span_end_logits: torch.Tensor,
        span_yesno_logits: torch.Tensor,
        span_followup_logits: torch.Tensor,
        max_span_length: int,
    ) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4),
                                                     dtype=torch.long)

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()
        span_followup_logits = span_followup_logits.data.cpu().numpy()
        for b_i in range(batch_size):
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
Example #23
0
        label = event_item['label']
        event_type = event_item['event_type']
        logits = predictor.predict(event, event_type)['logits']
        predict_out.append(logits)
        label_y.append(label)
        if logits == 1:
            predict_out_f1.append([0, 1])
        else:
            predict_out_f1.append([1, 0])

    predict_out = torch.LongTensor(predict_out)
    predict_out_f1 = torch.LongTensor(predict_out_f1)
    label_y = torch.LongTensor(label_y)

    # metrics
    get_accuracy = BooleanAccuracy()
    get_f1_score = F1Measure(positive_label=1)

    get_accuracy(predict_out, label_y)
    accuracy = get_accuracy.get_metric(reset=False)
    get_f1_score(predict_out_f1, label_y)
    precision, recall, f1_measure = get_f1_score.get_metric(reset=False)

    logger.debug('-------Train Metrics-------')
    for k in metrics:
        logger.debug('{}: {}'.format(k, metrics[k]))
    logger.debug('-------Test Output-------')
    logger.debug('accuracy: {}'.format(accuracy))
    logger.debug('precision: {}'.format(precision))
    logger.debug('recall: {}'.format(recall))
    logger.debug('f1_measure: {}'.format(f1_measure))
Example #24
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        phrase_layer: Seq2SeqEncoder,
        residual_encoder: Seq2SeqEncoder,
        span_start_encoder: Seq2SeqEncoder,
        span_end_encoder: Seq2SeqEncoder,
        initializer: Optional[InitializerApplicator] = None,
        dropout: float = 0.2,
        num_context_answers: int = 0,
        marker_embedding_dim: int = 10,
        max_span_length: int = 30,
        max_turn_length: int = 12,
    ) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim,
                                                       self._encoding_dim,
                                                       "x,y,x*y")
        self._merge_atten = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(
                max_turn_length, marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding(
                (num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim,
                                                     self._encoding_dim,
                                                     "x,y,x*y")

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(
            phrase_layer.get_input_dim(),
            text_field_embedder.get_output_dim() +
            marker_embedding_dim * num_context_answers,
            "phrase layer input dim",
            "embedding dim + marker dim * num context answers",
        )

        if initializer is not None:
            initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)
Example #25
0
class Seq2SeqTask(SequenceGenerationTask):
    """Sequence-to-sequence Task"""

    def __init__(self, path, max_seq_len, max_targ_v_size, name, **kw):
        super().__init__(name, **kw)
        self.scorer2 = BooleanAccuracy()
        self.scorers.append(self.scorer2)
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.max_seq_len = max_seq_len
        self._label_namespace = self.name + "_tokens"
        self.max_targ_v_size = max_targ_v_size
        self.target_indexer = {"words": SingleIdTokenIndexer(namespace=self._label_namespace)}
        self.files_by_split = {
            split: os.path.join(path, "%s.tsv" % split) for split in ["train", "val", "test"]
        }

        # The following is necessary since word-level tasks (e.g., MT) haven't been tested, yet.
        if self._tokenizer_name != "SplitChars" and self._tokenizer_name != "dummy_tokenizer_name":
            raise NotImplementedError("For now, Seq2SeqTask only supports character-level tasks.")

    def load_data(self):
        # Data is exposed as iterable: no preloading
        pass

    def get_split_text(self, split: str):
        """
        Get split text as iterable of records.
        Split should be one of 'train', 'val', or 'test'.
        """
        return self.get_data_iter(self.files_by_split[split])

    def get_all_labels(self) -> List[str]:
        """ Build character vocabulary and return it as a list """
        token2freq = collections.Counter()
        for split in ["train", "val"]:
            for _, sequence in self.get_data_iter(self.files_by_split[split]):
                for token in sequence:
                    token2freq[token] += 1
        return [t for t, _ in token2freq.most_common(self.max_targ_v_size)]

    def get_data_iter(self, path):
        """ Load data """
        with codecs.open(path, "r", "utf-8", errors="ignore") as txt_fh:
            for row in txt_fh:
                row = row.strip().split("\t")
                if len(row) < 2 or not row[0] or not row[1]:
                    continue
                src_sent = tokenize_and_truncate(self._tokenizer_name, row[0], self.max_seq_len)
                tgt_sent = tokenize_and_truncate(self._tokenizer_name, row[2], self.max_seq_len)
                yield (src_sent, tgt_sent)

    def get_sentences(self) -> Iterable[Sequence[str]]:
        """ Yield sentences, used to compute vocabulary. """
        for split in self.files_by_split:
            # Don't use test set for vocab building.
            if split.startswith("test"):
                continue
            path = self.files_by_split[split]
            yield from self.get_data_iter(path)

    def count_examples(self):
        """ Compute here b/c we're streaming the sentences. """
        example_counts = {}
        for split, split_path in self.files_by_split.items():
            example_counts[split] = sum(
                1 for _ in codecs.open(split_path, "r", "utf-8", errors="ignore")
            )
        self.example_counts = example_counts

    def process_split(
        self, split, indexers, model_preprocessing_interface
    ) -> Iterable[Type[Instance]]:
        """ Process split text into a list of AllenNLP Instances. """

        def _make_instance(input_, target):
            d = {
                "inputs": sentence_to_text_field(
                    model_preprocessing_interface.boundary_token_fn(input_), indexers
                ),
                "targs": sentence_to_text_field(
                    model_preprocessing_interface.boundary_token_fn(target), self.target_indexer
                ),
            }
            return Instance(d)

        for sent1, sent2 in split:
            yield _make_instance(sent1, sent2)

    def get_metrics(self, reset=False):
        """Get metrics specific to the task"""
        avg_nll = self.scorer1.get_metric(reset)
        acc = self.scorer2.get_metric(reset)
        return {"perplexity": math.exp(avg_nll), "accuracy": acc}

    def update_metrics(self, logits, labels, tagmask=None, predictions=None):
        # This doesn't require logits for now, since loss is updated in another part.
        assert logits is None and predictions is not None

        if labels.shape[1] < predictions.shape[2]:
            predictions = predictions[:, 0, : labels.shape[1]]
        else:
            predictions = predictions[:, 0, :]
            # Cut labels if predictions (without gold target) are shorter.
            labels = labels[:, : predictions.shape[1]]
            tagmask = tagmask[:, : predictions.shape[1]]
        self.scorer2(predictions, labels, tagmask)
        return

    def get_prediction(self, voc_src, voc_trg, inputs, gold, output):
        tokenizer = get_tokenizer(self._tokenizer_name)

        input_string = tokenizer.detokenize([voc_src[token.item()] for token in inputs]).split(
            "<EOS>"
        )[0]
        gold_string = tokenizer.detokenize([voc_trg[token.item()] for token in gold]).split(
            "<EOS>"
        )[0]
        output_string = tokenizer.detokenize([voc_trg[token.item()] for token in output]).split(
            "<EOS>"
        )[0]

        return input_string, gold_string, output_string
Example #26
0
class BidafV4(Model):
    """
    MODIFICATION NOTE:
    This class is a modification of BiDAF. In here we try to see what happens to our results
    if we convert the question encoder into a simple term frequency (bag-of-words) encoder which
    disregards word order. By doing so we analyze whether BiDAF can learn to solve SQuAD without
    having to encode the question sequentially. It has been shown in previous work that BiDAF and
    other models trained on SQuAD do not focus on questions words as we would expect them to. For
    example, they will often focus

    ORIGINAL DOCSTRING:
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:

        super(BidafV4, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(),
                               4 * encoding_dim, "modeling layer input dim",
                               "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(),
                               phrase_layer.get_input_dim(),
                               "text field embedder output dim",
                               "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(),
                               4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim",
                               "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))

        # # v5:
        # # remember to set token embeddings in the CONFIG JSON
        # encoded_question = self._dropout(embedded_question)

        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX
        similarity_matrix = self._matrix_attention(encoded_passage,
                                                   encoded_question)

        # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY
        passage_question_attention = util.last_dim_softmax(
            similarity_matrix, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Our custom query2context
        q2c_attention = util.masked_softmax(similarity_matrix,
                                            question_mask,
                                            dim=1).transpose(-1, -2)
        q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention)

        # Now we try the various variants
        # v1:
        # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention)

        # v2:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1]))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2)

        # v3:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim)

        # v4:
        # Re-application of query2context attention
        new_similarity_matrix = self._matrix_attention(encoded_passage,
                                                       q2c_vecs)
        masked_similarity = util.replace_masked_values(
            new_similarity_matrix, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # ------- Original variant
        # # We replace masked values with something really negative here, so they don't affect the
        # # max below.
        # masked_similarity = util.replace_masked_values(similarity_matrix,
        #                                                question_mask.unsqueeze(1),
        #                                                -1e7)
        # # Shape: (batch_size, passage_length)
        # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # # Shape: (batch_size, passage_length)
        # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # # Shape: (batch_size, encoding_dim)
        # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # # Shape: (batch_size, passage_length, encoding_dim)
        # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
        #                                                                             passage_length,
        #                                                                             encoding_dim)

        # ------- END

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        # original beta combination function
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        # # v6:
        # final_merged_passage = torch.cat([tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v7:
        # final_merged_passage = torch.cat([passage_question_vectors],
        #                                  dim=-1)
        #
        # # v8:
        # final_merged_passage = torch.cat([passage_question_vectors,
        #                                   tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v9:
        # final_merged_passage = torch.cat([encoded_passage,
        #                                   passage_question_vectors,
        #                                   encoded_passage * passage_question_vectors],
        #                                  dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2),
                                                     dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
Example #27
0
class TransformerQA(Model):
    """
    Registered as `"transformer_qa"`, this class implements a reading comprehension model patterned
    after the proposed model in [Devlin et al]([email protected]:huggingface/transformers.git),
    with improvements borrowed from the SQuAD model in the transformers project.

    It predicts start tokens and end tokens with a linear layer on top of word piece embeddings.

    If you want to use this model on SQuAD datasets, you can use it with the
    [`TransformerSquadReader`](../../dataset_readers/transformer_squad#transformersquadreader)
    dataset reader, registered as `"transformer_squad"`.

    Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could
    be more than one instance per question, these metrics are not the official numbers on either SQuAD task.

    To get official numbers for SQuAD v1.1, for example, you can run

    ```
    python -m allennlp_models.rc.tools.transformer_qa_eval
    ```

    # Parameters

    vocab : `Vocabulary`

    transformer_model_name : `str`, optional (default=`'bert-base-cased'`)
        This model chooses the embedder according to this setting. You probably want to make sure this is set to
        the same thing as the reader.
    """

    def __init__(
        self, vocab: Vocabulary, transformer_model_name: str = "bert-base-cased", **kwargs
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = BasicTextFieldEmbedder(
            {"tokens": PretrainedTransformerEmbedder(transformer_model_name)}
        )
        self._linear_layer = nn.Linear(self._text_field_embedder.get_output_dim(), 2)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._per_instance_metrics = SquadEmAndF1()

    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        cls_index: torch.LongTensor = None,
        answer_span: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        question_with_context : `Dict[str, torch.LongTensor]`
            From a `TextField`. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including
            `[CLS]` and `[SEP]`) has type id 1.

        context_span : `torch.IntTensor`
            From a `SpanField`. This marks the span of word pieces in `question` from which answers can come.

        cls_index : `torch.LongTensor`, optional
            A tensor of shape `(batch_size,)` that provides the index of the `[CLS]` token
            in the `question_with_context` for each instance.

            This is needed because the `[CLS]` token is used to indicate that the question
            is impossible.

            If this is `None`, it's assumed that the `[CLS]` token is at index 0 for each instance
            in the batch.

        answer_span : `torch.IntTensor`, optional
            From a `SpanField`. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.

        metadata : `List[Dict[str, Any]]`, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the `metadata` list should be the
            batch size, and each dictionary should have the keys `id`, `question`, `context`,
            `question_tokens`, `context_tokens`, and `answers`.

        # Returns

        `Dict[str, torch.Tensor]` :
            An output dictionary with the following fields:

            - span_start_logits (`torch.FloatTensor`) :
              A tensor of shape `(batch_size, passage_length)` representing unnormalized log
              probabilities of the span start position.
            - span_end_logits (`torch.FloatTensor`) :
              A tensor of shape `(batch_size, passage_length)` representing unnormalized log
              probabilities of the span end position (inclusive).
            - best_span_scores (`torch.FloatTensor`) :
              The score for each of the best spans.
            - loss (`torch.FloatTensor`, optional) :
              A scalar loss to be optimised, evaluated against `answer_span`.
            - best_span (`torch.IntTensor`, optional) :
              Provided when not in train mode and sufficient metadata given for the instance.
              The result of a constrained inference over `span_start_logits` and
              `span_end_logits` to find the most probable span.  Shape is `(batch_size, 2)`
              and each offset is a token index, unless the best span for an instance
              was predicted to be the `[CLS]` token, in which case the span will be (-1, -1).
            - best_span_str (`List[str]`, optional) :
              Provided when not in train mode and sufficient metadata given for the instance.
              This is the string from the original passage that the model thinks is the best answer
              to the question.

        """
        embedded_question = self._text_field_embedder(question_with_context)
        # shape: (batch_size, sequence_length, 2)
        logits = self._linear_layer(embedded_question)
        # shape: (batch_size, sequence_length, 1)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        # shape: (batch_size, sequence_length)
        span_start_logits = span_start_logits.squeeze(-1)
        # shape: (batch_size, sequence_length)
        span_end_logits = span_end_logits.squeeze(-1)

        # Create a mask for `question_with_context` to mask out tokens that are not part
        # of the context.
        # shape: (batch_size, sequence_length)
        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context), dtype=torch.bool
        )
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start : end + 1] = True
            # Also unmask the [CLS] token since that token is used to indicate that
            # the question is impossible.
            possible_answer_mask[i, 0 if cls_index is None else cls_index[i]] = True

        # Calculate span start and end probabilities
        # shape: (batch_size, sequence_length)
        span_start_probs = softmax(span_start_logits, dim=-1)
        # shape: (batch_size, sequence_length)
        span_end_probs = softmax(span_end_logits, dim=-1)

        # Replace the masked values with a very negative constant since we're in log-space.
        # shape: (batch_size, sequence_length)
        span_start_logits = replace_masked_values_with_big_negative_number(
            span_start_logits, possible_answer_mask
        )
        # shape: (batch_size, sequence_length)
        span_end_logits = replace_masked_values_with_big_negative_number(
            span_end_logits, possible_answer_mask
        )

        # Now calculate the best span.
        # shape: (batch_size, 2)
        best_spans = get_best_span(span_start_logits, span_end_logits)

        # Sum the span start score with the span end score to get an overall score for the span.
        # shape: (batch_size,)
        best_span_scores = torch.gather(
            span_start_logits, 1, best_spans[:, 0].unsqueeze(1)
        ) + torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        best_span_probs = torch.gather(
            span_start_probs, 1, best_spans[:, 0].unsqueeze(1)
        ) * torch.gather(span_end_probs, 1, best_spans[:, 1].unsqueeze(1))
        best_span_probs = best_span_probs.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span_scores": best_span_scores,
            "span_start_probs": span_start_probs,
            "span_end_probs": span_end_probs,
            "best_span_probs": best_span_probs,
        }

        # Compute the loss.
        if answer_span is not None:
            output_dict["loss"] = self._evaluate_span(
                best_spans, span_start_logits, span_end_logits, answer_span
            )

        # Gather the string of the best span and compute the EM and F1 against the gold span,
        # if given.
        if not self.training and metadata is not None:
            (
                output_dict["best_span_str"],
                output_dict["best_span"],
            ) = self._collect_best_span_strings(best_spans, context_span, metadata, cls_index)

        return output_dict

    def _evaluate_span(
        self,
        best_spans: torch.Tensor,
        span_start_logits: torch.Tensor,
        span_end_logits: torch.Tensor,
        answer_span: torch.Tensor,
    ) -> torch.Tensor:
        """
        Calculate the loss against the `answer_span` and also update the span metrics.
        """
        span_start = answer_span[:, 0]
        span_end = answer_span[:, 1]
        self._span_accuracy(best_spans, answer_span)

        start_loss = cross_entropy(span_start_logits, span_start, ignore_index=-1)
        big_constant = min(torch.finfo(start_loss.dtype).max, 1e9)
        assert not torch.any(start_loss > big_constant), "Start loss too high"

        end_loss = cross_entropy(span_end_logits, span_end, ignore_index=-1)
        assert not torch.any(end_loss > big_constant), "End loss too high"

        self._span_start_accuracy(span_start_logits, span_start)
        self._span_end_accuracy(span_end_logits, span_end)

        return (start_loss + end_loss) / 2

    def _collect_best_span_strings(
        self,
        best_spans: torch.Tensor,
        context_span: torch.IntTensor,
        metadata: List[Dict[str, Any]],
        cls_index: Optional[torch.LongTensor],
    ) -> Tuple[List[str], torch.Tensor]:
        """
        Collect the string of the best predicted span from the context metadata and
        update `self._per_instance_metrics`, which in the case of SQuAD v1.1 / v2.0
        includes the EM and F1 score.

        This returns a `Tuple[List[str], torch.Tensor]`, where the `List[str]` is the
        predicted answer for each instance in the batch, and the tensor is just the input
        tensor `best_spans` after adjustments so that each answer span corresponds to the
        context tokens only, and not the question tokens. Spans that correspond to the
        `[CLS]` token, i.e. the question was predicted to be impossible, will be set
        to `(-1, -1)`.
        """
        _best_spans = best_spans.detach().cpu().numpy()

        best_span_strings: List[str] = []
        best_span_strings_for_metric: List[str] = []
        answer_strings_for_metric: List[List[str]] = []

        for (metadata_entry, best_span, cspan, cls_ind) in zip(
            metadata,
            _best_spans,
            context_span,
            cls_index or (0 for _ in range(len(metadata))),
        ):
            context_tokens_for_question = metadata_entry["context_tokens"]

            if best_span[0] == cls_ind:
                # Predicting [CLS] is interpreted as predicting the question as unanswerable.
                best_span_string = ""
                # NOTE: even though we've "detached" 'best_spans' above, this still
                # modifies the original tensor in-place.
                best_span[0], best_span[1] = -1, -1
            else:
                best_span -= int(cspan[0])
                assert np.all(best_span >= 0)

                predicted_start, predicted_end = tuple(best_span)

                while (
                    predicted_start >= 0
                    and context_tokens_for_question[predicted_start].idx is None
                ):
                    predicted_start -= 1
                if predicted_start < 0:
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index "
                        f"'{best_span[0]}' to an offset in the original text."
                    )
                    character_start = 0
                else:
                    character_start = context_tokens_for_question[predicted_start].idx

                while (
                    predicted_end < len(context_tokens_for_question)
                    and context_tokens_for_question[predicted_end].idx is None
                ):
                    predicted_end += 1
                if predicted_end >= len(context_tokens_for_question):
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index "
                        f"'{best_span[1]}' to an offset in the original text."
                    )
                    character_end = len(metadata_entry["context"])
                else:
                    end_token = context_tokens_for_question[predicted_end]
                    character_end = end_token.idx + len(sanitize_wordpiece(end_token.text))

                best_span_string = metadata_entry["context"][character_start:character_end]

            best_span_strings.append(best_span_string)
            answers = metadata_entry.get("answers")
            if answers:
                best_span_strings_for_metric.append(best_span_string)
                answer_strings_for_metric.append(answers)

        if answer_strings_for_metric:
            self._per_instance_metrics(best_span_strings_for_metric, answer_strings_for_metric)

        return best_span_strings, best_spans

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        output = {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
        }
        if not self.training:
            exact_match, f1_score = self._per_instance_metrics.get_metric(reset)
            output["per_instance_em"] = exact_match
            output["per_instance_f1"] = f1_score
        return output

    default_predictor = "transformer_qa"
Example #28
0
class DialogQA(Model):
    """
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    max_turn_length: ``int``, optional (default=12)
        Maximum length of an interaction.
    """

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30,
                 max_turn_length: int = 12) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(max_turn_length,
                                                           marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3,
                                                                     self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(phrase_layer.get_input_dim(),
                               text_field_embedder.get_output_dim() +
                               marker_embedding_dim * num_context_answers,
                               "phrase layer input dim",
                               "embedding dim + marker dim * num context answers")

        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \
                      for yn_list in output_dict.pop("yesno")]
        followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \
                         for followup_list in output_dict.pop("followup")]
        output_dict['yesno'] = yesno_tags
        output_dict['followup'] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'yesno': self._span_yesno_accuracy.get_metric(reset),
                'followup': self._span_followup_accuracy.get_metric(reset),
                'f1': self._official_f1.get_metric(reset), }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      span_followup_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long)

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()
        span_followup_logits = span_followup_logits.data.cpu().numpy()
        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
Example #29
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    attention_similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these
        # aren't necessarily obvious from the configuration files, so we check
        # here.
        if modeling_layer.get_input_dim() != 4 * encoding_dim:
            raise ConfigurationError(
                "The input dimension to the modeling_layer must be "
                "equal to 4 times the encoding dimension of the phrase_layer. "
                "Found {} and 4 * {} respectively.".format(
                    modeling_layer.get_input_dim(), encoding_dim))
        if text_field_embedder.get_output_dim() != phrase_layer.get_input_dim(
        ):
            raise ConfigurationError(
                "The output dimension of the text_field_embedder (embedding_dim + "
                "char_cnn) must match the input dimension of the phrase_encoder. "
                "Found {} and {}, respectively.".format(
                    text_field_embedder.get_output_dim(),
                    phrase_layer.get_input_dim()))

        if span_end_encoder.get_input_dim(
        ) != encoding_dim * 4 + modeling_dim * 3:
            raise ConfigurationError(
                "The input dimension of the span_end_encoder should be equal to "
                "4 * phrase_layer.output_dim + 3 * modeling_layer.output_dim. "
                "Found {} and (4 * {} + 3 * {}) "
                "respectively.".format(span_end_encoder.get_input_dim(),
                                       encoding_dim, modeling_dim))

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self._get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span
        }
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
        if metadata is not None:
            output_dict['best_span_str'] = []
            for i in range(batch_size):
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def _get_best_span(span_start_logits: Variable,
                       span_end_logits: Variable) -> Variable:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = Variable(span_start_logits.data.new().resize_(
            batch_size, 2).fill_(0)).long()

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span

    @classmethod
    def from_params(cls, vocab: Vocabulary,
                    params: Params) -> 'BidirectionalAttentionFlow':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, embedder_params)
        num_highway_layers = params.pop("num_highway_layers")
        phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer"))
        similarity_function = SimilarityFunction.from_params(
            params.pop("similarity_function"))
        modeling_layer = Seq2SeqEncoder.from_params(
            params.pop("modeling_layer"))
        span_end_encoder = Seq2SeqEncoder.from_params(
            params.pop("span_end_encoder"))
        dropout = params.pop('dropout', 0.2)

        init_params = params.pop('initializer', None)
        reg_params = params.pop('regularizer', None)
        initializer = (InitializerApplicator.from_params(init_params)
                       if init_params is not None else InitializerApplicator())
        regularizer = RegularizerApplicator.from_params(
            reg_params) if reg_params is not None else None

        mask_lstms = params.pop('mask_lstms', True)
        params.assert_empty(cls.__name__)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   num_highway_layers=num_highway_layers,
                   phrase_layer=phrase_layer,
                   attention_similarity_function=similarity_function,
                   modeling_layer=modeling_layer,
                   span_end_encoder=span_end_encoder,
                   dropout=dropout,
                   mask_lstms=mask_lstms,
                   initializer=initializer,
                   regularizer=regularizer)
Example #30
0
class BidirectionalAttentionFlow_1(Model):
    """
    This class implements a Bayesian version of Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).
    """
    
    def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None:
        super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer)
        
        """
        Initialize some data structures 
        """
        self.cf_a = cf_a
        # Bayesian data models
        self.VBmodels = []
        self.LinearModels = []
        """
        ############## TEXT FIELD EMBEDDER with ELMO ####################
        text_field_embedder : ``TextFieldEmbedder``
            Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
        """
        if (cf_a.use_ELMO):
            if (type(preloaded_elmo) != type(None)):
                text_field_embedder = preloaded_elmo
            else:
                text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput )
                print ("ELMO loaded from disk or downloaded")
        else:
            text_field_embedder = None
        
#        embedder_out_dim  = text_field_embedder.get_output_dim()
        self._text_field_embedder = text_field_embedder
        
        if(cf_a.Add_Linear_projection_ELMO):
            if (self.cf_a.VB_Linear_projection_ELMO):
                prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior))
                print ("----------------- Bayesian Linear Projection ELMO --------------")
                linear_projection_ELMO =  LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior)
                self.VBmodels.append(linear_projection_ELMO)
            else:
                linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200)
        
            self._linear_projection_ELMO = linear_projection_ELMO
            
        """
        ############## Highway layers ####################
        num_highway_layers : ``int``
            The number of highway layers to use in between embedding the input and passing it through
            the phrase layer.
        """
        
        Input_dimension_highway = None
        if (cf_a.Add_Linear_projection_ELMO):
            Input_dimension_highway = 200
        else:
            Input_dimension_highway = text_field_embedder.get_output_dim()
            
        num_highway_layers = cf_a.num_highway_layers
        # Linear later to compute the start 
        if (self.cf_a.VB_highway_layers):
            print ("----------------- Bayesian Highway network  --------------")
            prior = Vil.Prior(**(cf_a.VB_highway_layers_prior))
            highway_layer = HighwayVB(Input_dimension_highway,
                                                          num_highway_layers, prior = prior)
            self.VBmodels.append(highway_layer)
        else:
            
            highway_layer = Highway(Input_dimension_highway,
                                                          num_highway_layers)
        highway_layer = TimeDistributed(highway_layer)
        
        self._highway_layer = highway_layer
        
        """
        ############## Phrase layer ####################
        phrase_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between embedding tokens
            and doing the bidirectional attention.
        """
        if cf_a.phrase_layer_dropout > 0:       ## Create dropout layer
            dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout)
        else:
            dropout_phrase_layer = lambda x: x
        
        phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout))
        
        phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2
        self._phrase_layer = phrase_layer
        self._dropout_phrase_layer = dropout_phrase_layer
        
        """
        ############## Matrix attention layer ####################
        similarity_function : ``SimilarityFunction``
            The similarity function that we will use when comparing encoded passage and question
            representations.
        """
        
        # Linear later to compute the start 
        if (self.cf_a.VB_similarity_function):
            prior = Vil.Prior(**(cf_a.VB_similarity_function_prior))
            print ("----------------- Bayesian Similarity matrix --------------")
            similarity_function = LinearSimilarityVB(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim, prior = prior)
            self.VBmodels.append(similarity_function)
        else:
            similarity_function = LinearSimilarity(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim)
            
        matrix_attention = LegacyMatrixAttention(similarity_function)
        self._matrix_attention = matrix_attention
        
        """
        ############## Modelling Layer ####################
        modeling_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between the bidirectional
            attention and predicting span start and end.
        """
        ## Create dropout layer
        if cf_a.modeling_passage_dropout > 0:       ## Create dropout layer
            dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout)
        else:
            dropout_modeling_passage = lambda x: x
        
        modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout))

        self._modeling_layer = modeling_layer
        self._dropout_modeling_passage = dropout_modeling_passage
        
        """
        ############## Span Start Representation #####################
        span_end_encoder : ``Seq2SeqEncoder``
            The encoder that we will use to incorporate span start predictions into the passage state
            before predicting span end.
        """
        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        
        # Linear later to compute the start 
        if (self.cf_a.VB_span_start_predictor_linear):
            prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior))
            print ("----------------- Bayesian Span Start Predictor--------------")
            span_start_predictor_linear =  LinearVB(span_start_input_dim, 1, prior = prior)
            self.VBmodels.append(span_start_predictor_linear)
        else:
            span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1)
            
        self._span_start_predictor_linear = span_start_predictor_linear
        self._span_start_predictor = TimeDistributed(span_start_predictor_linear)

        """
        ############## Span End Representation #####################
        """
        
        ## Create dropout layer
        if cf_a.span_end_encoder_dropout > 0:
            dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
        else:
            dropout_span_end_encode = lambda x: x
        
        span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout))
   
        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        
        self._span_end_encoder = span_end_encoder
        self._dropout_span_end_encode = dropout_span_end_encode
        
        if (self.cf_a.VB_span_end_predictor_linear):
            print ("----------------- Bayesian Span End Predictor--------------")
            prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior))
            span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior)
            self.VBmodels.append(span_end_predictor_linear) 
        else:
            span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1)
        
        self._span_end_predictor_linear = span_end_predictor_linear
        self._span_end_predictor = TimeDistributed(span_end_predictor_linear)

        """
        Dropput last layers
        """
        if cf_a.spans_output_dropout > 0:
            dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
        else:
            dropout_spans_output = lambda x: x
        
        self._dropout_spans_output = dropout_spans_output
        
        """
        Checkings and accuracy
        """
        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        """
        mask_lstms : ``bool``, optional (default=True)
            If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
            with only a slight performance decrease, if any.  We haven't experimented much with this
            yet, but have confirmed that we still get very similar performance with much faster
            training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
            required when using masking with pytorch LSTMs.
        """
        self._mask_lstms = cf_a.mask_lstms

    
        """
        ################### Initialize parameters ##############################
        """
        #### THEY ARE ALL INITIALIZED WHEN INSTANTING THE COMPONENTS ###
    
        """
        ####################### OPTIMIZER ################
        """
        optimizer = pytut.get_optimizers(self, cf_a)
        self._optimizer = optimizer
        #### TODO: Learning rate scheduler ####
        #scheduler = optim.ReduceLROnPlateau(optimizer, 'max')
    
    def forward_ensemble(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False) -> Dict[str, torch.Tensor]:
        """
        Sample 10 times and add them together
        """
        self.set_posterior_mean(True)
        most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)
        self.set_posterior_mean(False)
       
        subresults = [most_likely_output]
        for i in range(10):
           subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information))

        batch_size = len(subresults[0]["best_span"])

        best_span = bidut.merge_span_probs(subresults)
        
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
                        
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output
    
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False,
                get_attentions = False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        
        """
        #################### Sample Bayesian weights ##################
        """
        self.sample_posterior()
        
        """
        ################## MASK COMPUTING ########################
        """
                
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None
        
        """
        ###################### EMBEDDING + HIGHWAY LAYER ########################
        """
#        self.cf_a.use_ELMO
        
        if(self.cf_a.Add_Linear_projection_ELMO):
            embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]))
            embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]))
        else:
            embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])
            embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        
        """
        ###################### phrase_layer LAYER ########################
        """

        encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        """
        ###################### Attention LAYER ########################
        """
        
        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)
        
        """
        ###################### Spans LAYER ########################
        """
        
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout_spans_output(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout_spans_output(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        
        best_span = bidut.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            
            span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            loss = span_start_loss + span_end_loss

            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            
            output_dict["loss"] = loss
            output_dict["span_start_loss"] = span_start_loss
            output_dict["span_end_loss"] = span_end_loss
            
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            if (get_sample_level_information):
                output_dict["em_samples"] = []
                output_dict["f1_samples"] = []
                
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output_dict["em_samples"].append(em_sample)
                        output_dict["f1_samples"].append(f1_sample)
                        
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output_dict["span_start_sample_loss"] = []
            output_dict["span_end_sample_loss"] = []
            for i in range (batch_size):
                span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]])
                
                output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        if(get_attentions):
            output_dict["C2Q_attention"] = passage_question_attention
            output_dict["Q2C_attention"] = question_passage_attention
            output_dict["simmilarity"] = passage_question_similarity
            
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
                }
    
    def train_batch(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        
        """
        It is enough to just compute the total loss because the normal weights 
        do not depend on the KL Divergence
        """
        # Now we can just compute both losses which will build the dynamic graph
        
        output = self.forward(question,passage,span_start,span_end,metadata )
        data_loss = output["loss"]
        
        KL_div = self.get_KL_divergence()
        total_loss =  self.combine_losses(data_loss, KL_div)
        
        self.zero_grad()     # zeroes the gradient buffers of all parameters
        total_loss.backward()
        
        if (type(self._optimizer) == type(None)):
            parameters = filter(lambda p: p.requires_grad, self.parameters())
            with torch.no_grad():
                for f in parameters:
                    f.data.sub_(f.grad.data * self.lr )
        else:
#            print ("Training")
            self._optimizer.step()
            self._optimizer.zero_grad()
            
        return output
    
    def fill_batch_training_information(self, training_logger,
                                        output_batch):
        """
        Function to fill the the training_logger for each batch. 
        training_logger: Dictionary that will hold all the training info
        output_batch: Output from training the batch
        """
        training_logger["train"]["span_start_loss_batch"].append(output_batch["span_start_loss"].detach().cpu().numpy())
        training_logger["train"]["span_end_loss_batch"].append(output_batch["span_end_loss"].detach().cpu().numpy())
        training_logger["train"]["loss_batch"].append(output_batch["loss"].detach().cpu().numpy())
        # Training metrics:
        metrics = self.get_metrics()
        training_logger["train"]["start_acc_batch"].append(metrics["start_acc"])
        training_logger["train"]["end_acc_batch"].append(metrics["end_acc"])
        training_logger["train"]["span_acc_batch"].append(metrics["span_acc"])
        training_logger["train"]["em_batch"].append(metrics["em"])
        training_logger["train"]["f1_batch"].append(metrics["f1"])
    
        
    def fill_epoch_training_information(self, training_logger,device,
                                        validation_iterable, num_batches_validation):
        """
        Fill the information per each epoch
        """
        Ntrials_CUDA = 100
        # Training Epoch final metrics
        metrics = self.get_metrics(reset = True)
        training_logger["train"]["start_acc"].append(metrics["start_acc"])
        training_logger["train"]["end_acc"].append(metrics["end_acc"])
        training_logger["train"]["span_acc"].append(metrics["span_acc"])
        training_logger["train"]["em"].append(metrics["em"])
        training_logger["train"]["f1"].append(metrics["f1"])
        
        self.set_posterior_mean(True)
        self.eval()
        
        data_loss_validation = 0
        loss_validation = 0
        with torch.no_grad():
            # Compute the validation accuracy by using all the Validation dataset but in batches.
            for j in range(num_batches_validation):
                tensor_dict = next(validation_iterable)
                
                trial_index = 0
                while (1):
                    try:
                        tensor_dict = pytut.move_to_device(tensor_dict, device) ## Move the tensor to cuda
                        output_batch = self.forward(**tensor_dict)
                        break;
                    except RuntimeError as er:
                        print (er.args)
                        torch.cuda.empty_cache()
                        time.sleep(5)
                        torch.cuda.empty_cache()
                        trial_index += 1
                        if (trial_index == Ntrials_CUDA):
                            print ("Too many failed trials to allocate in memory")
                            send_error_email(str(er.args))
                            sys.exit(0)
                
                data_loss_validation += output_batch["loss"].detach().cpu().numpy() 
                        
                ## Memmory management !!
            if (self.cf_a.force_free_batch_memory):
                del tensor_dict["question"]; del tensor_dict["passage"]
                del tensor_dict
                del output_batch
                torch.cuda.empty_cache()
            if (self.cf_a.force_call_garbage_collector):
                gc.collect()
                
            data_loss_validation = data_loss_validation/num_batches_validation
#            loss_validation = loss_validation/num_batches_validation
    
            # Training Epoch final metrics
        metrics = self.get_metrics(reset = True)
        training_logger["validation"]["start_acc"].append(metrics["start_acc"])
        training_logger["validation"]["end_acc"].append(metrics["end_acc"])
        training_logger["validation"]["span_acc"].append(metrics["span_acc"])
        training_logger["validation"]["em"].append(metrics["em"])
        training_logger["validation"]["f1"].append(metrics["f1"])
        
        training_logger["validation"]["data_loss"].append(data_loss_validation)
        self.train()
        self.set_posterior_mean(False)
    
    def trim_model(self, mu_sigma_ratio = 2):
        
        total_size_w = []
        total_removed_w = []
        total_size_b = []
        total_removed_b = []
        
        if (self.cf_a.VB_Linear_projection_ELMO):
                VBmodel = self._linear_projection_ELMO
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_highway_layers):
                VBmodel = self._highway_layer._module.VBmodels[0]
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_similarity_function):
                VBmodel = self._matrix_attention._similarity_function
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_span_start_predictor_linear):
                VBmodel = self._span_start_predictor_linear
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_span_end_predictor_linear):
                VBmodel = self._span_end_predictor_linear
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        
        return  total_size_w, total_removed_w, total_size_b, total_removed_b
#    print (weights_to_remove_W.shape)

    
    """
    BAYESIAN NECESSARY FUNCTIONS
    """
    sample_posterior = GeneralVBModel.sample_posterior
    get_KL_divergence = GeneralVBModel.get_KL_divergence
    set_posterior_mean = GeneralVBModel.set_posterior_mean
    combine_losses = GeneralVBModel.combine_losses
    
    def save_VB_weights(self):
        """
        Function that saves only the VB weights of the model.
        """
        pretrained_dict = ...
        model_dict = self.state_dict()
        
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict) 
        # 3. load the new state dict
        self.load_state_dict(pretrained_dict)
Example #31
0
 def test_incorrect_gold_labels_shape_catches_exceptions(self, device: str):
     accuracy = BooleanAccuracy()
     predictions = torch.rand([5, 7], device=device)
     incorrect_shape_labels = torch.rand([5, 8], device=device)
     with pytest.raises(ValueError):
         accuracy(predictions, incorrect_shape_labels)
Example #32
0
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            char_field_embedder: TextFieldEmbedder,
            # num_highway_layers: int,
            phrase_layer: Seq2SeqEncoder,
            char_rnn: Seq2SeqEncoder,
            hops: int,
            hidden_dim: int,
            dropout: float = 0.2,
            mask_lstms: bool = True,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._char_field_embedder = char_field_embedder
        self._features_embedder = nn.Embedding(2, 5)
        # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim() + 5 * 3,
        #                                               num_highway_layers))
        self._phrase_layer = phrase_layer
        self._encoding_dim = phrase_layer.get_output_dim()
        # self._stacked_brnn = PytorchSeq2SeqWrapper(
        #     StackedBidirectionalLstm(input_size=self._encoding_dim, hidden_size=hidden_dim,
        #                              num_layers=3, recurrent_dropout_probability=0.2))
        self._char_rnn = char_rnn

        self.hops = hops

        self.interactive_aligners = nn.ModuleList()
        self.interactive_SFUs = nn.ModuleList()
        self.self_aligners = nn.ModuleList()
        self.self_SFUs = nn.ModuleList()
        self.aggregate_rnns = nn.ModuleList()
        for i in range(hops):
            # interactive aligner
            self.interactive_aligners.append(
                layers.SeqAttnMatch(self._encoding_dim))
            self.interactive_SFUs.append(
                layers.SFU(self._encoding_dim, 3 * self._encoding_dim))
            # self aligner
            self.self_aligners.append(layers.SelfAttnMatch(self._encoding_dim))
            self.self_SFUs.append(
                layers.SFU(self._encoding_dim, 3 * self._encoding_dim))
            # aggregating
            self.aggregate_rnns.append(
                PytorchSeq2SeqWrapper(
                    nn.LSTM(input_size=self._encoding_dim,
                            hidden_size=hidden_dim,
                            num_layers=1,
                            dropout=0.2,
                            bidirectional=True,
                            batch_first=True)))

        # Memmory-based Answer Pointer
        self.mem_ans_ptr = layers.MemoryAnsPointer(x_size=self._encoding_dim,
                                                   y_size=self._encoding_dim,
                                                   hidden_size=hidden_dim,
                                                   hop=hops,
                                                   dropout_rate=0.2,
                                                   normalize=True)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)
Example #33
0
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30,
                 max_turn_length: int = 12) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(max_turn_length,
                                                           marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3,
                                                                     self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(phrase_layer.get_input_dim(),
                               text_field_embedder.get_output_dim() +
                               marker_embedding_dim * num_context_answers,
                               "phrase layer input dim",
                               "embedding dim + marker dim * num context answers")

        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)
Example #34
0
 def test_does_not_divide_by_zero_with_no_count(self, device: str):
     accuracy = BooleanAccuracy()
     assert accuracy.get_metric() == pytest.approx(0.0)
Example #35
0
    def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None:
        super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer)
        
        """
        Initialize some data structures 
        """
        self.cf_a = cf_a
        # Bayesian data models
        self.VBmodels = []
        self.LinearModels = []
        """
        ############## TEXT FIELD EMBEDDER with ELMO ####################
        text_field_embedder : ``TextFieldEmbedder``
            Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
        """
        if (cf_a.use_ELMO):
            if (type(preloaded_elmo) != type(None)):
                text_field_embedder = preloaded_elmo
            else:
                text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput )
                print ("ELMO loaded from disk or downloaded")
        else:
            text_field_embedder = None
        
#        embedder_out_dim  = text_field_embedder.get_output_dim()
        self._text_field_embedder = text_field_embedder
        
        if(cf_a.Add_Linear_projection_ELMO):
            if (self.cf_a.VB_Linear_projection_ELMO):
                prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior))
                print ("----------------- Bayesian Linear Projection ELMO --------------")
                linear_projection_ELMO =  LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior)
                self.VBmodels.append(linear_projection_ELMO)
            else:
                linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200)
        
            self._linear_projection_ELMO = linear_projection_ELMO
            
        """
        ############## Highway layers ####################
        num_highway_layers : ``int``
            The number of highway layers to use in between embedding the input and passing it through
            the phrase layer.
        """
        
        Input_dimension_highway = None
        if (cf_a.Add_Linear_projection_ELMO):
            Input_dimension_highway = 200
        else:
            Input_dimension_highway = text_field_embedder.get_output_dim()
            
        num_highway_layers = cf_a.num_highway_layers
        # Linear later to compute the start 
        if (self.cf_a.VB_highway_layers):
            print ("----------------- Bayesian Highway network  --------------")
            prior = Vil.Prior(**(cf_a.VB_highway_layers_prior))
            highway_layer = HighwayVB(Input_dimension_highway,
                                                          num_highway_layers, prior = prior)
            self.VBmodels.append(highway_layer)
        else:
            
            highway_layer = Highway(Input_dimension_highway,
                                                          num_highway_layers)
        highway_layer = TimeDistributed(highway_layer)
        
        self._highway_layer = highway_layer
        
        """
        ############## Phrase layer ####################
        phrase_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between embedding tokens
            and doing the bidirectional attention.
        """
        if cf_a.phrase_layer_dropout > 0:       ## Create dropout layer
            dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout)
        else:
            dropout_phrase_layer = lambda x: x
        
        phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout))
        
        phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2
        self._phrase_layer = phrase_layer
        self._dropout_phrase_layer = dropout_phrase_layer
        
        """
        ############## Matrix attention layer ####################
        similarity_function : ``SimilarityFunction``
            The similarity function that we will use when comparing encoded passage and question
            representations.
        """
        
        # Linear later to compute the start 
        if (self.cf_a.VB_similarity_function):
            prior = Vil.Prior(**(cf_a.VB_similarity_function_prior))
            print ("----------------- Bayesian Similarity matrix --------------")
            similarity_function = LinearSimilarityVB(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim, prior = prior)
            self.VBmodels.append(similarity_function)
        else:
            similarity_function = LinearSimilarity(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim)
            
        matrix_attention = LegacyMatrixAttention(similarity_function)
        self._matrix_attention = matrix_attention
        
        """
        ############## Modelling Layer ####################
        modeling_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between the bidirectional
            attention and predicting span start and end.
        """
        ## Create dropout layer
        if cf_a.modeling_passage_dropout > 0:       ## Create dropout layer
            dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout)
        else:
            dropout_modeling_passage = lambda x: x
        
        modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout))

        self._modeling_layer = modeling_layer
        self._dropout_modeling_passage = dropout_modeling_passage
        
        """
        ############## Span Start Representation #####################
        span_end_encoder : ``Seq2SeqEncoder``
            The encoder that we will use to incorporate span start predictions into the passage state
            before predicting span end.
        """
        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        
        # Linear later to compute the start 
        if (self.cf_a.VB_span_start_predictor_linear):
            prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior))
            print ("----------------- Bayesian Span Start Predictor--------------")
            span_start_predictor_linear =  LinearVB(span_start_input_dim, 1, prior = prior)
            self.VBmodels.append(span_start_predictor_linear)
        else:
            span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1)
            
        self._span_start_predictor_linear = span_start_predictor_linear
        self._span_start_predictor = TimeDistributed(span_start_predictor_linear)

        """
        ############## Span End Representation #####################
        """
        
        ## Create dropout layer
        if cf_a.span_end_encoder_dropout > 0:
            dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
        else:
            dropout_span_end_encode = lambda x: x
        
        span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout))
   
        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        
        self._span_end_encoder = span_end_encoder
        self._dropout_span_end_encode = dropout_span_end_encode
        
        if (self.cf_a.VB_span_end_predictor_linear):
            print ("----------------- Bayesian Span End Predictor--------------")
            prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior))
            span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior)
            self.VBmodels.append(span_end_predictor_linear) 
        else:
            span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1)
        
        self._span_end_predictor_linear = span_end_predictor_linear
        self._span_end_predictor = TimeDistributed(span_end_predictor_linear)

        """
        Dropput last layers
        """
        if cf_a.spans_output_dropout > 0:
            dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
        else:
            dropout_spans_output = lambda x: x
        
        self._dropout_spans_output = dropout_spans_output
        
        """
        Checkings and accuracy
        """
        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        """
        mask_lstms : ``bool``, optional (default=True)
            If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
            with only a slight performance decrease, if any.  We haven't experimented much with this
            yet, but have confirmed that we still get very similar performance with much faster
            training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
            required when using masking with pytorch LSTMs.
        """
        self._mask_lstms = cf_a.mask_lstms

    
        """
        ################### Initialize parameters ##############################
        """
        #### THEY ARE ALL INITIALIZED WHEN INSTANTING THE COMPONENTS ###
    
        """
        ####################### OPTIMIZER ################
        """
        optimizer = pytut.get_optimizers(self, cf_a)
        self._optimizer = optimizer
print ("-------------- LOGITS OF BOTH SPANS and BEST SPAN ---------------")

span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)

best_span = BidirectionalAttentionFlow_1.get_best_span(span_start_logits, span_end_logits)

print ("best_spans", best_span)

"""
------------------------------ GET LOSES AND ACCURACIES -----------------------------------
"""
span_start_accuracy_function = CategoricalAccuracy()
span_end_accuracy_function = CategoricalAccuracy()
span_accuracy_function = BooleanAccuracy()
squad_metrics_function = SquadEmAndF1()

# Compute the loss for training.
if span_start is not None:
    span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
    span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
    loss = span_start_loss + span_end_loss
    
    span_start_accuracy_function(span_start_logits, span_start.squeeze(-1))
    span_end_accuracy_function(span_end_logits, span_end.squeeze(-1))
    span_accuracy_function(best_span, torch.stack([span_start, span_end], -1))

    span_start_accuracy = span_start_accuracy_function.get_metric()
    span_end_accuracy =  span_end_accuracy_function.get_metric()
    span_accuracy = span_accuracy_function.get_metric()
Example #37
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(),
                               4 * encoding_dim, "modeling layer input dim",
                               "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(),
                               phrase_layer.get_input_dim(),
                               "text field embedder output dim",
                               "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(),
                               4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim",
                               "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['best_span_indices'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                output_dict['best_span_indices'].append(
                    [start_offset, end_offset])
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = torch.triu(
            torch.ones((passage_length, passage_length),
                       device=device)).log().unsqueeze(0)
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)
Example #38
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 feed_forward: FeedForward,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(ModelSQUAD, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        #self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._residual_encoder = residual_encoder
        self._span_end_encoder = span_end_encoder
        self._span_start_encoder = span_start_encoder
        self._feed_forward = feed_forward

        encoding_dim = phrase_layer.get_output_dim()
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))
        self._no_answer_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))

        #self._self_matrix_attention = MatrixAttention(attention_similarity_function)
        self._linear_layer = TimeDistributed(
            torch.nn.Linear(4 * encoding_dim, encoding_dim))
        self._residual_linear_layer = TimeDistributed(
            torch.nn.Linear(3 * encoding_dim, encoding_dim))

        self._w_p = torch.nn.Parameter(torch.Tensor(encoding_dim))
        self._w_q = torch.nn.Parameter(torch.Tensor(encoding_dim))
        self._w_pq = torch.nn.Parameter(torch.Tensor(encoding_dim))
        std = math.sqrt(6 / (encoding_dim * 3 + 1))
        self._w_p.data.uniform_(-std, std)
        self._w_q.data.uniform_(-std, std)
        self._w_pq.data.uniform_(-std, std)

        self._w_x = torch.nn.Parameter(torch.Tensor(encoding_dim))
        self._w_y = torch.nn.Parameter(torch.Tensor(encoding_dim))
        self._w_xy = torch.nn.Parameter(torch.Tensor(encoding_dim))
        #std = math.sqrt(6/ (encoding_dim*3 + 1))
        self._w_x.data.uniform_(-std, std)
        self._w_y.data.uniform_(-std, std)
        self._w_xy.data.uniform_(-std, std)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)