Example #1
0
    def forward(
        self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs
    ) -> torch.Tensor:
        if self._token_embedders.keys() != text_field_input.keys():
            message = "Mismatched token keys: %s and %s" % (
                str(self._token_embedders.keys()),
                str(text_field_input.keys()),
            )
            raise ConfigurationError(message)

        embedded_representations = []
        for key in self._ordered_embedder_keys:
            embedder = getattr(self, "token_embedder_{}".format(key))
            forward_params_values = {}
            missing_tensor_args = set()
                if param in kwargs:
                    forward_params_values[param] = kwargs[param]
                else:

            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)

            tensors: Dict[str, torch.Tensor] = text_field_input[key]
            if len(tensors) == 1 and len(missing_tensor_args) == 1:
                token_vectors = embedder(list(tensors.values())[0], **forward_params_values)
            else:
                token_vectors = embedder(**tensors, **forward_params_values)
            if token_vectors is not None:
    def forward(self,
                text_field_input: TextFieldTensors,
                num_wrapping_dims: int = 0,
                **kwargs) -> torch.Tensor:
        if sorted(self._token_embedders.keys()) != sorted(
                text_field_input.keys()):
            message = "Mismatched token keys: %s and %s" % (
                str(self._token_embedders.keys()),
                str(text_field_input.keys()),
            )
            embedder_keys = set(self._token_embedders.keys())
            input_keys = set(text_field_input.keys())
            if embedder_keys > input_keys and all(
                    isinstance(embedder, EmptyEmbedder)
                    for name, embedder in self._token_embedders.items()
                    if name in embedder_keys - input_keys):
                # Allow extra embedders that are only in the token embedders (but not input) and are empty to pass
                # config check
                pass
            else:
                raise ConfigurationError(message)

        embedded_representations = []
        for key in self._ordered_embedder_keys:
            # Note: need to use getattr here so that the pytorch voodoo
            # with submodules works with multiple GPUs.
            embedder = getattr(self, "token_embedder_{}".format(key))
            if isinstance(embedder, EmptyEmbedder):
                # Skip empty embedders
                continue
            forward_params = inspect.signature(embedder.forward).parameters
            forward_params_values = {}
            missing_tensor_args = set()
            for param in forward_params.keys():
                if param in kwargs:
                    forward_params_values[param] = kwargs[param]
                else:
                    missing_tensor_args.add(param)

            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)

            tensors: Dict[str, torch.Tensor] = text_field_input[key]
            if len(tensors) == 1 and len(missing_tensor_args) == 1:
                # If there's only one tensor argument to the embedder, and we just have one tensor to
                # embed, we can just pass in that tensor, without requiring a name match.
                token_vectors = embedder(
                    list(tensors.values())[0], **forward_params_values)
            else:
                # If there are multiple tensor arguments, we have to require matching names from the
                # TokenIndexer.  I don't think there's an easy way around that.
                token_vectors = embedder(**tensors, **forward_params_values)
            if token_vectors is not None:
                # To handle some very rare use cases, we allow the return value of the embedder to
                # be None; we just skip it in that case.
                embedded_representations.append(token_vectors)
        return torch.cat(embedded_representations, dim=-1)
