def forward( self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs ) -> torch.Tensor: if self._token_embedders.keys() != text_field_input.keys(): message = "Mismatched token keys: %s and %s" % ( str(self._token_embedders.keys()), str(text_field_input.keys()), ) raise ConfigurationError(message) embedded_representations = [] for key in self._ordered_embedder_keys: embedder = getattr(self, "token_embedder_{}".format(key)) forward_params_values = {} missing_tensor_args = set() if param in kwargs: forward_params_values[param] = kwargs[param] else: for _ in range(num_wrapping_dims): embedder = TimeDistributed(embedder) tensors: Dict[str, torch.Tensor] = text_field_input[key] if len(tensors) == 1 and len(missing_tensor_args) == 1: token_vectors = embedder(list(tensors.values())[0], **forward_params_values) else: token_vectors = embedder(**tensors, **forward_params_values) if token_vectors is not None:
def forward(self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs) -> torch.Tensor: if sorted(self._token_embedders.keys()) != sorted( text_field_input.keys()): message = "Mismatched token keys: %s and %s" % ( str(self._token_embedders.keys()), str(text_field_input.keys()), ) embedder_keys = set(self._token_embedders.keys()) input_keys = set(text_field_input.keys()) if embedder_keys > input_keys and all( isinstance(embedder, EmptyEmbedder) for name, embedder in self._token_embedders.items() if name in embedder_keys - input_keys): # Allow extra embedders that are only in the token embedders (but not input) and are empty to pass # config check pass else: raise ConfigurationError(message) embedded_representations = [] for key in self._ordered_embedder_keys: # Note: need to use getattr here so that the pytorch voodoo # with submodules works with multiple GPUs. embedder = getattr(self, "token_embedder_{}".format(key)) if isinstance(embedder, EmptyEmbedder): # Skip empty embedders continue forward_params = inspect.signature(embedder.forward).parameters forward_params_values = {} missing_tensor_args = set() for param in forward_params.keys(): if param in kwargs: forward_params_values[param] = kwargs[param] else: missing_tensor_args.add(param) for _ in range(num_wrapping_dims): embedder = TimeDistributed(embedder) tensors: Dict[str, torch.Tensor] = text_field_input[key] if len(tensors) == 1 and len(missing_tensor_args) == 1: # If there's only one tensor argument to the embedder, and we just have one tensor to # embed, we can just pass in that tensor, without requiring a name match. token_vectors = embedder( list(tensors.values())[0], **forward_params_values) else: # If there are multiple tensor arguments, we have to require matching names from the # TokenIndexer. I don't think there's an easy way around that. token_vectors = embedder(**tensors, **forward_params_values) if token_vectors is not None: # To handle some very rare use cases, we allow the return value of the embedder to # be None; we just skip it in that case. embedded_representations.append(token_vectors) return torch.cat(embedded_representations, dim=-1)
def forward(self, text_field_input: TextFieldTensors, augment: int, difficulty_step: int, num_wrapping_dims: int = 0, **kwargs) -> torch.Tensor: if self._token_embedders.keys() != text_field_input.keys(): message = "Mismatched token keys: %s and %s" % ( str(self._token_embedders.keys()), str(text_field_input.keys()), ) raise ConfigurationError(message) embedded_representations = [] for key in self._ordered_embedder_keys: # Note: need to use getattr here so that the pytorch voodoo # with submodules works with multiple GPUs. embedder = getattr(self, "token_embedder_{}".format(key)) forward_params = inspect.signature(embedder.forward).parameters forward_params_values = {} missing_tensor_args = set() for param in forward_params.keys(): if param in kwargs: forward_params_values[param] = kwargs[param] else: missing_tensor_args.add(param) for _ in range(num_wrapping_dims): embedder = TimeDistributed(embedder) tensors: Dict[str, torch.Tensor] = text_field_input[key] if len(tensors) == 1 and len(missing_tensor_args) == 1: # If there's only one tensor argument to the embedder, and we just have one tensor # to embed, we can just pass in that tensor, without requiring a name match. masked_lm_loss, token_vectors = embedder( augment, difficulty_step, list(tensors.values())[0], **forward_params_values) else: # If there are multiple tensor arguments, we have to require matching names from # the TokenIndexer. I don't think there's an easy way around that. masked_lm_loss, token_vectors = embedder( augment, difficulty_step, **tensors, **forward_params_values) if token_vectors is not None: # To handle some very rare use cases, we allow the return value of the embedder to # be None; we just skip it in that case. embedded_representations.append(token_vectors) return masked_lm_loss, torch.cat(embedded_representations, dim=-1)
def get_step_state(self, inputs: TextFieldTensors) -> Dict[str, torch.Tensor]: """ Create a `state` dictionary for `BeamSearch` from the `TextFieldTensors` inputs to the `NextTokenLm` model. By default this assumes the `TextFieldTensors` has a single `TokenEmbedder`, and just "flattens" the `TextFieldTensors` by returning the `TokenEmbedder` sub-dictionary. If you have `TextFieldTensors` with more than one `TokenEmbedder` sub-dictionary, you'll need to override this class. """ assert len(inputs) == 1, ( "'get_step_state()' assumes a single token embedder by default, " "you'll need to override this method to handle more than one") key = list(inputs.keys())[0] # We can't just `return inputs[key]` because we might want to modify the state # dictionary (add or remove fields) without accidentally modifying the inputs # dictionary. return {k: v for (k, v) in inputs[key].items()}
def forward( self, context_ids: TextFieldTensors, query_ids: TextFieldTensors, context_lens: torch.Tensor, query_lens: torch.Tensor, mask_label: Optional[torch.Tensor] = None, cls_label: Optional[torch.Tensor] = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # concat the context and query to the encoder # get the indexers first indexers = context_ids.keys() dialogue_ids = {} # 获取context和query的长度 context_len = torch.max(context_lens).item() query_len = torch.max(query_lens).item() # [B, _len] context_mask = get_mask_from_sequence_lengths(context_lens, context_len) query_mask = get_mask_from_sequence_lengths(query_lens, query_len) for indexer in indexers: # get the various variables of context and query dialogue_ids[indexer] = {} for key in context_ids[indexer].keys(): context = context_ids[indexer][key] query = query_ids[indexer][key] # concat the context and query in the length dim dialogue = torch.cat([context, query], dim=1) dialogue_ids[indexer][key] = dialogue # get the outputs of the dialogue if isinstance(self._text_field_embedder, TextFieldEmbedder): embedder_outputs = self._text_field_embedder(dialogue_ids) else: embedder_outputs = self._text_field_embedder( **dialogue_ids[self._index_name]) # get the outputs of the query and context # [B, _len, embed_size] context_last_layer = embedder_outputs[:, :context_len].contiguous() query_last_layer = embedder_outputs[:, context_len:].contiguous() output_dict = {} # --------- cls任务:判断是否需要改写 ------------------ if self._cls_task: # 获取cls表征, [B, embed_size] cls_embed = context_last_layer[:, 0] # 经过线性层分类, [B, 2] cls_logits = self._cls_linear(cls_embed) output_dict["cls_logits"] = cls_logits else: cls_logits = None # --------- mask任务:判断query中需要填充的位置 ----------- if self._mask_task: # 经过线性层,[B, _len, 2] mask_logits = self._mask_linear(query_last_layer) output_dict["mask_logits"] = mask_logits else: mask_logits = None if cls_label is not None: output_dict["loss"] = self._calc_loss(cls_label, mask_label, cls_logits, mask_logits, query_mask) return output_dict
def forward(self, context_ids: TextFieldTensors, query_ids: TextFieldTensors, extend_context_ids: torch.Tensor, extend_query_ids: torch.Tensor, context_len: torch.Tensor, query_len: torch.Tensor, oovs_len: torch.Tensor, rewrite_input_ids: Optional[TextFieldTensors] = None, rewrite_target_ids: Optional[TextFieldTensors] = None, extend_rewrite_ids: Optional[torch.Tensor] = None, rewrite_len: Optional[torch.Tensor] = None, metadata: Optional[List[Dict[str, Any]]] = None): """ 这里的通用的id都是allennlp中默认的TextFieldTensors类型 而extend_context_ids则是我们在数据预处理时转换好的 context, query和rewrite等_len,主要用于获取mask向量 """ # 获取context和query的token_ids context_token_ids = context_ids[self._index_name]["token_ids"] query_token_ids = query_ids[self._index_name]["token_ids"] context_mask = context_ids[self._index_name]["mask"] query_mask = query_ids[self._index_name]["mask"] # get the extended context and query ids extend_context_ids = context_token_ids + extend_context_ids.to( dtype=torch.long) extend_query_ids = query_token_ids + extend_query_ids.to( dtype=torch.long) # ---------- bert编码器计算输出 --------------- # 需要将context和query拼接在一起编码 indexers = context_ids.keys() dialogue_ids = {} for indexer in indexers: # get the various variables of context and query dialogue_ids[indexer] = {} for key in context_ids[indexer].keys(): context = context_ids[indexer][key] query = query_ids[indexer][key] # concat the context and query in the length dim dialogue = torch.cat([context, query], dim=1) dialogue_ids[indexer][key] = dialogue # 计算编码 dialogue_output = self._text_field_embedder(dialogue_ids) context_output, query_output, dec_init_state = self._run_encoder( dialogue_output, context_mask, query_mask) output_dict = {"metadata": metadata} if self.training: rewrite_input_token_ids = rewrite_input_ids[ self._index_name]["token_ids"] rewrite_input_mask = rewrite_input_ids[self._index_name]["mask"] rewrite_target_ids = rewrite_target_ids[ self._index_name]["token_ids"] rewrite_target_ids = rewrite_target_ids + extend_rewrite_ids.to( dtype=torch.long) # [B, rewrite_len, encoder_output_dim] rewrite_embed = self._get_embeddings(rewrite_input_token_ids) new_output_dict = self._forward_step( context_output, query_output, context_mask, query_mask, rewrite_embed, rewrite_target_ids, rewrite_len, rewrite_input_mask, extend_context_ids, extend_query_ids, oovs_len, dec_init_state) output_dict.update(new_output_dict) else: batch_hyps = self._run_inference(context_output, query_output, context_mask, query_mask, extend_context_ids, extend_query_ids, oovs_len, dec_init_state=dec_init_state) # get the result of each instance output_dict['hypothesis'] = batch_hyps output_dict = self.get_rewrite_string(output_dict) output_dict["loss"] = torch.tensor(0) return output_dict
def forward(self, context_ids: TextFieldTensors, query_ids: TextFieldTensors, context_lens: torch.Tensor, query_lens: torch.Tensor, mask_label: Optional[torch.Tensor] = None, start_label: Optional[torch.Tensor] = None, end_label: Optional[torch.Tensor] = None, metadata: List[Dict[str, Any]] = None): # concat the context and query to the encoder # get the indexers first indexers = context_ids.keys() dialogue_ids = {} # 获取context和query的长度 context_len = torch.max(context_lens).item() query_len = torch.max(query_lens).item() # [B, _len] context_mask = get_mask_from_sequence_lengths(context_lens, context_len) query_mask = get_mask_from_sequence_lengths(query_lens, query_len) for indexer in indexers: # get the various variables of context and query dialogue_ids[indexer] = {} for key in context_ids[indexer].keys(): context = context_ids[indexer][key] query = query_ids[indexer][key] # concat the context and query in the length dim dialogue = torch.cat([context, query], dim=1) dialogue_ids[indexer][key] = dialogue # get the outputs of the dialogue if isinstance(self._text_field_embedder, TextFieldEmbedder): embedder_outputs = self._text_field_embedder(dialogue_ids) else: embedder_outputs = self._text_field_embedder( **dialogue_ids[self._index_name]) # get the outputs of the query and context # [B, _len, embed_size] context_last_layer = embedder_outputs[:, :context_len].contiguous() query_last_layer = embedder_outputs[:, context_len:].contiguous() # ------- 计算span预测的结果 ------- # 我们想要知道query中的每一个mask位置的token后面需要补充的内容 # 也就是其对应的context中span的start和end的位置 # 同理,将context扩展成 [b, query_len, context_len, embed_size] context_last_layer = context_last_layer.unsqueeze(dim=1).expand( -1, query_len, -1, -1).contiguous() # [b, query_len, context_len] context_expand_mask = context_mask.unsqueeze(dim=1).expand( -1, query_len, -1).contiguous() # 将上面3个部分拼接在一起 # 这里表示query中所有的position span_embed_size = context_last_layer.size(-1) if self.training and self._neg_sample_ratio > 0.0: # 对mask中0的位置进行采样 # [B*query_len, ] sample_mask_label = mask_label.view(-1) # 获取展开之后的长度以及需要采样的负样本的数量 mask_length = sample_mask_label.size(0) mask_sum = int( torch.sum(sample_mask_label).item() * self._neg_sample_ratio) mask_sum = max(10, mask_sum) # 获取需要采样的负样本的索引 neg_indexes = torch.randint(low=0, high=mask_length, size=(mask_sum, )) # 限制在长度范围内 neg_indexes = neg_indexes[:mask_length] # 将负样本对应的位置mask置为1 sample_mask_label[neg_indexes] = 1 # [B, query_len] use_mask_label = sample_mask_label.view( -1, query_len).to(dtype=torch.bool) # 过滤掉query中pad的部分, [B, query_len] use_mask_label = use_mask_label & query_mask span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1) # 选择context部分可以使用的内容 # [B_mask, context_len, span_embed_size] span_context_matrix = context_last_layer.masked_select( span_mask).view(-1, context_len, span_embed_size).contiguous() # 选择query部分可以使用的向量 span_query_vector = query_last_layer.masked_select( span_mask.squeeze(dim=-1)).view(-1, span_embed_size).contiguous() span_context_mask = context_expand_mask.masked_select( span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous() else: use_mask_label = query_mask span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1) # 选择context部分可以使用的内容 # [B_mask, context_len, span_embed_size] span_context_matrix = context_last_layer.masked_select( span_mask).view(-1, context_len, span_embed_size).contiguous() # 选择query部分可以使用的向量 span_query_vector = query_last_layer.masked_select( span_mask.squeeze(dim=-1)).view(-1, span_embed_size).contiguous() span_context_mask = context_expand_mask.masked_select( span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous() # 得到span属于每个位置的logits # [B_mask, context_len] span_start_probs = self.start_attention(span_query_vector, span_context_matrix, span_context_mask) span_end_probs = self.end_attention(span_query_vector, span_context_matrix, span_context_mask) span_start_logits = torch.log(span_start_probs + self._eps) span_end_logits = torch.log(span_end_probs + self._eps) # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置 best_spans = get_best_span(span_start_logits, span_end_logits) # 计算得到每个best_span的分数 best_span_scores = ( torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1))) # [B_mask, ] best_span_scores = best_span_scores.squeeze(1) # 将重要的信息写入到输出中 output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_spans": best_spans, "best_span_scores": best_span_scores } # 如果存在标签,则使用标签计算loss if start_label is not None: loss = self._calc_loss(span_start_logits, span_end_logits, use_mask_label, start_label, end_label, best_spans) output_dict["loss"] = loss if metadata is not None: predict_rewrite_results = self._get_rewrite_result( use_mask_label, best_spans, query_lens, context_lens, metadata) output_dict['rewrite_results'] = predict_rewrite_results return output_dict