def forward(self, inputs: torch.Tensor, offsets: torch.Tensor = None) -> torch.Tensor: """ Parameters ---------- inputs: ``torch.Tensor``, required A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings for the current batch. offsets: ``torch.Tensor``, required A ``(batch_size, max_sequence_length)`` tensor representing the word offsets for the current batch. Returns ------- ``[torch.Tensor]`` An embedding representation of the input sequence having shape ``(batch_size, sequence_length, embedding_dim)`` """ # pylint: disable=arguments-differ batch_size, num_timesteps = inputs.size() # the transformer embedding consists of the byte pair embeddings, # the special embeddings and the position embeddings. # the position embeddings are always at least self._transformer.n_ctx, # but may be longer. # the transformer "vocab" consists of the actual vocab and the # positional encodings. Here we want the count of just the former. vocab_size = self._transformer.vocab_size - self._transformer.n_ctx # vocab_size, vocab_size + 1, ... positional_encodings = get_range_vector(num_timesteps, device=get_device_of(inputs)) + vocab_size # Combine the inputs with positional encodings batch_tensor = torch.stack([ inputs, # (batch_size, num_timesteps) positional_encodings.expand(batch_size, num_timesteps) ], dim=-1) byte_pairs_mask = inputs != 0 # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim) layer_activations = self._transformer(batch_tensor) # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim) if self._top_layer_only: mix = layer_activations[-1] else: mix = self._scalar_mix(layer_activations, byte_pairs_mask) # These embeddings are one per byte-pair, but we want one per original _word_. # So we choose the embedding corresponding to the last byte pair for each word, # which is captured by the ``offsets`` input. if offsets is not None: range_vector = get_range_vector(batch_size, device=get_device_of(mix)).unsqueeze(1) last_byte_pair_embeddings = mix[range_vector, offsets] else: # allow to return all byte pairs by passing no offsets seq_len = (byte_pairs_mask > 0).long().sum(dim=1).max() last_byte_pair_embeddings = mix[:, :seq_len] return last_byte_pair_embeddings
def create_cached_cnn_embeddings(self, tokens: List[str]) -> None: """ Given a list of tokens, this method precomputes word representations by running just the character convolutions and highway layers of elmo, essentially creating uncontextual word vectors. On subsequent forward passes, the word ids are looked up from an embedding, rather than being computed on the fly via the CNN encoder. This function sets 3 attributes: _word_embedding : ``torch.Tensor`` The word embedding for each word in the tokens passed to this method. _bos_embedding : ``torch.Tensor`` The embedding for the BOS token. _eos_embedding : ``torch.Tensor`` The embedding for the EOS token. Parameters ---------- tokens : ``List[str]``, required. A list of tokens to precompute character convolutions for. """ tokens = [ELMoCharacterMapper.bos_token, ELMoCharacterMapper.eos_token] + tokens timesteps = 32 batch_size = 32 chunked_tokens = lazy_groups_of(iter(tokens), timesteps) all_embeddings = [] device = get_device_of(next(self.parameters())) for batch in lazy_groups_of(chunked_tokens, batch_size): # Shape (batch_size, timesteps, 50) batched_tensor = batch_to_ids(batch) # NOTE: This device check is for when a user calls this method having # already placed the model on a device. If this is called in the # constructor, it will probably happen on the CPU. This isn't too bad, # because it's only a few convolutions and will likely be very fast. if device >= 0: batched_tensor = batched_tensor.cuda(device) output = self._token_embedder(batched_tensor) token_embedding = output["token_embedding"] mask = output["mask"] token_embedding, _ = remove_sentence_boundaries(token_embedding, mask) all_embeddings.append(token_embedding.view(-1, token_embedding.size(-1))) full_embedding = torch.cat(all_embeddings, 0) # We might have some trailing embeddings from padding in the batch, so # we clip the embedding and lookup to the right size. full_embedding = full_embedding[:len(tokens), :] embedding = full_embedding[2:len(tokens), :] vocab_size, embedding_dim = list(embedding.size()) from allennlp.modules.token_embedders import Embedding # type: ignore self._bos_embedding = full_embedding[0, :] self._eos_embedding = full_embedding[1, :] self._word_embedding = Embedding(vocab_size, # type: ignore embedding_dim, weight=embedding.data, trainable=self._requires_grad, padding_index=0)
def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. """ device = util.get_device_of(self._action_embedder.weight) # TODO(Mark): This type is pure \(- . ^)/ translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} actions_grouped_by_nonterminal: Dict[str, List[Tuple[ProductionRule, int]]] = defaultdict(list) for i, action in enumerate(possible_actions): if action.rule == "": continue if action.is_global_rule: actions_grouped_by_nonterminal[action.nonterminal].append((action, i)) else: raise ValueError("The sql parser doesn't support non-global actions yet.") for key, production_rule_arrays in actions_grouped_by_nonterminal.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. global_actions = [] for production_rule_array, action_index in production_rule_arrays: global_actions.append((production_rule_array.rule_id, action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).long() if device >= 0: global_action_tensor = global_action_tensor.to(device) global_input_embeddings = self._action_embedder(global_action_tensor) global_output_embeddings = self._output_action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_output_embeddings, list(global_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal, reverse_productions=True)
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] if offsets is None: return mix else: batch_size = input_ids.size(0) range_vector = util.get_range_vector(batch_size, device=util.get_device_of(mix)).unsqueeze(1) return mix[range_vector, offsets]
def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous() # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits
def _get_prediction_device(self) -> int: """ This method checks the device of the model parameters to determine the cuda_device this model should be run on for predictions. If there are no parameters, it returns -1. Returns ------- The cuda device this model should run on for predictions. """ devices = {util.get_device_of(param) for param in self.parameters()} if len(devices) > 1: devices_string = ", ".join(str(x) for x in devices) raise ConfigurationError(f"Parameters have mismatching cuda_devices: {devices_string}") elif len(devices) == 1: return devices.pop() else: return -1
def flatten_and_batch_shift_indices(indices: torch.Tensor, sequence_length: int) -> torch.Tensor: """ This is a subroutine for :func:`~batched_index_select`. The given ``indices`` of size ``(batch_size, d_1, ..., d_n)`` indexes into dimension 2 of a target tensor, which has size ``(batch_size, sequence_length, embedding_size)``. This function returns a vector that correctly indexes into the flattened target. The sequence length of the target must be provided to compute the appropriate offsets. .. code-block:: python indices = torch.ones([2,3], dtype=torch.long) # Sequence length of the target tensor. sequence_length = 10 shifted_indices = flatten_and_batch_shift_indices(indices, sequence_length) # Indices into the second element in the batch are correctly shifted # to take into account that the target tensor will be flattened before # the indices are applied. assert shifted_indices == [1, 1, 1, 11, 11, 11] Parameters ---------- indices : ``torch.LongTensor``, required. sequence_length : ``int``, required. The length of the sequence the indices index into. This must be the second dimension of the tensor. Returns ------- offset_indices : ``torch.LongTensor`` """ # Shape: (batch_size) offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length for _ in range(len(indices.size()) - 1): offsets = offsets.unsqueeze(1) # Shape: (batch_size, d_1, ..., d_n) offset_indices = indices + offsets # print(offset_indices) # Shape: (batch_size * d_1 * ... * d_n) offset_indices = offset_indices.view(-1) return offset_indices
def _get_prediction_device(self) -> int: """ This method checks the device of the model parameters to determine the cuda_device this model should be run on for predictions. If there are no parameters, it returns -1. Returns ------- The cuda device this model should run on for predictions. """ devices = {util.get_device_of(param) for param in self.parameters()} if len(devices) > 1: devices_string = ", ".join(str(x) for x in devices) raise ConfigurationError( f"Parameters have mismatching cuda_devices: {devices_string}") elif len(devices) == 1: return devices.pop() else: return -1
def label_scores(self, encoded_text:torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Computes edge label scores for a fixed tree structure (given by head_indices) for a batch of sentences. Parameters ---------- encoded_text : torch.Tensor, required The input sentence, with artifical root node (head sentinel) added in the beginning of shape (batch_size, sequence length, encoding dim) head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word (predicted or gold). Returns ------- edge_label_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each given arc. """ # shape (batch_size, sequence_length, tag_representation_dim) head_label_representation = self._dropout(self.head_label_feedforward(encoded_text)) child_label_representation = self._dropout(self.child_label_feedforward(encoded_text)) batch_size = head_label_representation.size(0) # shape (batch_size,) range_vector = get_range_vector(batch_size, get_device_of(head_label_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_label_representations = head_label_representation[range_vector, head_indices] selected_head_label_representations = selected_head_label_representations.contiguous() combined = self.activation(selected_head_label_representations + child_label_representation) #(batch_size, sequence_length, num_head_tags) edge_label_logits = self.label_out_layer(combined) return edge_label_logits
def get_next_sentence_output(use_fp16, input_tensor, next_sentence_feedforward, labels): """Get loss and log probs for the next sentence prediction.""" # Simple binary classification. Note that 0 is "next sentence" and 1 is # "random sentence". This weight matrix is not used after pre-training. logits = next_sentence_feedforward(input_tensor) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) labels = labels.view(-1, 1) if labels.is_cuda: one_hot_labels = torch.cuda.FloatTensor( labels.size(0), 2, device=util.get_device_of(labels)) else: one_hot_labels = torch.FloatTensor(labels.size(0), 2) if use_fp16: one_hot_labels = one_hot_labels.half() one_hot_labels.zero_() one_hot_labels.scatter_(1, labels, 1) per_example_loss = -(one_hot_labels * log_probs).sum(-1) loss = torch.sum(per_example_loss) return (loss, per_example_loss, log_probs)
def forward(self, inputs: torch.Tensor, mask: torch.Tensor, span: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ,unused-argument # input -> [B x seq_len x d], offset -> [B x 2] batch_size, seq_len, _ = inputs.size() offset = span[:, 0].unsqueeze(-1) position_range = util.get_range_vector( seq_len, util.get_device_of(inputs)).repeat((batch_size, 1)) offset_mask = position_range == offset position_markers = inputs.new_ones((batch_size, seq_len), requires_grad=True) position_markers = position_markers * offset_mask.float() position_markers = position_markers.unsqueeze(-1) return position_markers
def scoped_pool(tokens: torch.Tensor, mask: torch.Tensor, pooling: str, pooling_scopes: List[PoolingScope], is_bidirectional: bool = False, head: torch.Tensor = None, tail: torch.Tensor = None) -> torch.Tensor: pooling_masks = [] if PoolingScope.SEQUENCE in pooling_scopes: pooling_masks.append(mask.unsqueeze(-1)) if PoolingScope.HEAD in pooling_scopes or PoolingScope.TAIL in pooling_scopes: assert head is not None and tail is not None, \ "head and tail offsets are required for pooling on entities" batch_size, seq_len, _ = tokens.size() pos_range = util.get_range_vector( seq_len, util.get_device_of(tokens)).repeat((batch_size, 1)) if PoolingScope.HEAD in pooling_scopes: head_start = head[:, 0].unsqueeze(dim=1) head_end = head[:, 1].unsqueeze(dim=1) head_mask = ((torch.ge(pos_range, head_start) * torch.le(pos_range, head_end)).unsqueeze(-1).long()) pooling_masks.append(head_mask) if PoolingScope.TAIL in pooling_scopes: tail_start = tail[:, 0].unsqueeze(dim=1) tail_end = tail[:, 1].unsqueeze(dim=1) tail_mask = ((torch.ge(pos_range, tail_start) * torch.le(pos_range, tail_end)).unsqueeze(-1).long()) pooling_masks.append(tail_mask) assert pooling_masks, "At least one pooling scope must be defined" pooled = [pool(tokens, mask, dim=1, pooling=pooling, is_bidirectional=is_bidirectional) for mask in pooling_masks] return torch.cat(pooled, dim=-1)
def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous() # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits
def _convert_actionscores_to_probs(batch_actionseq_scores: List[List[torch.Tensor]]) -> List[torch.FloatTensor]: """ Normalize program scores in a beam for an instance to get probabilities Returns: --------- List[torch.FloatTensor]: For each instance, a tensor the size of number of predicted programs containing normalized probabilities """ # Convert batch_action_scores to a single tensor the size of number of programs for each instance device_id = allenutil.get_device_of(batch_actionseq_scores[0][0]) # Inside List[torch.Tensor] is a list of scalar-tensor with prob of each program for this instance # The prob is normalized across the programs in the beam batch_actionseq_probs = [] for score_list in batch_actionseq_scores: scores_astensor = allenutil.move_to_device(torch.cat([x.view(1) for x in score_list]), device_id) # allenutil.masked_softmax(scores_astensor, mask=None) action_probs = torch.nn.functional.softmax(scores_astensor, dim=-1) batch_actionseq_probs.append(action_probs) return batch_actionseq_probs
def forward(self, inputs: torch.Tensor, mask: torch.Tensor, span: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ # input -> [B x seq_len x d], offset -> [B x 2] batch_size, seq_len, _ = inputs.size() offset = span[:, 0] position_range = util.get_range_vector( seq_len, util.get_device_of(inputs)).repeat((batch_size, 1)) relative_positions = (1 + self._n_position + position_range - offset.unsqueeze(dim=1)) # mask padding so it won't receive a positional embedding relative_positions = relative_positions * mask.long() return self._embedding(relative_positions)
def common_step(self, batch, phase="train"): token_ids, type_ids, offsets, wordpiece_mask, span_idx, span_tag, pos_tags, word_mask, mrc_mask, meta_data, child_arcs, child_tags = ( batch["token_ids"], batch["type_ids"], batch["offsets"], batch["wordpiece_mask"], batch["span_idx"], batch["span_tag"], batch["pos_tags"], batch["word_mask"], batch["mrc_mask"], batch["meta_data"], batch["child_arcs"], batch["child_tags"]) parent_probs, parent_tag_probs, child_probs, child_tag_probs, parent_arc_nll, parent_tag_nll, child_arc_loss, child_tag_loss = self( token_ids, type_ids, offsets, wordpiece_mask, span_idx, span_tag, child_arcs, child_tags, pos_tags, word_mask, mrc_mask) # todo fix child bug loss = parent_arc_nll + parent_tag_nll # + child_arc_loss + child_tag_loss eval_mask = self._get_mask_for_eval(mask=word_mask, pos_tags=pos_tags) bsz = span_idx.size(0) # [bsz] batch_range_vector = get_range_vector(bsz, get_device_of(span_idx)) gold_positions = span_idx[:, 0] eval_mask = eval_mask[batch_range_vector, gold_positions] # [bsz] if phase == "train" or not self.args.use_mst: # [bsz] pred_positions = parent_probs.argmax(1) metric_name = f"{phase}_stat" metric = getattr(self, metric_name) metric.update( pred_positions.unsqueeze(-1), # [bsz, 1] parent_tag_probs[batch_range_vector, pred_positions].argmax( 1).unsqueeze(-1), # [bsz, 1] gold_positions.unsqueeze(-1), # [bsz, 1] span_tag.unsqueeze(-1), # [bsz, 1] eval_mask.unsqueeze(-1) # [bsz, 1] ) else: metric = getattr(self, f"{phase}_stat") metric.update(meta_data["ann_idx"], meta_data["word_idx"], [len(x) for x in meta_data["words"]], parent_probs, parent_tag_probs, child_probs, child_tag_probs, span_idx, span_tag, eval_mask) self.log(f'{phase}_loss', loss) return loss
def concat_features(self, emb_z, token_indices, span_len): """ concatenate two features args: emb_z (batch_size, sentence_len, featsdim) : contextualized word representations token_indices: Dict[str, LongTensor], indices of different fields span_len: a number (from 0) """ batch_size = emb_z.size(0) sent_len = emb_z.size(1) hidden_dim = emb_z.size(2) emb_z = emb_z.unsqueeze(1).expand(batch_size, 1, sent_len, hidden_dim) span_exprs = [ emb_z[:, :, i:i + span_len + 1] for i in range(sent_len - span_len) ] span_exprs = torch.cat(span_exprs, 1) endpoint_vec = (span_exprs[:, :, 0] - span_exprs[:, :, span_len]).unsqueeze(2).expand( batch_size, sent_len - span_len, span_len + 1, hidden_dim) index = Variable(torch.LongTensor(range(span_len + 1))).cuda( util.get_device_of(emb_z)) index = self.index_embeds(index).unsqueeze(0).unsqueeze(0).expand( batch_size, sent_len - span_len, span_len + 1, self.index_embeds_dim) BILOU_features = self.get_BILOU_features(token_indices, sent_len, span_len) new_emb = torch.cat((span_exprs, BILOU_features, endpoint_vec, index), 3) return new_emb.transpose(1, 2).contiguous()
def generate_embeddings_for_pooling(sequence_tensor, span_starts, span_ends): #(B, L, E), #(B, L), #(B, L) span_starts = span_starts.unsqueeze(-1) span_ends = (span_ends - 1).unsqueeze(-1) span_widths = span_ends - span_starts max_batch_span_width = span_widths.max().item() + 1 # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector( max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices( span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) return span_embeddings, span_mask
def device_id(self): allenutil.get_device_of()
def _create_grammar_state( self, possible_actions: List[ProductionRule]) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. """ device = util.get_device_of(self._action_embedder.weight) # TODO(Mark): This type is pure \(- . ^)/ translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} actions_grouped_by_nonterminal: Dict[str, List[Tuple[ ProductionRule, int]]] = defaultdict(list) for i, action in enumerate(possible_actions): if action.rule == "": continue if action.is_global_rule: actions_grouped_by_nonterminal[action.nonterminal].append( (action, i)) else: raise ValueError( "The sql parser doesn't support non-global actions yet.") for key, production_rule_arrays in actions_grouped_by_nonterminal.items( ): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. global_actions = [] for production_rule_array, action_index in production_rule_arrays: global_actions.append( (production_rule_array.rule_id, action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).long() if device >= 0: global_action_tensor = global_action_tensor.to(device) global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal, reverse_productions=True)
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question[ 'token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape( total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout( self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat( 1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view( total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector( max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat( 1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat( batch_size, 1, 1) question_num_ind = question_num_ind.reshape( total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker( question_num_ind) embedded_question = torch.cat( [embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view( total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view( total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker( p3_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout( self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout( self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view( total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout( self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention( repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax( question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum( repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector ], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout( self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([ self_attention_vecs, residual_layer, residual_layer * self_attention_vecs ], dim=-1) residual_layer = F.relu( self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder( torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze( -1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view( total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append( squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, _ = self.bert_model( input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(input_mask)) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) return util.uncombine_initial_dims(mix, input_ids.size()) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector( offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = mix[range_vector, offsets2d] return util.uncombine_initial_dims(selected_embeddings, offsets.size())
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict
def _construct_loss(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax(head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, metadata: List[Dict[str, Any]], doc_span_offsets: torch.IntTensor, span_labels: torch.IntTensor = None, doc_truth_spans: torch.IntTensor = None, doc_spans_in_truth: torch.IntTensor = None, doc_relation_labels: torch.Tensor = None, truth_spans: List[Set[Tuple[int, int]]] = None, doc_relations=None, doc_ner_labels: torch.IntTensor = None, ) -> Dict[str, torch.Tensor]: # add matrix from datareader # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. doc_ner_labels : ``torch.IntTensor``. A tensor of shape # TODO, ... doc_span_offsets : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_truth_spans : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, 1), ... doc_spans_in_truth : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_relation_labels : ``torch.Tensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans), ... Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) batch_size = len(spans) document_length = text_embeddings.size(1) max_sentence_length = max( len(sentence) for document in metadata for sentence in document['doc_tokens']) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # TODO features dropout # Shape: (batch_size, num_spans, embedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_relex_spans_to_keep = int( math.floor(self._relex_spans_per_word * max_sentence_length)) # Shapes: # (batch_size, num_spans_to_keep, span_dim), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # Shape: (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = dict() output_dict["top_spans"] = top_spans output_dict["antecedent_indices"] = valid_antecedent_indices output_dict["predicted_antecedents"] = predicted_antecedents if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] # Shape: (,) loss = 0 # Shape: (batch_size, max_sentences, max_spans) doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float() # Shape: (batch_size, max_sentences, num_spans, span_dim) doc_span_embeddings = util.batched_index_select( span_embeddings, doc_span_offsets.squeeze(-1).long().clamp(min=0)) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep, 1) pruned = self._relex_mention_pruner( doc_span_embeddings, doc_span_mask, num_items_to_keep=num_relex_spans_to_keep, pass_through=['num_items_to_keep']) (top_relex_span_embeddings, top_relex_span_mask, top_relex_span_indices, top_relex_span_mention_scores) = pruned # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1) top_relex_span_mask = top_relex_span_mask.unsqueeze(-1) # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2) # TODO do we need for a mask? doc_spans = util.batched_index_select( spans, doc_span_offsets.clamp(0).squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2) top_relex_spans = nd_batched_index_select(doc_spans, top_relex_span_indices) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep). (relex_span_pair_embeddings, relex_span_pair_mask) = self._compute_relex_span_pair_embeddings( top_relex_span_embeddings, top_relex_span_mask.squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels) relex_scores = self._compute_relex_scores( relex_span_pair_embeddings, top_relex_span_mention_scores) output_dict['relex_scores'] = relex_scores output_dict['top_relex_spans'] = top_relex_spans if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels_ = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long( ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability x to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs) negative_marginal_log_likelihood *= top_span_mask.squeeze( -1).float() negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum( ) self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) coref_loss = negative_marginal_log_likelihood output_dict['coref_loss'] = coref_loss loss += self._loss_coref_weight * coref_loss if doc_relations is not None: # The adjacency matrix for relation extraction is very sparse. # As it is not just sparse, but row/column sparse (only few # rows and columns are non-zero and in that case these rows/columns # are not sparse), we implemented our own matrix for the case. # Here we have indices of truth spans and mapping, using which # we map prediction matrix on truth matrix. # TODO Add teacher forcing support. # Shape: (batch_size, max_sentences, num_relex_spans_to_keep), relative_indices = top_relex_span_indices # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1), compressed_indices = nd_batched_padded_index_select( doc_spans_in_truth, relative_indices) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans) gold_pruned_rows = nd_batched_padded_index_select( doc_relation_labels, compressed_indices.squeeze(-1), padding_value=0) gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3, 2).contiguous() # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep) gold_pruned_matrices = nd_batched_padded_index_select( gold_pruned_rows, compressed_indices.squeeze(-1), padding_value=0) # pad with epsilon gold_pruned_matrices = gold_pruned_matrices.permute( 0, 1, 3, 2).contiguous() # TODO log_mask relex score before passing relex_loss = nd_cross_entropy_with_logits(relex_scores, gold_pruned_matrices, relex_span_pair_mask) output_dict['relex_loss'] = relex_loss self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2), truth_spans) self._compute_relex_metrics(output_dict, doc_relations) loss += self._loss_relex_weight * relex_loss if doc_ner_labels is not None: # Shape: (batch_size, max_sentences, num_spans, num_ner_classes) ner_scores = self._ner_scorer(doc_span_embeddings) output_dict['ner_scores'] = ner_scores ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels, doc_span_mask) output_dict['ner_loss'] = ner_loss loss += self._loss_ner_weight * ner_loss if not isinstance(loss, int): # If loss is not yet modified output_dict["loss"] = loss return output_dict
def forward(self, inputs: torch.Tensor, offsets: torch.Tensor = None) -> torch.Tensor: """ Parameters ---------- inputs: ``torch.Tensor``, required A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings for the current batch. offsets: ``torch.Tensor``, required A ``(batch_size, max_sequence_length)`` tensor representing the word offsets for the current batch. Returns ------- ``[torch.Tensor]`` An embedding representation of the input sequence having shape ``(batch_size, sequence_length, embedding_dim)`` """ batch_size, num_timesteps = inputs.size() # the transformer embedding consists of the byte pair embeddings, # the special embeddings and the position embeddings. # the position embeddings are always at least self._transformer.n_ctx, # but may be longer. # the transformer "vocab" consists of the actual vocab and the # positional encodings. Here we want the count of just the former. vocab_size = self._transformer.vocab_size - self._transformer.n_ctx # vocab_size, vocab_size + 1, ... positional_encodings = ( get_range_vector(num_timesteps, device=get_device_of(inputs)) + vocab_size) # Combine the inputs with positional encodings batch_tensor = torch.stack( [inputs, positional_encodings.expand(batch_size, num_timesteps)], dim=-1, # (batch_size, num_timesteps) ) byte_pairs_mask = inputs != 0 # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim) layer_activations = self._transformer(batch_tensor) # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim) if self._top_layer_only: mix = layer_activations[-1] else: mix = self._scalar_mix(layer_activations, byte_pairs_mask) # These embeddings are one per byte-pair, but we want one per original _word_. # So we choose the embedding corresponding to the last byte pair for each word, # which is captured by the ``offsets`` input. if offsets is not None: range_vector = get_range_vector( batch_size, device=get_device_of(mix)).unsqueeze(1) last_byte_pair_embeddings = mix[range_vector, offsets] else: # allow to return all byte pairs by passing no offsets seq_len = (byte_pairs_mask > 0).long().sum(dim=1).max() last_byte_pair_embeddings = mix[:, :seq_len] return last_byte_pair_embeddings
def forward( self, # type: ignore label_indices: torch.LongTensor, token_representations: torch.FloatTensor = None, raw_tokens: List[List[str]] = None, labels: torch.LongTensor = None, **kwargs) -> Dict[str, torch.Tensor]: """ If ``token_representations`` is provided, ``tokens`` is not required. If ``token_representations`` is ``None``, then ``tokens`` is required. Parameters ---------- label_indices : torch.LongTensor A LongTensor of shape (batch_size, max_num_adpositions) with the tokens to predict a label for for each element (sentence) in the batch. token_representations : torch.FloatTensor, optional (default = None) A tensor of shape (batch_size, sequence_length, representation_dim) with the represenatation of the first token. If None, we use a contextualizer within this model to produce the token representation. raw_tokens : List[List[str]], optional (default = None) A batch of lists with the raw token strings. Used to compute token_representations, if either are None. labels : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_label_indices)``. Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_label_indices, num_classes)`` representing unnormalized log probabilities of the classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_label_indices, num_classes)`` representing a distribution of the tag classes. loss : torch.FloatTensor, optional A scalar loss to be optimized. """ # Convert to LongTensor # TODO: add PR to ArrayField to preserve array types. label_indices = label_indices.long() if token_representations is None: if self._contextualizer is None: raise ConfigurationError( "token_representation not provided as input to the model, and no " "contextualizer was specified. Either add a contextualizer to your " "dataset reader (preferred if your contextualizer is frozen) or to " "this model (if you wish to train your contextualizer).") if raw_tokens is None: raise ValueError( "Input raw_tokens is ``None`` --- make sure to set " "include_raw_tokens in the DatasetReader to True.") if label_indices is None: raise ValueError("Did not recieve any token indices, needed " "if the contextualizer is within the model.") # Convert contextualizer output into a tensor # Shape: (batch_size, max_seq_len, representation_dim) token_representations, _ = pad_contextualizer_output( self._contextualizer(raw_tokens)) # Move token representation to the same device as the # module (CPU or CUDA). TODO(nfliu): This only works if the module # is on one device. device = next(self._decoder._linear_layers[0].parameters()).device token_representations = token_representations.to(device) text_mask = get_text_mask_from_representations(token_representations) text_mask = text_mask.to(device) label_mask = self._get_label_mask_from_label_indices(label_indices) label_mask = label_mask.to(device) # Mask out the -1 padding in the label_indices, since that doesn't # work with indexing. Note that we can't 0 pad because 0 is actually # a valid label index, so we pad with -1 just for the purposes of # proper mask calculation and then convert to 0-padding by applying # the mask. label_indices = label_indices * label_mask # Encode the token representation. encoded_token_representations = self._encoder(token_representations, text_mask) batch_size = label_indices.size(0) # Index into the encoded_token_representations to get tensors corresponding # to the representations of the tokens to predict labels for. # Shape: (batch_size, num_label_indices, representation_dim) range_vector = get_range_vector( batch_size, get_device_of(label_indices)).unsqueeze(1) selected_token_representations = encoded_token_representations[ range_vector, label_indices] selected_token_representations = selected_token_representations.contiguous( ) # Decode out a label from the token representation # Shape: (batch_size, num_label_indices, num_classes) logits = self._decoder(selected_token_representations) class_probabilities = F.softmax(logits, dim=-1) output_dict = { "logits": logits, "class_probabilities": class_probabilities } if labels is not None: loss = sequence_cross_entropy_with_logits( logits, labels, label_mask, average=self.loss_average) for name, metric in self.metrics.items(): # When not running in error analysis mode, skip # metrics that start with "_" if not self.error_analysis and name.startswith("_"): continue metric(logits, labels, label_mask.float()) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) if self._use_gold_mentions: if text_embeddings.is_cuda: device = torch.device("cuda") else: device = torch.device("cpu") s = [ torch.as_tensor(pair, dtype=torch.long, device=device) for cluster in metadata[0]["clusters"] for pair in cluster ] gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1) span_mask = spans.unsqueeze(2) - gm span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1] == 0) span_mask, _ = (span_mask == 2).max(-1) num_spans = span_mask.sum().item() span_mask = span_mask.float() else: span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() num_spans = spans.size(1) # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = self._generate_valid_antecedents( num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask, ) # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) coreference_log_probs = util.last_dim_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, _ = self.bert_model(input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(input_mask)) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) return util.uncombine_initial_dims(mix, input_ids.size()) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector(offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = mix[range_vector, offsets2d] return util.uncombine_initial_dims(selected_embeddings, offsets.size())
def _distance_pruning( self, top_span_embeddings: torch.FloatTensor, top_span_mention_scores: torch.FloatTensor, max_antecedents: int, ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]: """ Generates antecedents for each span and prunes down to `max_antecedents`. This method prunes antecedents only based on distance (i.e. number of intervening spans). The closest antecedents are kept. # Parameters top_span_embeddings: torch.FloatTensor, required. The embeddings of the top spans. (batch_size, num_spans_to_keep, embedding_size). top_span_mention_scores: torch.FloatTensor, required. The mention scores of the top spans. (batch_size, num_spans_to_keep). max_antecedents: int, required. The maximum number of antecedents to keep for each span. # Returns top_partial_coreference_scores: torch.FloatTensor The partial antecedent scores for each span-antecedent pair. Computed by summing the span mentions scores of the span and the antecedent. This score is partial because compared to the full coreference scores, it lacks the interaction term w * FFNN([g_i, g_j, g_i * g_j, features]). (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_mask: torch.BoolTensor The mask representing whether each antecedent span is valid. Required since different spans have different numbers of valid antecedents. For example, the first span in the document should have no valid antecedents. (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_offsets: torch.LongTensor The distance between the span and each of its antecedents in terms of the number of considered spans (i.e not the word distance between the spans). (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_indices: torch.LongTensor The indices of every antecedent to consider with respect to the top k spans. (batch_size, num_spans_to_keep, max_antecedents) """ # These antecedent matrices are independent of the batch dimension - they're just a function # of the span's position in top_spans. # The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. num_spans_to_keep = top_span_embeddings.size(1) device = util.get_device_of(top_span_embeddings) # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) ( top_antecedent_indices, top_antecedent_offsets, top_antecedent_mask, ) = self._generate_valid_antecedents( # noqa num_spans_to_keep, max_antecedents, device ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores.unsqueeze(-1), top_antecedent_indices ).squeeze(-1) # Shape: (batch_size, num_spans_to_keep, max_antecedents) * 4 top_partial_coreference_scores = ( top_span_mention_scores.unsqueeze(-1) + top_antecedent_mention_scores ) top_antecedent_indices = top_antecedent_indices.unsqueeze(0).expand_as( top_partial_coreference_scores ) top_antecedent_offsets = top_antecedent_offsets.unsqueeze(0).expand_as( top_partial_coreference_scores ) top_antecedent_mask = top_antecedent_mask.expand_as(top_partial_coreference_scores) return ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, top_antecedent_indices, )
def forward( self, passage_attention: torch.Tensor, passage_lengths: List[int], count_answer: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: device_id = allenutil.get_device_of(passage_attention) batch_size, max_passage_length = passage_attention.size() # Shape: (B, passage_length) passage_mask = (passage_attention >= 0).float() # List of (B, P) shaped tensors scaled_attentions = [ passage_attention * sf for sf in self.scaling_vals ] # Shape: (B, passage_length, num_scaling_factors) scaled_passage_attentions = torch.stack(scaled_attentions, dim=2) # Shape (batch_size, 1) passage_len_bias = self.passagelength_to_bias( passage_mask.sum(1, keepdim=True)) scaled_passage_attentions = scaled_passage_attentions * passage_mask.unsqueeze( 2) # Shape: (B, passage_length, hidden_dim) count_hidden_repr = self.passage_attention_to_count( scaled_passage_attentions, passage_mask) # Shape: (B, passage_length, 1) -- score for each token passage_span_logits = self.passage_count_hidden2logits( count_hidden_repr) # Shape: (B, passage_length) -- sigmoid on token-score token_sigmoids = torch.sigmoid(passage_span_logits.squeeze(2)) token_sigmoids = token_sigmoids * passage_mask # Shape: (B, 1) -- sum of sigmoids. This will act as the predicted mean # passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True) + passage_len_bias passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True) # Shape: (1, count_vals) self.countvals = allenutil.get_range_vector( 10, device=device_id).unsqueeze(0).float() variance = 0.2 # Shape: (batch_size, count_vals) l2_by_vsquared = torch.pow(self.countvals - passage_count_mean, 2) / (2 * variance * variance) exp_val = torch.exp(-1 * l2_by_vsquared) + 1e-30 # Shape: (batch_size, count_vals) count_distribution = exp_val / (torch.sum(exp_val, 1, keepdim=True)) # Loss computation output_dict = {} loss = 0.0 pred_count_idx = torch.argmax(count_distribution, 1) if count_answer is not None: # L2-loss passage_count_mean = passage_count_mean.squeeze(1) L2Loss = F.mse_loss(input=passage_count_mean, target=count_answer.float()) loss = L2Loss predictions = passage_count_mean.detach().cpu().numpy() predictions = np.round_(predictions) gold_count = count_answer.detach().cpu().numpy() correct_vec = (predictions == gold_count) correct_perc = sum(correct_vec) / batch_size # print(f"{correct_perc} {predictions} {gold_count}") self.count_acc(correct_perc) # loss = F.cross_entropy(input=count_distribution, target=count_answer) # List of predicted count idxs, Shape: (B,) # correct_vec = (pred_count_idx == count_answer).float() # correct_perc = torch.sum(correct_vec) / batch_size # self.count_acc(correct_perc.item()) batch_loss = loss / batch_size output_dict["loss"] = batch_loss output_dict["passage_attention"] = passage_attention output_dict["passage_sigmoid"] = token_sigmoids output_dict["count_mean"] = passage_count_mean output_dict["count_distritbuion"] = count_distribution output_dict["count_answer"] = count_answer output_dict["pred_count"] = pred_count_idx return output_dict
def _construct_loss(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax(head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask, dim=-1) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Shape: (batch_size, passage_length) question_passage_similarity = torch.transpose( passage_question_similarity, 1, 2) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask, dim=-1) # Shape: (batch_size, question_length, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) passage_gate = torch.unsqueeze( self._passage_similarity_function(encoded_passage, passage_question_vectors), -1) passage_fusion = self._passage_fusion_function( encoded_passage, passage_question_vectors) gated_passage = passage_gate * passage_fusion + ( 1 - passage_gate) * encoded_passage question_gate = torch.unsqueeze( self._question_similarity_function(encoded_question, question_passage_vector), -1) question_fusion = self._question_fusion_function( encoded_question, question_passage_vector) gated_question = question_gate * question_fusion + ( 1 - question_gate) * encoded_question passage_passage_similarity = self._self_matrix_attention( gated_passage, gated_passage) passage_passage_attention = util.masked_softmax( passage_passage_similarity, passage_mask, dim=-1) passage_passage_vector = util.weighted_sum(gated_passage, passage_passage_attention) final_passage = self._fusion_function(gated_passage, passage_passage_vector) modeled_passage = self._dropout( self._passage_modeling_layer(final_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) span_logits = self._span_predictor(modeled_passage) modeled_question = self._question_modeling_layer( gated_question, question_lstm_mask) question_vector = self._question_encoding_layer( modeled_question, question_lstm_mask).unsqueeze(-1) span_start_logits = torch.bmm(self._span_start_weight(modeled_passage), question_vector).squeeze(-1) span_end_logits = torch.bmm(self._span_end_weight(modeled_passage), question_vector).squeeze(-1) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: device_id = util.get_device_of(span_start) weight = self._span_weight.cuda( device_id) if device_id >= 0 else self._span_weight arange_mask = util.get_range_vector(passage_length, util.get_device_of(span_start)) span_mask = (arange_mask >= span_start) & (arange_mask <= span_end) span_loss = nll_loss(self._masked_log_softmax( span_logits, passage_mask).transpose(1, 2), span_mask.long(), weight=weight) loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss + span_loss / 2 # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def create_cached_cnn_embeddings(self, tokens: List[str]) -> None: """ Given a list of tokens, this method precomputes word representations by running just the character convolutions and highway layers of elmo, essentially creating uncontextual word vectors. On subsequent forward passes, the word ids are looked up from an embedding, rather than being computed on the fly via the CNN encoder. This function sets 3 attributes: _word_embedding : `torch.Tensor` The word embedding for each word in the tokens passed to this method. _bos_embedding : `torch.Tensor` The embedding for the BOS token. _eos_embedding : `torch.Tensor` The embedding for the EOS token. # Parameters tokens : `List[str]`, required. A list of tokens to precompute character convolutions for. """ tokens = [ ELMoCharacterMapper.bos_token, ELMoCharacterMapper.eos_token ] + tokens timesteps = 32 batch_size = 32 chunked_tokens = lazy_groups_of(iter(tokens), timesteps) all_embeddings = [] device = get_device_of(next(self.parameters())) for batch in lazy_groups_of(chunked_tokens, batch_size): # Shape (batch_size, timesteps, 50) batched_tensor = batch_to_ids(batch) # NOTE: This device check is for when a user calls this method having # already placed the model on a device. If this is called in the # constructor, it will probably happen on the CPU. This isn't too bad, # because it's only a few convolutions and will likely be very fast. if device >= 0: batched_tensor = batched_tensor.cuda(device) output = self._token_embedder(batched_tensor) token_embedding = output["token_embedding"] mask = output["mask"] token_embedding, _ = remove_sentence_boundaries( token_embedding, mask) all_embeddings.append( token_embedding.view(-1, token_embedding.size(-1))) full_embedding = torch.cat(all_embeddings, 0) # We might have some trailing embeddings from padding in the batch, so # we clip the embedding and lookup to the right size. full_embedding = full_embedding[:len(tokens), :] embedding = full_embedding[2:len(tokens), :] vocab_size, embedding_dim = list(embedding.size()) from allennlp.modules.token_embedders import Embedding # type: ignore self._bos_embedding = full_embedding[0, :] self._eos_embedding = full_embedding[1, :] self._word_embedding = Embedding( # type: ignore vocab_size, embedding_dim, weight=embedding.data, trainable=self._requires_grad, padding_index=0, )
def forward(self, spans_tensor: torch.FloatTensor, spans_mask: torch.FloatTensor, question_tensor: torch.FloatTensor, question_mask: torch.FloatTensor, evd_chain_labels: torch.FloatTensor, self_att_layer: Seq2SeqEncoder, sent_encoder: Seq2SeqEncoder, transition_mask: torch.IntTensor = None, start_transition_mask: torch.FloatTensor = None, get_all_beam: bool=False): #print("spans_tensor", spans_tensor.shape) #print("spans_mask", spans_mask.shape) batch_size, num_spans, max_batch_span_width = spans_mask.size() # Shape: (batch_size, num_spans, embedding_dim) max_pooled_span_emb = spans_tensor[:, :, 0, :] # self attention on spans representation # shape: (batch_size, num_spans, embedding_dim) #max_pooled_span_emb = max_pooled_span_emb.view(batch_size, num_spans, spans_tensor.size(2)) # shape: (batch_size, num_spans) max_pooled_span_mask = (torch.sum(spans_mask, dim=-1) >= 1).float() ''' # shape: (batch_size, num_spans, embedding_dim) max_pooled_span_emb = sent_encoder(max_pooled_span_emb, max_pooled_span_mask) # shape: (batch_size, num_spans, embedding_dim) att_max_pooled_span_emb, _, att_score = self_att_layer(max_pooled_span_emb, max_pooled_span_mask) max_pooled_span_emb = max_pooled_span_emb + att_max_pooled_span_emb ''' att_score = None # extract the final hidden states as the question vector # Shape: (batch_size, embedding_dim) #question_emb = util.get_final_encoder_states(question_tensor, question_mask, True) question_emb = question_tensor[:, 0, :] # decode the most likely evidence path # shape (all_predictions): (batch_size, K, num_decoding_steps) # shape (all_logprobs): (batch_size, K, num_decoding_steps) # shape (seq_logprobs): (batch_size, K) # shape (final_hidden): (batch_size, K, decoder_output_dim) #print("max_pooled_span_emb", max_pooled_span_emb.shape) #print("max_pooled_span_mask", max_pooled_span_mask.shape) print("start trans mask:", start_transition_mask) print("trans mask:", transition_mask) all_predictions, all_logprobs, seq_logprobs, final_hidden = self.evd_decoder(max_pooled_span_emb, max_pooled_span_mask, question_emb, aux_input=None,#question_emb,#None transition_mask=transition_mask, start_transition_mask=start_transition_mask, labels=evd_chain_labels) print("all prediction:", all_predictions) # The selection order of each sentence. Set to -1 if not being chosen # shape: (batch_size, K, num_spans) _, beam, num_steps = all_predictions.size() orders = spans_tensor.new_ones((batch_size, beam, 1+num_spans)) * -1 indices = util.get_range_vector(num_steps, util.get_device_of(spans_tensor)).\ float().\ unsqueeze(0).\ unsqueeze(0).\ expand(batch_size, beam, num_steps) orders.scatter_(2, all_predictions, indices) orders = orders[:, :, 1:] # For beamsearch, get the top one. For other helpers, just like squeeze if not get_all_beam: all_predictions = all_predictions[:, 0, :] all_logprobs = all_logprobs[:, 0, :] seq_logprobs = seq_logprobs[:, 0] final_hidden = final_hidden[:, 0, :] # build the gate. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) or (batch_size, K, 1+num_spans) if not get_all_beam: gate = spans_tensor.new_zeros((batch_size, 1+num_spans)) else: gate = spans_tensor.new_zeros((batch_size, beam, 1+num_spans)) gate.scatter_(-1, all_predictions, 1.) # remove the column for end embedding # shape: (batch_size, num_spans) or (batch_size, K, num_spans) gate = gate[..., 1:] #print("gate:", gate) #print("real num:", torch.sum(gate, dim=1)) #print("seq probs:", torch.exp(seq_logprobs)) # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1) if not get_all_beam: gate = gate.reshape(batch_size * num_spans, 1) else: gate = gate.reshape(batch_size * beam * num_spans, 1) # The probability of each selected sentence being selected. If not selected, set to 0. # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1) if not get_all_beam: gate_probs = spans_tensor.new_zeros((batch_size, 1+num_spans)) else: gate_probs = spans_tensor.new_zeros((batch_size, beam, 1+num_spans)) gate_probs.scatter_(-1, all_predictions, all_logprobs.exp()) gate_probs = gate_probs[..., 1:] if not get_all_beam: gate_probs = gate_probs.reshape(batch_size * num_spans, 1) else: gate_probs = gate_probs.reshape(batch_size * beam * num_spans, 1) return all_predictions, all_logprobs, seq_logprobs, gate, gate_probs, max_pooled_span_mask, att_score, orders
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) initial_dims = list(input_ids.shape[:-1]) # The embedder may receive an input tensor that has a sequence length longer than can # be fit. In that case, we should expect the wordpiece indexer to create padded windows # of length `self.max_pieces` for us, and have them concatenated into one long sequence. # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." # We can then split the sequence into sub-sequences of that length, and concatenate them # along the batch dimension so we effectively have one huge batch of partial sentences. # This can then be fed into BERT without any sentence length issues. Keep in mind # that the memory consumption can dramatically increase for large batches with extremely # long sentences. needs_split = full_seq_len > self.max_pieces last_window_size = 0 if needs_split: # Split the flattened list by the window size, `max_pieces` split_input_ids = list(input_ids.split(self.max_pieces, dim=-1)) # We want all sequences to be the same length, so pad the last sequence last_window_size = split_input_ids[-1].size(-1) padding_amount = self.max_pieces - last_window_size split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0) # Now combine the sequences along the batch dimension input_ids = torch.cat(split_input_ids, dim=0) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, _ = self.bert_model( input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(input_mask)) all_encoder_layers = torch.stack(all_encoder_layers) if needs_split: # First, unpack the output embeddings into one long sequence again unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) # Next, select indices of the sequence such that it will result in embeddings representing the original # sentence. To capture maximal context, the indices will be the middle part of each embedded window # sub-sequence (plus any leftover start and final edge windows), e.g., # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start # and final windows with indices [0, 1] and [14, 15] respectively. # Find the stride as half the max pieces, ignoring the special start and end tokens # Calculate an offset to extract the centermost embeddings of each window stride = (self.max_pieces - self.start_tokens - self.end_tokens) // 2 stride_offset = stride // 2 + self.start_tokens first_window = list(range(stride_offset)) max_context_windows = [ i for i in range(full_seq_len) if stride_offset - 1 < i % self.max_pieces < stride_offset + stride ] final_window_start = full_seq_len - ( full_seq_len % self.max_pieces) + stride_offset + stride final_window = list(range(final_window_start, full_seq_len)) select_indices = first_window + max_context_windows + final_window initial_dims.append(len(select_indices)) recombined_embeddings = unpacked_embeddings[:, :, select_indices] else: recombined_embeddings = all_encoder_layers # Recombine the outputs of all layers # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) # recombined = torch.cat(combined, dim=2) input_mask = (recombined_embeddings != 0).long() # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) dims = initial_dims if needs_split else input_ids.size() layers = util.uncombine_initial_dims(recombined_embeddings, dims) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector( offsets2d.size(0), device=util.get_device_of(recombined_embeddings)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = recombined_embeddings[:, range_vector, offsets2d] layers = util.uncombine_initial_dims(selected_embeddings, offsets.size()) if self._scalar_mix is not None: return self._scalar_mix(layers, input_mask) elif self.combine_layers == "last": return layers[-1] else: return layers
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # shape (batch_size, sequence_length, 1) global_attention_logits = self._global_attention(sequence_tensor) # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector(max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_logits = util.batched_index_select(global_attention_logits, span_indices, flat_span_indices).squeeze(-1) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_weights = util.masked_softmax(span_attention_logits, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised attention distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) if span_indices_mask is not None: # Above we were masking the widths of spans with respect to the max # span width in the batch. Here we are masking the spans which were # originally passed in as padding. return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float() return attended_text_embeddings
def _coarse_to_fine_pruning( self, top_span_embeddings: torch.FloatTensor, top_span_mention_scores: torch.FloatTensor, top_span_mask: torch.BoolTensor, max_antecedents: int, ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]: """ Generates antecedents for each span and prunes down to `max_antecedents`. This method prunes antecedents using a fast bilinar interaction score between a span and a candidate antecedent, and the highest-scoring antecedents are kept. # Parameters top_span_embeddings: torch.FloatTensor, required. The embeddings of the top spans. (batch_size, num_spans_to_keep, embedding_size). top_span_mention_scores: torch.FloatTensor, required. The mention scores of the top spans. (batch_size, num_spans_to_keep). top_span_mask: torch.BoolTensor, required. The mask for the top spans. (batch_size, num_spans_to_keep). max_antecedents: int, required. The maximum number of antecedents to keep for each span. # Returns top_partial_coreference_scores: torch.FloatTensor The partial antecedent scores for each span-antecedent pair. Computed by summing the span mentions scores of the span and the antecedent as well as a bilinear interaction term. This score is partial because compared to the full coreference scores, it lacks the interaction term w * FFNN([g_i, g_j, g_i * g_j, features]). (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_mask: torch.BoolTensor The mask representing whether each antecedent span is valid. Required since different spans have different numbers of valid antecedents. For example, the first span in the document should have no valid antecedents. (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_offsets: torch.LongTensor The distance between the span and each of its antecedents in terms of the number of considered spans (i.e not the word distance between the spans). (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_indices: torch.LongTensor The indices of every antecedent to consider with respect to the top k spans. (batch_size, num_spans_to_keep, max_antecedents) """ batch_size, num_spans_to_keep = top_span_embeddings.size()[:2] device = util.get_device_of(top_span_embeddings) # Shape: (1, num_spans_to_keep, num_spans_to_keep) _, _, valid_antecedent_mask = self._generate_valid_antecedents( num_spans_to_keep, num_spans_to_keep, device ) mention_one_score = top_span_mention_scores.unsqueeze(1) mention_two_score = top_span_mention_scores.unsqueeze(2) bilinear_weights = self._coarse2fine_scorer(top_span_embeddings).transpose(1, 2) bilinear_score = torch.matmul(top_span_embeddings, bilinear_weights) # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op partial_antecedent_scores = mention_one_score + mention_two_score + bilinear_score # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op span_pair_mask = top_span_mask.unsqueeze(-1) & valid_antecedent_mask # Shape: # (batch_size, num_spans_to_keep, max_antecedents) * 3 ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_indices, ) = util.masked_topk(partial_antecedent_scores, span_pair_mask, max_antecedents) top_span_range = util.get_range_vector(num_spans_to_keep, device) # Shape: (num_spans_to_keep, num_spans_to_keep); broadcast op valid_antecedent_offsets = top_span_range.unsqueeze(-1) - top_span_range.unsqueeze(0) # TODO: we need to make `batched_index_select` more general to make this less awkward. top_antecedent_offsets = util.batched_index_select( valid_antecedent_offsets.unsqueeze(0) .expand(batch_size, num_spans_to_keep, num_spans_to_keep) .reshape(batch_size * num_spans_to_keep, num_spans_to_keep, 1), top_antecedent_indices.view(-1, max_antecedents), ).reshape(batch_size, num_spans_to_keep, max_antecedents) return ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, top_antecedent_indices, )
def inference_coref(self, batch, embedded_text_input_relation, mask): submodel = self.model._tagger_coref ### Fast inference of coreference ### spans = batch["spans"] document_length = mask.size(1) num_spans = spans.size(1) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() spans = F.relu(spans.float()).long() encoded_text_coref = submodel._context_layer( embedded_text_input_relation, mask) endpoint_span_embeddings = submodel._endpoint_span_extractor( encoded_text_coref, spans) attended_span_embeddings = submodel._attentive_span_extractor( embedded_text_input_relation, spans) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) num_spans_to_keep = int( math.floor(submodel._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = submodel._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) max_antecedents = min(submodel._max_antecedents, num_spans_to_keep) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = submodel._generate_valid_antecedents( num_spans_to_keep, max_antecedents, util.get_device_of(mask)) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) span_pair_embeddings = submodel._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) coreference_scores = submodel._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask, ) _, predicted_antecedents = coreference_scores.max(2) predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents, } return output_dict