Example #3
0
    def forward(self,
                text_field_input: TextFieldTensors,
                augment: int,
                difficulty_step: int,
                num_wrapping_dims: int = 0,
                **kwargs) -> torch.Tensor:
        if self._token_embedders.keys() != text_field_input.keys():
            message = "Mismatched token keys: %s and %s" % (
                str(self._token_embedders.keys()),
                str(text_field_input.keys()),
            )
            raise ConfigurationError(message)

        embedded_representations = []
        for key in self._ordered_embedder_keys:
            # Note: need to use getattr here so that the pytorch voodoo
            # with submodules works with multiple GPUs.
            embedder = getattr(self, "token_embedder_{}".format(key))
            forward_params = inspect.signature(embedder.forward).parameters
            forward_params_values = {}
            missing_tensor_args = set()
            for param in forward_params.keys():
                if param in kwargs:
                    forward_params_values[param] = kwargs[param]
                else:
                    missing_tensor_args.add(param)

            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)

            tensors: Dict[str, torch.Tensor] = text_field_input[key]
            if len(tensors) == 1 and len(missing_tensor_args) == 1:
                # If there's only one tensor argument to the embedder, and we just have one tensor
                # to embed, we can just pass in that tensor, without requiring a name match.
                masked_lm_loss, token_vectors = embedder(
                    augment, difficulty_step,
                    list(tensors.values())[0], **forward_params_values)
            else:
                # If there are multiple tensor arguments, we have to require matching names from
                # the TokenIndexer. I don't think there's an easy way around that.
                masked_lm_loss, token_vectors = embedder(
                    augment, difficulty_step, **tensors,
                    **forward_params_values)
            if token_vectors is not None:
                # To handle some very rare use cases, we allow the return value of the embedder to
                # be None; we just skip it in that case.
                embedded_representations.append(token_vectors)
        return masked_lm_loss, torch.cat(embedded_representations, dim=-1)
    def get_step_state(self,
                       inputs: TextFieldTensors) -> Dict[str, torch.Tensor]:
        """
        Create a `state` dictionary for `BeamSearch` from the `TextFieldTensors` inputs
        to the `NextTokenLm` model.

        By default this assumes the `TextFieldTensors` has a single `TokenEmbedder`,
        and just "flattens" the `TextFieldTensors` by returning the `TokenEmbedder`
        sub-dictionary.

        If you have `TextFieldTensors` with more than one `TokenEmbedder` sub-dictionary,
        you'll need to override this class.
        """
        assert len(inputs) == 1, (
            "'get_step_state()' assumes a single token embedder by default, "
            "you'll need to override this method to handle more than one")

        key = list(inputs.keys())[0]

        # We can't just `return inputs[key]` because we might want to modify the state
        # dictionary (add or remove fields) without accidentally modifying the inputs
        # dictionary.
        return {k: v for (k, v) in inputs[key].items()}
