def masking_blockdiagonal(passage_length, window, device_id): """ Make a (passage_length, passage_length) tensor M of 1 and -1 in which for each row x, M[:, x, y] = -1 if y < x - window or y > x + window, else it is 1. Basically for the x-th row, the [x-win, x+win] columns should be 1, and rest -1 """ lower_limit = [max(0, i - window) for i in range(passage_length)] upper_limit = [ min(passage_length, i + window) for i in range(passage_length) ] # Tensors of lower and upper limits for each row lower = allenutil.move_to_device(torch.LongTensor(lower_limit), cuda_device=device_id) upper = allenutil.move_to_device(torch.LongTensor(upper_limit), cuda_device=device_id) lower_un = lower.unsqueeze(1) upper_un = upper.unsqueeze(1) # Range vector for each row lower_range_vector = allenutil.get_range_vector( passage_length, device=device_id).unsqueeze(0) upper_range_vector = allenutil.get_range_vector( passage_length, device=device_id).unsqueeze(0) # Masks for lower and upper limits of the mask lower_mask = lower_range_vector >= lower_un upper_mask = upper_range_vector <= upper_un # Final-mask that we require # Shape: (passage_length, passage_length); (passage_length, passage_length) inwindow_mask = (lower_mask == upper_mask).float() outwindow_mask = (lower_mask != upper_mask).float() return inwindow_mask, outwindow_mask
def forward(self, inputs: torch.Tensor, offsets: torch.Tensor = None) -> torch.Tensor: """ Parameters ---------- inputs: ``torch.Tensor``, required A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings for the current batch. offsets: ``torch.Tensor``, required A ``(batch_size, max_sequence_length)`` tensor representing the word offsets for the current batch. Returns ------- ``[torch.Tensor]`` An embedding representation of the input sequence having shape ``(batch_size, sequence_length, embedding_dim)`` """ # pylint: disable=arguments-differ batch_size, num_timesteps = inputs.size() # the transformer embedding consists of the byte pair embeddings, # the special embeddings and the position embeddings. # the position embeddings are always at least self._transformer.n_ctx, # but may be longer. # the transformer "vocab" consists of the actual vocab and the # positional encodings. Here we want the count of just the former. vocab_size = self._transformer.vocab_size - self._transformer.n_ctx # vocab_size, vocab_size + 1, ... positional_encodings = get_range_vector(num_timesteps, device=get_device_of(inputs)) + vocab_size # Combine the inputs with positional encodings batch_tensor = torch.stack([ inputs, # (batch_size, num_timesteps) positional_encodings.expand(batch_size, num_timesteps) ], dim=-1) byte_pairs_mask = inputs != 0 # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim) layer_activations = self._transformer(batch_tensor) # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim) if self._top_layer_only: mix = layer_activations[-1] else: mix = self._scalar_mix(layer_activations, byte_pairs_mask) # These embeddings are one per byte-pair, but we want one per original _word_. # So we choose the embedding corresponding to the last byte pair for each word, # which is captured by the ``offsets`` input. if offsets is not None: range_vector = get_range_vector(batch_size, device=get_device_of(mix)).unsqueeze(1) last_byte_pair_embeddings = mix[range_vector, offsets] else: # allow to return all byte pairs by passing no offsets seq_len = (byte_pairs_mask > 0).long().sum(dim=1).max() last_byte_pair_embeddings = mix[:, :seq_len] return last_byte_pair_embeddings
def forward(self, input_tensor: torch.Tensor): """ Adds a positional encoding to `input_tensor`. """ # TODO: Another option is to specify the expected size in init, so that we can construct # the positional encoding beforehand, and simply add it to the input tensor in forward. _, timesteps, hidden_dim = input_tensor.size() num_timescales = hidden_dim // 2 device = get_device_of(input_tensor) timestep_range = get_range_vector(timesteps, device).data.float() timescale_range = get_range_vector(num_timescales, device).data.float() log_timescale_increments = math.log( float(self.max_timescale) / float(self.min_timescale)) / float(num_timescales - 1) inverse_timescales = self.min_timescale * torch.exp( timescale_range * -log_timescale_increments) # Broadcasted multiplication - shape (timesteps, num_timescales) scaled_time = timestep_range.unsqueeze( 1) * inverse_timescales.unsqueeze(0) # shape (timesteps, 2 * num_timescales) sinusoids = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1) if hidden_dim % 2 != 0: # if the number of dimensions is odd, the cos and sin # timescales had size (hidden_dim - 1) / 2, so we need # to add a row of zeros to make up the difference. sinusoids = torch.cat( [sinusoids, sinusoids.new_zeros(timesteps, 1)], 1) return input_tensor + sinusoids.unsqueeze(0)
def masking_blockdiagonal(passage_length, window, device_id): """ Make a (passage_length, passage_length) tensor M of 1 and -1 in which for each row x, M[x, y] = -1 if y < x - window or y > x + window, else it is 1. Basically for the x-th row, the [x-win, x+win] columns should be 1, and rest -1 """ # The lower and upper limit of token-idx that won't be masked for a given token lower = allenutil.get_range_vector(passage_length, device=device_id) - window upper = allenutil.get_range_vector(passage_length, device=device_id) + window lower = torch.clamp(lower, min=0, max=passage_length - 1) upper = torch.clamp(upper, min=0, max=passage_length - 1) lower_un = lower.unsqueeze(1) upper_un = upper.unsqueeze(1) # Range vector for each row lower_range_vector = allenutil.get_range_vector( passage_length, device=device_id).unsqueeze(0) upper_range_vector = allenutil.get_range_vector( passage_length, device=device_id).unsqueeze(0) # Masks for lower and upper limits of the mask lower_mask = lower_range_vector >= lower_un upper_mask = upper_range_vector <= upper_un # Final-mask that we require inwindow_mask = (lower_mask == upper_mask).float() outwindow_mask = (lower_mask != upper_mask).float() return inwindow_mask, outwindow_mask
def _generate_valid_antecedents( num_spans_to_keep: int, max_antecedents: int, device: int ) -> Tuple[torch.IntTensor, torch.IntTensor, torch.BoolTensor]: """ This method generates possible antecedents per span which survived the pruning stage. This procedure is `generic across the batch`. The reason this is the case is that each span in a batch can be coreferent with any previous span, but here we are computing the possible `indices` of these spans. So, regardless of the batch, the 1st span _cannot_ have any antecedents, because there are none to select from. Similarly, each element can only predict previous spans, so this returns a matrix of shape (num_spans_to_keep, max_antecedents), where the (i,j)-th index is equal to (i - 1) - j if j <= i, or zero otherwise. # Parameters num_spans_to_keep : `int`, required. The number of spans that were kept while pruning. max_antecedents : `int`, required. The maximum number of antecedent spans to consider for every span. device : `int`, required. The CUDA device to use. # Returns valid_antecedent_indices : `torch.LongTensor` The indices of every antecedent to consider with respect to the top k spans. Has shape `(num_spans_to_keep, max_antecedents)`. valid_antecedent_offsets : `torch.LongTensor` The distance between the span and each of its antecedents in terms of the number of considered spans (i.e not the word distance between the spans). Has shape `(1, max_antecedents)`. valid_antecedent_mask : `torch.BoolTensor` The mask representing whether each antecedent span is valid. Required since different spans have different numbers of valid antecedents. For example, the first span in the document should have no valid antecedents. Has shape `(1, num_spans_to_keep, max_antecedents)`. """ # Shape: (num_spans_to_keep, 1) target_indices = util.get_range_vector(num_spans_to_keep, device).unsqueeze(1) # Shape: (1, max_antecedents) valid_antecedent_offsets = ( util.get_range_vector(max_antecedents, device) + 1).unsqueeze(0) # This is a broadcasted subtraction. # Shape: (num_spans_to_keep, max_antecedents) raw_antecedent_indices = target_indices - valid_antecedent_offsets # In our matrix of indices, the upper triangular part will be negative # because the offsets will be > the target indices. We want to mask these, # because these are exactly the indices which we don't want to predict, per span. # Shape: (1, num_spans_to_keep, max_antecedents) valid_antecedent_mask = (raw_antecedent_indices >= 0).unsqueeze(0) # Shape: (num_spans_to_keep, max_antecedents) valid_antecedent_indices = F.relu( raw_antecedent_indices.float()).long() return valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_mask
def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor, head_tag_temperature: Optional[float] = None, head_temperature: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: float_mask = mask.float() tag_mask = self._get_unknown_tag_mask(mask, head_tags) batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) if head_temperature: attended_arcs /= head_temperature normalised_arc_logits = masked_log_softmax( attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) if head_tag_temperature: head_tag_logits /= head_tag_temperature normalised_head_tag_logits = masked_log_softmax( head_tag_logits, tag_mask.unsqueeze(-1)) * tag_mask.float().unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] tag_loss *= (head_tags > 1).float() # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() loss = arc_nll + tag_nll return loss, normalised_arc_logits, normalised_head_tag_logits
def loss(self, edge_scores: torch.Tensor, head_indices: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Computes the edge loss for a sequence given gold head indices and tags. Parameters ---------- edge_scores : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. """ float_mask = mask.float() batch_size, sequence_length, _ = edge_scores.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(edge_scores)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax( edge_scores, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(edge_scores)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() if self.normalize_wrt_seq_len: arc_nll /= valid_positions.float() return arc_nll
def forward(self, inputs: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: """ Parameters ---------- inputs: ``torch.Tensor``, required A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings for the current batch. offsets: ``torch.Tensor``, required A ``(batch_size, max_sequence_length)`` tensor representing the word offsets for the current batch. Returns ------- ``[torch.Tensor]`` An embedding representation of the input sequence having shape ``(batch_size, sequence_length, embedding_dim)`` """ # pylint: disable=arguments-differ batch_size, num_timesteps = inputs.size() # the transformer "vocab" consists of the actual vocab and the # positional encodings. Here we want the count of just the former. vocab_size = self._transformer.vocab_size - self._transformer.n_ctx # vocab_size, vocab_size + 1, ... positional_encodings = get_range_vector( num_timesteps, device=get_device_of(inputs)) + vocab_size # Combine the inputs with positional encodings batch_tensor = torch.stack( [ inputs, # (batch_size, num_timesteps) positional_encodings.expand(batch_size, num_timesteps) ], dim=-1) byte_pairs_mask = inputs != 0 # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim) layer_activations = self._transformer(batch_tensor) # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim) mix = self._scalar_mix(layer_activations, byte_pairs_mask) # These embeddings are one per byte-pair, but we want one per original _word_. # So we choose the embedding corresponding to the last byte pair for each word, # which is captured by the ``offsets`` input. range_vector = get_range_vector(batch_size, device=get_device_of(mix)).unsqueeze(1) last_byte_pair_embeddings = mix[range_vector, offsets] return last_byte_pair_embeddings
def loss(self, edge_label_logits: torch.Tensor, mask: torch.Tensor, head_tags: torch.Tensor) -> torch.Tensor: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- edge_label_logits : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, num_head_tags), that contains raw predictions for incoming edge labels head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- tag_nll : ``torch.Tensor``, required. The negative log likelihood from the edge label loss. """ float_mask = mask.float() batch_size, sequence_length, _ = edge_label_logits.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(edge_label_logits)).unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) normalised_edge_label_logits = masked_log_softmax( edge_label_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(edge_label_logits)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) tag_loss = normalised_edge_label_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size if self.normalize_wrt_seq_len: return -tag_loss.sum() / valid_positions.float() else: return -tag_loss.sum()
def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: float_mask = mask.float() minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax( attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = torch.nn.functional.log_softmax( head_tag_logits, dim=-1) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll
def forward(self, inputs: torch.Tensor, mask: torch.Tensor, span: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ # input -> [B x seq_len x d], offset -> [B x 2] batch_size, seq_len, _ = inputs.size() pos_range = util.get_range_vector(seq_len, util.get_device_of(inputs)).repeat( (batch_size, 1)) start_offset = span[:, 0].unsqueeze(dim=1) end_offset = span[:, 1].unsqueeze(dim=1) left_mask = torch.lt(pos_range, start_offset).long() middle_mask = (torch.ge(pos_range, start_offset) * torch.le(pos_range, end_offset)).long() right_mask = torch.gt(pos_range, end_offset).long() offsets = start_offset * left_mask + end_offset * right_mask relative_positions = (1 + self._n_position + (pos_range - offsets) * (1 - middle_mask)) # mask padding so it won't receive a positional embedding relative_positions = relative_positions * mask.long() return self._embedding(relative_positions)
def set_input(self, encoded_input: torch.Tensor, mask: torch.Tensor) -> None: self.encoded_input = encoded_input batch_size = encoded_input.shape[0] self.mask = mask self.batch_size_range = get_range_vector(batch_size, get_device_of(encoded_input))
def forward(self, word_inputs: torch.Tensor, char_inputs: torch.Tensor): embs = [] if self.word_embedder is not None: word_inputs = torch.autograd.Variable(word_inputs, requires_grad=False) embed_words = self.word_embedder(word_inputs) embs.append(embed_words) if self.char_embedder is not None: char_inputs, char_lengths = char_inputs batch_size, seq_len = char_lengths.size()[:2] char_inputs = char_inputs.view(batch_size * seq_len, -1) char_lengths = char_lengths.view(batch_size * seq_len, -1) # (batch_size * seq_len, max_char, dim) embeded_chars = self.char_embedder(char_inputs) _, max_seq_len, dim = embeded_chars.size() layer = embeded_chars for length in range(1, max_seq_len): new_layer = layer.new_zeros(layer.size()) range_vector = get_range_vector(max_seq_len, get_device_of(char_lengths)) mask = ((range_vector.unsqueeze(0) - char_lengths + length) <= 0).unsqueeze(-1) for i in range(max_seq_len - length): new_layer[:, i, :] = self.cell(layer[:, i:i + 2, :]) layer.masked_scatter_(mask, new_layer) embs.append(layer[:, 0, :].view(batch_size, seq_len, dim)) token_embedding = torch.cat(embs, dim=2) return self.projection(token_embedding)
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector( max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu( raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices( span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # text_embeddings = span_embeddings * span_mask.unsqueeze(-1) batch_size, num_spans, max_batch_span_width, _ = span_embeddings.size() view_text_embeddings = span_embeddings.view(batch_size * num_spans, max_batch_span_width, -1) span_mask = span_mask.view(batch_size * num_spans, max_batch_span_width) cnn_text_embeddings = self.cnn(view_text_embeddings, span_mask) cnn_text_embeddings = cnn_text_embeddings.view(batch_size, num_spans, self._output_dim) return cnn_text_embeddings
def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) # print('hello_ in 576') # print(batch_size) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) # print('hello_ in 587') # print(head_indices) # print(head_indices.shape) # print(range_vector) # # print(head_tag_representation) # print(head_tag_representation.shape) # # print(child_tag_representation) # print(child_tag_representation.shape) selected_head_tag_representations = head_tag_representation[ range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous( ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits
def get_last_sent_gate(all_predictions, num_spans, get_all_beam, eos_idx=0): """ all_predictions: shape: (batch_size, K, num_decoding_steps) """ batch_size = all_predictions.size(0) num_steps = all_predictions.size(-1) # shape: (batch_size, K) last_poses = torch.sum((all_predictions != eos_idx).float(), dim=-1) - 1 # shape: (num_decoding_steps, ) indices = get_range_vector(num_steps, get_device_of(all_predictions)).float() # shape: (batch_size, K, num_decoding_steps) mask = (indices.view(*([1]*(all_predictions.dim()-1)), num_steps) == last_poses.unsqueeze(-1)).float() # shape: (batch_size, K, num_decoding_steps) last_predictions = all_predictions.float() * mask print("last_predictions:", last_predictions) # build the last sent gate. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) or (batch_size, K, 1+num_spans) if not get_all_beam: gate = last_predictions.new_zeros((batch_size, 1+num_spans)) else: beam = all_predictions.size(1) gate = last_predictions.new_zeros((batch_size, beam, 1+num_spans)) gate.scatter_(-1, last_predictions.long(), 1.) # remove the column for end embedding # shape: (batch_size, num_spans) or (batch_size, K, num_spans) gate = gate[..., 1:] # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1) if not get_all_beam: gate = gate.reshape(batch_size * num_spans, 1) else: gate = gate.reshape(batch_size * beam * num_spans, 1) return gate
def get_timing_signal_1d(length, channels, device, min_timescale=1.0, max_timescale=1.0e4, start_index=0): """Gets a bunch of sinusoids of different frequencies. Each channel of the input Tensor is incremented by a sinusoid of a different frequency and phase. This allows attention to learn to use absolute and relative positions. Timing signals should be added to some precursors of both the query and the memory inputs to attention. The use of relative position is possible because sin(x+y) and cos(x+y) can be expressed in terms of y, sin(x) and cos(x). In particular, we use a geometric sequence of timescales starting with min_timescale and ending with max_timescale. The number of different timescales is equal to channels / 2. For each timescale, we generate the two sinusoidal signals sin(timestep/timescale) and cos(timestep/timescale). All of these sinusoids are concatenated in the channels dimension. Args: length: scalar, length of timing signal sequence. channels: scalar, size of timing embeddings to create. The number of different timescales is equal to channels / 2. min_timescale: a float max_timescale: a float start_index: index of first position Returns: a Tensor of timing signals [1, length, channels] """ position = util.get_range_vector(length, device) + start_index position = position.float() num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1.0, 1.0)) inv_timescales = min_timescale * torch.exp( util.get_range_vector(num_timescales, device).float() * -log_timescale_increment) scaled_time = torch.unsqueeze(position, 1) * torch.unsqueeze( inv_timescales, 0) signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) pad = nn.ConstantPad1d((0, channels % 2), 0) signal = pad(signal) signal = signal.view(1, length, channels) return signal
def test_openai_transformer_matches_tensorflow(self): model_path = "https://allennlp.s3.amazonaws.com/models/openai-transformer-lm-2018.07.23.tar.gz" indexer = OpenaiTransformerBytePairIndexer(model_path=model_path) transformer = OpenaiTransformer(model_path=model_path) # get the test sentences with open(self.FIXTURES_ROOT / 'openai_transformer' / 'text.txt', 'r') as fin: sentences = fin.read().strip().split('\n') # tokenize and check that indices are correct nlp = spacy.load('en_core_web_sm') # make a batch of two sentences batch_indices = [] batch_lengths = [] for k, sentence in enumerate(sentences): tokens = [ token.text for token in nlp(text_standardize(sentence)) if not token.is_space ] indices = indexer.tokens_to_indices( [Token(token) for token in tokens], Vocabulary(), 'openai_indexer') batch_indices.append(indices['openai_indexer']) batch_lengths.append( len([i for i in indices['openai_indexer'] if i != 0])) batch_indices = torch.from_numpy(numpy.array(batch_indices)) batch_size, num_timesteps = batch_indices.size() vocab_size = transformer.vocab_size - transformer.n_ctx positional_encodings = get_range_vector(num_timesteps, device=-1) + vocab_size # Combine the inputs with positional encodings batch_tensor = torch.stack( [ batch_indices, # (batch_size, num_timesteps) positional_encodings.expand(batch_size, num_timesteps) ], dim=-1) # run the LM transformer.eval() activations = transformer(batch_tensor) # load the expected activations expected_activations = [] with h5py.File( self.FIXTURES_ROOT / 'openai_transformer' / 'expected_embeddings.hdf5', 'r') as fin: expected_activations.append(fin['0'][...]) expected_activations.append(fin['1'][...]) # just check the top layer for k in range(2): actual = activations[-1][k, :batch_lengths[k], :].numpy() expected = expected_activations[k] numpy.testing.assert_almost_equal(expected, actual, decimal=5)
def get_select_embedding(sub_words_embedding, offsets): # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector(offsets2d.size(0), device=util.get_device_of(sub_words_embedding)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = sub_words_embedding[range_vector, offsets2d] return util.uncombine_initial_dims(selected_embeddings, offsets.size())
def attention_bias_proximal(length, device=-1): """Bias for self-attention to encourage attention to close positions. Args: length: an integer scalar. Returns: a Tensor with shape [1, 1, length, length] """ r = util.get_range_vector(length, device).float() diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) return torch.unsqueeze(torch.unsqueeze(-torch.log(1 + torch.abs(diff)), 0), 0)
def _coarse_to_fine_pruning( self, top_span_embeddings: torch.FloatTensor, top_span_mention_scores: torch.FloatTensor, num_spans_to_keep: int, max_antecedents: int, device: int ) -> Tuple[torch.IntTensor, torch.IntTensor, torch.FloatTensor, torch.FloatTensor]: # Shape: (num_spans_to_keep) target_indices = util.get_range_vector(num_spans_to_keep, device) # Shape: (num_spans_to_keep, num_spans_to_keep) valid_antecedent_offsets = target_indices.unsqueeze( 1) - target_indices.unsqueeze(0) # Shape: (num_spans_to_keep, num_spans_to_keep) valid_antecedent_log_mask = (valid_antecedent_offsets >= 1).float().unsqueeze(0).log() # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep) fast_antecedent_scores = top_span_mention_scores + top_span_mention_scores.squeeze( -1).unsqueeze(1) fast_antecedent_scores += valid_antecedent_log_mask # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep) coarse_scores = self._compute_coarse_scores(top_span_embeddings) fast_antecedent_scores += coarse_scores # Shape: (batch_size, num_spans_to_keep, max_antecedents) _, top_antecedent_indices = fast_antecedent_scores.topk( max_antecedents, -1) # Now we order the selected indices in increasing order with # respect to their indices (and hence, with respect to the # order they originally appeared in the ``embeddings`` tensor). # Shape: (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_indices, _ = torch.sort(top_antecedent_indices, dim=-1) # Shape: (batch_size, num_items_to_keep, max_antecedents) # (batch_size, num_spans_to_keep, max_antecedents) valid_antecedent_log_mask = valid_antecedent_log_mask.expand( top_antecedent_indices.size(0), -1, -1) top_antecedent_log_mask = torch.gather(valid_antecedent_log_mask, -1, top_antecedent_indices) # Shape: (batch_size, num_items_to_keep, max_antecedents) valid_antecedent_offsets = \ valid_antecedent_offsets.unsqueeze(0).expand(top_antecedent_indices.size(0), -1, -1) top_antecedent_offsets = torch.gather(valid_antecedent_offsets, -1, top_antecedent_indices) # Shape: (batch_size, num_items_to_keep, max_antecedents) top_fast_antecedent_scores = torch.gather(fast_antecedent_scores, -1, top_antecedent_indices) return top_antecedent_indices, top_antecedent_offsets, top_antecedent_log_mask, top_fast_antecedent_scores
def gather_indexes(sequence_tensor, positions): """Gathers the vectors at the specific positions over a minibatch.""" sequence_shape = sequence_tensor.size() batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = util.get_range_vector( batch_size, util.get_device_of(sequence_tensor)) * seq_length flat_offsets = flat_offsets.unsqueeze(-1).long() flat_positions = (positions + flat_offsets).view(-1) flat_sequence_tensor = sequence_tensor.view(batch_size * seq_length, width) output_tensor = torch.index_select(flat_sequence_tensor, 0, flat_positions) return output_tensor
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] if offsets is None: return mix else: batch_size = input_ids.size(0) range_vector = util.get_range_vector( batch_size, device=util.get_device_of(mix)).unsqueeze(1) return mix[range_vector, offsets]
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] if offsets is None: return mix else: batch_size = input_ids.size(0) range_vector = util.get_range_vector(batch_size, device=util.get_device_of(mix)).unsqueeze(1) return mix[range_vector, offsets]
def get_evd_prediction_mask(all_predictions, eos_idx): # get the mask w.r.t to ``all_predictions`` that includes the index of the first eos and those before it # shape(all_predictions): (batch_size, ..., num_steps) # Shape: (batch_size,) batch_size = all_predictions.size(0) num_steps = all_predictions.size(-1) # shape: (batch_size, ...) valid_decoding_lens = torch.sum( (all_predictions != eos_idx).float(), dim=-1) + 1 indices = get_range_vector(num_steps, get_device_of(all_predictions)).float() mask = (indices.view(*([1] * (all_predictions.dim() - 1)), num_steps) < valid_decoding_lens.unsqueeze(-1)).int() eos_mask = (all_predictions == eos_idx).int() * mask return mask, eos_mask
def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous() # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits
def label_scores(self, encoded_text: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Computes edge label scores for a fixed tree structure (given by head_indices) for a batch of sentences. Parameters ---------- encoded_text: (batch_size, sequence_length, encoder_output_dim) head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word (predicted or gold). Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text) ) # will be used to generate predictions for the edge labels for the given arcs. child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text) ) # will be used to generate predictions for the edge labels for the given arcs. batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[ range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous( ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits
def label_scores(self, encoded_text: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Computes edge label scores for a fixed tree structure (given by head_indices) for a batch of sentences. Parameters ---------- encoded_text : torch.Tensor, required The input sentence, with artifical root node (head sentinel) added in the beginning of shape (batch_size, sequence length, encoding dim) head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word (predicted or gold). Returns ------- edge_label_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each given arc. """ # shape (batch_size, sequence_length, tag_representation_dim) head_label_representation = self.head_label_feedforward(encoded_text) child_label_representation = self.child_label_feedforward(encoded_text) batch_size = head_label_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_label_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_label_representations = head_label_representation[ range_vector, head_indices] selected_head_label_representations = selected_head_label_representations.contiguous( ) combined = self.activation(selected_head_label_representations + child_label_representation) #(batch_size, sequence_length, num_head_tags) edge_label_logits = self.label_out_layer(combined) return edge_label_logits
def test_openai_transformer_matches_tensorflow(self): model_path = "https://s3-us-west-2.amazonaws.com/allennlp/models/openai-transformer-lm-2018.07.23.tar.gz" indexer = OpenaiTransformerBytePairIndexer(model_path=model_path) transformer = OpenaiTransformer(model_path=model_path) # get the test sentences with open(self.FIXTURES_ROOT / 'openai_transformer' / 'text.txt', 'r') as fin: sentences = fin.read().strip().split('\n') # tokenize and check that indices are correct nlp = spacy.load('en_core_web_sm') # make a batch of two sentences batch_indices = [] batch_lengths = [] for k, sentence in enumerate(sentences): tokens = [token.text for token in nlp(text_standardize(sentence)) if not token.is_space] indices = indexer.tokens_to_indices( [Token(token) for token in tokens], Vocabulary(), 'openai_indexer' ) batch_indices.append(indices['openai_indexer']) batch_lengths.append(len([i for i in indices['openai_indexer'] if i != 0])) batch_indices = torch.from_numpy(numpy.array(batch_indices)) batch_size, num_timesteps = batch_indices.size() vocab_size = transformer.vocab_size - transformer.n_ctx positional_encodings = get_range_vector(num_timesteps, device=-1) + vocab_size # Combine the inputs with positional encodings batch_tensor = torch.stack([ batch_indices, # (batch_size, num_timesteps) positional_encodings.expand(batch_size, num_timesteps) ], dim=-1) # run the LM transformer.eval() activations = transformer(batch_tensor) # load the expected activations expected_activations = [] with h5py.File(self.FIXTURES_ROOT / 'openai_transformer' / 'expected_embeddings.hdf5', 'r') as fin: expected_activations.append(fin['0'][...]) expected_activations.append(fin['1'][...]) # just check the top layer for k in range(2): actual = activations[-1][k, :batch_lengths[k], :].numpy() expected = expected_activations[k] numpy.testing.assert_almost_equal(expected, actual, decimal=5)
def get_input_type_ids(self, type_ids, offsets, embedder): "Converts (bsz, seq_len_wp) to (bsz, seq_len_wp) by indexing." batch_size = type_ids.size(0) full_seq_len = type_ids.size(1) if full_seq_len > embedder.max_pieces: # Recombine if we had used sliding window approach assert batch_size == 1 and type_ids.max() > 0 num_question_tokens = type_ids[0][:embedder.max_pieces].nonzero( ).size(0) select_indices = embedder.indices_to_select( full_seq_len, num_question_tokens) type_ids = type_ids[:, select_indices] range_vector = util.get_range_vector( batch_size, device=util.get_device_of(type_ids)).unsqueeze(1) type_ids = type_ids[range_vector, offsets] return type_ids
def number2count_auxloss(passage_number_values: List[List[float]], device_id=-1): """ Using passage numnbers, make a (batch_size, max_passage_numbers) (padded) tensor, each containing a noisy distribution with mass distributed over x-numbers. The corresponding count-answer will be x. Use the attention2count rnn to predict a count value and compute the loss. """ batch_size = len(passage_number_values) # List of length -- batch-size num_of_passage_numbers = [len(nums) for nums in passage_number_values] max_passage_numbers = max(num_of_passage_numbers) # Shape: (batch_size, ) num_pasasge_numbers = util.move_to_device( torch.LongTensor(num_of_passage_numbers), cuda_device=device_id) # Shape: (max_passage_numbers, ) range_vector = util.get_range_vector(size=max_passage_numbers, device=device_id) mask = (range_vector.unsqueeze(0) < num_pasasge_numbers.unsqueeze(1)).float() print(mask) number_distributions = mask.new_zeros(batch_size, max_passage_numbers).normal_( 0, 0.01).abs_() count_answers = number_distributions.new_zeros(batch_size, max_passage_numbers).long() for i, num_numbers in enumerate(num_of_passage_numbers): """ Sample a count value between [0, min(5, num_numbers)]. Sample indices in this range, and set them as 1. Add gaussian noise to the whole tensor and normalize. """ # Pick a count answer count_value = random.randint(0, min(7, num_numbers)) count_answers[i, count_value] = 1 # Pick the indices that will have mass if count_value > 0: indices = random.sample(range(num_numbers), count_value) # Add 1.0 to all sampled indices number_distributions[i, indices] += 1.0 number_distributions = number_distributions * mask number_distributions = number_distributions / torch.sum( number_distributions, dim=1).unsqueeze(1)
def common_step(self, batch, phase="train"): (token_ids, type_ids, offsets, wordpiece_mask, pos_tags, word_mask, mrc_mask, meta_data, parent_idxs, parent_tags) = (batch["token_ids"], batch["type_ids"], batch["offsets"], batch["wordpiece_mask"], batch["pos_tags"], batch["word_mask"], batch["mrc_mask"], batch["meta_data"], batch["parent_idxs"], batch["parent_tags"]) parent_probs, parent_tag_probs, parent_arc_nll, parent_tag_nll = self( token_ids, type_ids, offsets, wordpiece_mask, pos_tags, word_mask, mrc_mask, parent_idxs, parent_tags) loss = parent_arc_nll + parent_tag_nll eval_mask = self._get_mask_for_eval(mask=word_mask, pos_tags=pos_tags) bsz = parent_probs.size(0) # [bsz] batch_range_vector = get_range_vector(bsz, get_device_of(parent_tags)) eval_mask = eval_mask[batch_range_vector, parent_idxs] # [bsz] if phase == "train" or not self.args.use_mst: # [bsz] pred_positions = parent_probs.argmax(1) metric_name = f"{phase}_stat" metric = getattr(self, metric_name) metric.update( pred_positions.unsqueeze(-1), # [bsz, 1] parent_tag_probs[batch_range_vector, pred_positions].argmax( 1).unsqueeze(-1), # [bsz, 1] parent_idxs.unsqueeze(-1), # [bsz, 1] parent_tags.unsqueeze(-1), # [bsz, 1] eval_mask.unsqueeze(-1) # [bsz, 1] ) else: # todo implement mst decoding metric = getattr(self, f"{phase}_stat") metric.update(meta_data["ann_idx"], meta_data["word_idx"], [len(x) for x in meta_data["words"]], parent_probs, parent_tag_probs, eval_mask) # acc_metric = getattr(self, f"{phase}_acc") # acc_metric.update( # preds=is_subtree_probs, # target=is_subtree # ) self.log(f'{phase}_loss', loss) return loss
def flatten_and_batch_shift_indices(indices: torch.Tensor, sequence_length: int) -> torch.Tensor: """ This is a subroutine for :func:`~batched_index_select`. The given ``indices`` of size ``(batch_size, d_1, ..., d_n)`` indexes into dimension 2 of a target tensor, which has size ``(batch_size, sequence_length, embedding_size)``. This function returns a vector that correctly indexes into the flattened target. The sequence length of the target must be provided to compute the appropriate offsets. .. code-block:: python indices = torch.ones([2,3], dtype=torch.long) # Sequence length of the target tensor. sequence_length = 10 shifted_indices = flatten_and_batch_shift_indices(indices, sequence_length) # Indices into the second element in the batch are correctly shifted # to take into account that the target tensor will be flattened before # the indices are applied. assert shifted_indices == [1, 1, 1, 11, 11, 11] Parameters ---------- indices : ``torch.LongTensor``, required. sequence_length : ``int``, required. The length of the sequence the indices index into. This must be the second dimension of the tensor. Returns ------- offset_indices : ``torch.LongTensor`` """ # Shape: (batch_size) offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length for _ in range(len(indices.size()) - 1): offsets = offsets.unsqueeze(1) # Shape: (batch_size, d_1, ..., d_n) offset_indices = indices + offsets # print(offset_indices) # Shape: (batch_size * d_1 * ... * d_n) offset_indices = offset_indices.view(-1) return offset_indices
def forward(self, inputs: torch.Tensor, mask: torch.Tensor, span: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ,unused-argument # input -> [B x seq_len x d], offset -> [B x 2] batch_size, seq_len, _ = inputs.size() offset = span[:, 0].unsqueeze(-1) position_range = util.get_range_vector( seq_len, util.get_device_of(inputs)).repeat((batch_size, 1)) offset_mask = position_range == offset position_markers = inputs.new_ones((batch_size, seq_len), requires_grad=True) position_markers = position_markers * offset_mask.float() position_markers = position_markers.unsqueeze(-1) return position_markers
def _construct_loss(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax(head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll
def _generate_valid_antecedents(num_spans_to_keep: int, max_antecedents: int, device: int) -> Tuple[torch.IntTensor, torch.IntTensor, torch.FloatTensor]: """ This method generates possible antecedents per span which survived the pruning stage. This procedure is `generic across the batch`. The reason this is the case is that each span in a batch can be coreferent with any previous span, but here we are computing the possible `indices` of these spans. So, regardless of the batch, the 1st span _cannot_ have any antecedents, because there are none to select from. Similarly, each element can only predict previous spans, so this returns a matrix of shape (num_spans_to_keep, max_antecedents), where the (i,j)-th index is equal to (i - 1) - j if j <= i, or zero otherwise. Parameters ---------- num_spans_to_keep : ``int``, required. The number of spans that were kept while pruning. max_antecedents : ``int``, required. The maximum number of antecedent spans to consider for every span. device: ``int``, required. The CUDA device to use. Returns ------- valid_antecedent_indices : ``torch.IntTensor`` The indices of every antecedent to consider with respect to the top k spans. Has shape ``(num_spans_to_keep, max_antecedents)``. valid_antecedent_offsets : ``torch.IntTensor`` The distance between the span and each of its antecedents in terms of the number of considered spans (i.e not the word distance between the spans). Has shape ``(1, max_antecedents)``. valid_antecedent_log_mask : ``torch.FloatTensor`` The logged mask representing whether each antecedent span is valid. Required since different spans have different numbers of valid antecedents. For example, the first span in the document should have no valid antecedents. Has shape ``(1, num_spans_to_keep, max_antecedents)``. """ # Shape: (num_spans_to_keep, 1) target_indices = util.get_range_vector(num_spans_to_keep, device).unsqueeze(1) # Shape: (1, max_antecedents) valid_antecedent_offsets = (util.get_range_vector(max_antecedents, device) + 1).unsqueeze(0) # This is a broadcasted subtraction. # Shape: (num_spans_to_keep, max_antecedents) raw_antecedent_indices = target_indices - valid_antecedent_offsets # In our matrix of indices, the upper triangular part will be negative # because the offsets will be > the target indices. We want to mask these, # because these are exactly the indices which we don't want to predict, per span. # We're generating a logspace mask here because we will eventually create a # distribution over these indices, so we need the 0 elements of the mask to be -inf # in order to not mess up the normalisation of the distribution. # Shape: (1, num_spans_to_keep, max_antecedents) valid_antecedent_log_mask = (raw_antecedent_indices >= 0).float().unsqueeze(0).log() # Shape: (num_spans_to_keep, max_antecedents) valid_antecedent_indices = F.relu(raw_antecedent_indices.float()).long() return valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, _ = self.bert_model(input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(input_mask)) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) return util.uncombine_initial_dims(mix, input_ids.size()) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector(offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = mix[range_vector, offsets2d] return util.uncombine_initial_dims(selected_embeddings, offsets.size())
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # shape (batch_size, sequence_length, 1) global_attention_logits = self._global_attention(sequence_tensor) # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector(max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_logits = util.batched_index_select(global_attention_logits, span_indices, flat_span_indices).squeeze(-1) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_weights = util.masked_softmax(span_attention_logits, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised attention distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) if span_indices_mask is not None: # Above we were masking the widths of spans with respect to the max # span width in the batch. Here we are masking the spans which were # originally passed in as padding. return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float() return attended_text_embeddings