Example #5
0
    def forward(
            self,
            context_ids: TextFieldTensors,
            query_ids: TextFieldTensors,
            context_lens: torch.Tensor,
            query_lens: torch.Tensor,
            mask_label: Optional[torch.Tensor] = None,
            cls_label: Optional[torch.Tensor] = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # concat the context and query to the encoder
        # get the indexers first
        indexers = context_ids.keys()
        dialogue_ids = {}

        # 获取context和query的长度
        context_len = torch.max(context_lens).item()
        query_len = torch.max(query_lens).item()

        # [B, _len]
        context_mask = get_mask_from_sequence_lengths(context_lens,
                                                      context_len)
        query_mask = get_mask_from_sequence_lengths(query_lens, query_len)
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # get the outputs of the dialogue
        if isinstance(self._text_field_embedder, TextFieldEmbedder):
            embedder_outputs = self._text_field_embedder(dialogue_ids)
        else:
            embedder_outputs = self._text_field_embedder(
                **dialogue_ids[self._index_name])

        # get the outputs of the query and context
        # [B, _len, embed_size]
        context_last_layer = embedder_outputs[:, :context_len].contiguous()
        query_last_layer = embedder_outputs[:, context_len:].contiguous()

        output_dict = {}
        # --------- cls任务:判断是否需要改写 ------------------
        if self._cls_task:
            # 获取cls表征, [B, embed_size]
            cls_embed = context_last_layer[:, 0]
            # 经过线性层分类, [B, 2]
            cls_logits = self._cls_linear(cls_embed)
            output_dict["cls_logits"] = cls_logits
        else:
            cls_logits = None

        # --------- mask任务:判断query中需要填充的位置 -----------
        if self._mask_task:
            # 经过线性层,[B, _len, 2]
            mask_logits = self._mask_linear(query_last_layer)
            output_dict["mask_logits"] = mask_logits
        else:
            mask_logits = None

        if cls_label is not None:
            output_dict["loss"] = self._calc_loss(cls_label, mask_label,
                                                  cls_logits, mask_logits,
                                                  query_mask)

        return output_dict
Example #6
0
    def forward(self,
                context_ids: TextFieldTensors,
                query_ids: TextFieldTensors,
                extend_context_ids: torch.Tensor,
                extend_query_ids: torch.Tensor,
                context_len: torch.Tensor,
                query_len: torch.Tensor,
                oovs_len: torch.Tensor,
                rewrite_input_ids: Optional[TextFieldTensors] = None,
                rewrite_target_ids: Optional[TextFieldTensors] = None,
                extend_rewrite_ids: Optional[torch.Tensor] = None,
                rewrite_len: Optional[torch.Tensor] = None,
                metadata: Optional[List[Dict[str, Any]]] = None):
        """
        这里的通用的id都是allennlp中默认的TextFieldTensors类型
        而extend_context_ids则是我们在数据预处理时转换好的
        context, query和rewrite等_len,主要用于获取mask向量
        """
        # 获取context和query的token_ids
        context_token_ids = context_ids[self._index_name]["token_ids"]
        query_token_ids = query_ids[self._index_name]["token_ids"]

        context_mask = context_ids[self._index_name]["mask"]
        query_mask = query_ids[self._index_name]["mask"]

        # get the extended context and query ids
        extend_context_ids = context_token_ids + extend_context_ids.to(
            dtype=torch.long)
        extend_query_ids = query_token_ids + extend_query_ids.to(
            dtype=torch.long)
        # ---------- bert编码器计算输出 ---------------
        # 需要将context和query拼接在一起编码
        indexers = context_ids.keys()
        dialogue_ids = {}
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # 计算编码
        dialogue_output = self._text_field_embedder(dialogue_ids)
        context_output, query_output, dec_init_state = self._run_encoder(
            dialogue_output, context_mask, query_mask)

        output_dict = {"metadata": metadata}
        if self.training:
            rewrite_input_token_ids = rewrite_input_ids[
                self._index_name]["token_ids"]
            rewrite_input_mask = rewrite_input_ids[self._index_name]["mask"]
            rewrite_target_ids = rewrite_target_ids[
                self._index_name]["token_ids"]
            rewrite_target_ids = rewrite_target_ids + extend_rewrite_ids.to(
                dtype=torch.long)

            # [B, rewrite_len, encoder_output_dim]
            rewrite_embed = self._get_embeddings(rewrite_input_token_ids)
            new_output_dict = self._forward_step(
                context_output, query_output, context_mask, query_mask,
                rewrite_embed, rewrite_target_ids, rewrite_len,
                rewrite_input_mask, extend_context_ids, extend_query_ids,
                oovs_len, dec_init_state)
            output_dict.update(new_output_dict)
        else:
            batch_hyps = self._run_inference(context_output,
                                             query_output,
                                             context_mask,
                                             query_mask,
                                             extend_context_ids,
                                             extend_query_ids,
                                             oovs_len,
                                             dec_init_state=dec_init_state)
            # get the result of each instance
            output_dict['hypothesis'] = batch_hyps
            output_dict = self.get_rewrite_string(output_dict)
            output_dict["loss"] = torch.tensor(0)
        return output_dict
Example #7
0
    def forward(self,
                context_ids: TextFieldTensors,
                query_ids: TextFieldTensors,
                context_lens: torch.Tensor,
                query_lens: torch.Tensor,
                mask_label: Optional[torch.Tensor] = None,
                start_label: Optional[torch.Tensor] = None,
                end_label: Optional[torch.Tensor] = None,
                metadata: List[Dict[str, Any]] = None):
        # concat the context and query to the encoder
        # get the indexers first
        indexers = context_ids.keys()
        dialogue_ids = {}

        # 获取context和query的长度
        context_len = torch.max(context_lens).item()
        query_len = torch.max(query_lens).item()

        # [B, _len]
        context_mask = get_mask_from_sequence_lengths(context_lens,
                                                      context_len)
        query_mask = get_mask_from_sequence_lengths(query_lens, query_len)
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # get the outputs of the dialogue
        if isinstance(self._text_field_embedder, TextFieldEmbedder):
            embedder_outputs = self._text_field_embedder(dialogue_ids)
        else:
            embedder_outputs = self._text_field_embedder(
                **dialogue_ids[self._index_name])

        # get the outputs of the query and context
        # [B, _len, embed_size]
        context_last_layer = embedder_outputs[:, :context_len].contiguous()
        query_last_layer = embedder_outputs[:, context_len:].contiguous()

        # ------- 计算span预测的结果 -------
        # 我们想要知道query中的每一个mask位置的token后面需要补充的内容
        # 也就是其对应的context中span的start和end的位置
        # 同理,将context扩展成 [b, query_len, context_len, embed_size]
        context_last_layer = context_last_layer.unsqueeze(dim=1).expand(
            -1, query_len, -1, -1).contiguous()
        # [b, query_len, context_len]
        context_expand_mask = context_mask.unsqueeze(dim=1).expand(
            -1, query_len, -1).contiguous()

        # 将上面3个部分拼接在一起
        # 这里表示query中所有的position
        span_embed_size = context_last_layer.size(-1)

        if self.training and self._neg_sample_ratio > 0.0:
            # 对mask中0的位置进行采样
            # [B*query_len, ]
            sample_mask_label = mask_label.view(-1)
            # 获取展开之后的长度以及需要采样的负样本的数量
            mask_length = sample_mask_label.size(0)
            mask_sum = int(
                torch.sum(sample_mask_label).item() * self._neg_sample_ratio)
            mask_sum = max(10, mask_sum)
            # 获取需要采样的负样本的索引
            neg_indexes = torch.randint(low=0,
                                        high=mask_length,
                                        size=(mask_sum, ))
            # 限制在长度范围内
            neg_indexes = neg_indexes[:mask_length]
            # 将负样本对应的位置mask置为1
            sample_mask_label[neg_indexes] = 1
            # [B, query_len]
            use_mask_label = sample_mask_label.view(
                -1, query_len).to(dtype=torch.bool)
            # 过滤掉query中pad的部分, [B, query_len]
            use_mask_label = use_mask_label & query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()
        else:
            use_mask_label = query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()

        # 得到span属于每个位置的logits
        # [B_mask, context_len]
        span_start_probs = self.start_attention(span_query_vector,
                                                span_context_matrix,
                                                span_context_mask)
        span_end_probs = self.end_attention(span_query_vector,
                                            span_context_matrix,
                                            span_context_mask)

        span_start_logits = torch.log(span_start_probs + self._eps)
        span_end_logits = torch.log(span_end_probs + self._eps)

        # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置
        best_spans = get_best_span(span_start_logits, span_end_logits)
        # 计算得到每个best_span的分数
        best_span_scores = (
            torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) +
            torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1)))
        # [B_mask, ]
        best_span_scores = best_span_scores.squeeze(1)

        # 将重要的信息写入到输出中
        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_spans": best_spans,
            "best_span_scores": best_span_scores
        }

        # 如果存在标签,则使用标签计算loss
        if start_label is not None:
            loss = self._calc_loss(span_start_logits, span_end_logits,
                                   use_mask_label, start_label, end_label,
                                   best_spans)
            output_dict["loss"] = loss
        if metadata is not None:
            predict_rewrite_results = self._get_rewrite_result(
                use_mask_label, best_spans, query_lens, context_lens, metadata)
            output_dict['rewrite_results'] = predict_rewrite_results
        return output_dict