def test_tie_break_categorical_accuracy(self): accuracy = CategoricalAccuracy(tie_break=True) predictions = torch.Tensor([[0.35, 0.25, 0.35, 0.35, 0.35], [0.1, 0.6, 0.1, 0.2, 0.2], [0.1, 0.0, 0.1, 0.2, 0.2]]) # Test without mask: targets = torch.Tensor([2, 1, 4]) accuracy(predictions, targets) assert accuracy.get_metric(reset=True) == (0.25 + 1 + 0.5)/3.0 # # # Test with mask mask = torch.Tensor([1, 0, 1]) targets = torch.Tensor([2, 1, 4]) accuracy(predictions, targets, mask) assert accuracy.get_metric(reset=True) == (0.25 + 0.5)/2.0 # # Test tie-break with sequence predictions = torch.Tensor([[[0.35, 0.25, 0.35, 0.35, 0.35], [0.1, 0.6, 0.1, 0.2, 0.2], [0.1, 0.0, 0.1, 0.2, 0.2]], [[0.35, 0.25, 0.35, 0.35, 0.35], [0.1, 0.6, 0.1, 0.2, 0.2], [0.1, 0.0, 0.1, 0.2, 0.2]]]) targets = torch.Tensor([[0, 1, 3], # 0.25 + 1 + 0.5 [0, 3, 4]]) # 0.25 + 0 + 0.5 = 2.5 accuracy(predictions, targets) actual_accuracy = accuracy.get_metric(reset=True) numpy.testing.assert_almost_equal(actual_accuracy, 2.5/6.0)
class CoLATask(Task): '''Class for Warstdadt acceptability task''' def __init__(self, path, max_seq_len, name="acceptability"): ''' ''' super(CoLATask, self).__init__(name, 2) self.pair_input = 0 self.load_data(path, max_seq_len) self.val_metric = "%s_accuracy" % self.name self.val_metric_decreases = False self.scorer1 = Average() self.scorer2 = CategoricalAccuracy() def load_data(self, path, max_seq_len): '''Load the data''' tr_data = load_tsv(os.path.join(path, "train.tsv"), max_seq_len, s1_idx=3, s2_idx=None, targ_idx=1) val_data = load_tsv(os.path.join(path, "dev.tsv"), max_seq_len, s1_idx=3, s2_idx=None, targ_idx=1) te_data = load_tsv(os.path.join(path, 'test.tsv'), max_seq_len, s1_idx=1, s2_idx=None, targ_idx=None, idx_idx=0, skip_rows=1) self.train_data_text = tr_data self.val_data_text = val_data self.test_data_text = te_data log.info("\tFinished loading CoLA.") def get_metrics(self, reset=False): # NB: I think I call it accuracy b/c something weird in training return {'accuracy': self.scorer1.get_metric(reset), 'acc': self.scorer2.get_metric(reset)}
class LstmTagger(Model): def __init__(self, word_embeddings: TextFieldEmbedder, encoder: Seq2SeqEncoder, vocab: Vocabulary) -> None: super().__init__(vocab) self.word_embeddings = word_embeddings self.encoder = encoder self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(), out_features=vocab.get_vocab_size('labels')) self.accuracy = CategoricalAccuracy() def forward(self, sentence: Dict[str, torch.Tensor], labels: torch.Tensor = None) -> torch.Tensor: mask = get_text_field_mask(sentence) embeddings = self.word_embeddings(sentence) encoder_out = self.encoder(embeddings, mask) tag_logits = self.hidden2tag(encoder_out) output = {"tag_logits": tag_logits} if labels is not None: self.accuracy(tag_logits, labels, mask) output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask) return output def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {"accuracy": self.accuracy.get_metric(reset)}
def test_top_k_categorical_accuracy(self): accuracy = CategoricalAccuracy(top_k=2) predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2], [0.1, 0.6, 0.1, 0.2, 0.0]]) targets = torch.Tensor([0, 3]) accuracy(predictions, targets) actual_accuracy = accuracy.get_metric() assert actual_accuracy == 1.0
def test_top_k_categorical_accuracy_respects_mask(self): accuracy = CategoricalAccuracy(top_k=2) predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2], [0.1, 0.6, 0.1, 0.2, 0.0], [0.1, 0.2, 0.5, 0.2, 0.0]]) targets = torch.Tensor([0, 3, 0]) mask = torch.Tensor([0, 1, 1]) accuracy(predictions, targets, mask) actual_accuracy = accuracy.get_metric() assert actual_accuracy == 0.50
def test_top_k_categorical_accuracy_works_for_sequences(self): accuracy = CategoricalAccuracy(top_k=2) predictions = torch.Tensor([[[0.35, 0.25, 0.1, 0.1, 0.2], [0.1, 0.6, 0.1, 0.2, 0.0], [0.1, 0.6, 0.1, 0.2, 0.0]], [[0.35, 0.25, 0.1, 0.1, 0.2], [0.1, 0.6, 0.1, 0.2, 0.0], [0.1, 0.6, 0.1, 0.2, 0.0]]]) targets = torch.Tensor([[0, 3, 4], [0, 1, 4]]) accuracy(predictions, targets) actual_accuracy = accuracy.get_metric(reset=True) numpy.testing.assert_almost_equal(actual_accuracy, 0.6666666) # Test the same thing but with a mask: mask = torch.Tensor([[0, 1, 1], [1, 0, 1]]) accuracy(predictions, targets, mask) actual_accuracy = accuracy.get_metric(reset=True) numpy.testing.assert_almost_equal(actual_accuracy, 0.50)
def test_top_k_categorical_accuracy_accumulates_and_resets_correctly(self): accuracy = CategoricalAccuracy(top_k=2) predictions = torch.Tensor([[0.35, 0.25, 0.1, 0.1, 0.2], [0.1, 0.6, 0.1, 0.2, 0.0]]) targets = torch.Tensor([0, 3]) accuracy(predictions, targets) accuracy(predictions, targets) accuracy(predictions, torch.Tensor([4, 4])) accuracy(predictions, torch.Tensor([4, 4])) actual_accuracy = accuracy.get_metric(reset=True) assert actual_accuracy == 0.50 assert accuracy.correct_count == 0.0 assert accuracy.total_count == 0.0
class LstmTagger(Model): #### One thing that might seem unusual is that we're going pass in the embedder and the sequence encoder as constructor parameters. This allows us to experiment with different embedders and encoders without having to change the model code. def __init__(self, #### The embedding layer is specified as an AllenNLP <code>TextFieldEmbedder</code> which represents a general way of turning tokens into tensors. (Here we know that we want to represent each unique word with a learned tensor, but using the general class allows us to easily experiment with different types of embeddings, for example <a href = "https://allennlp.org/elmo">ELMo</a>.) word_embeddings: TextFieldEmbedder, #### Similarly, the encoder is specified as a general <code>Seq2SeqEncoder</code> even though we know we want to use an LSTM. Again, this makes it easy to experiment with other sequence encoders, for example a Transformer. encoder: Seq2SeqEncoder, #### Every AllenNLP model also expects a <code>Vocabulary</code>, which contains the namespaced mappings of tokens to indices and labels to indices. vocab: Vocabulary) -> None: #### Notice that we have to pass the vocab to the base class constructor. super().__init__(vocab) self.word_embeddings = word_embeddings self.encoder = encoder #### The feed forward layer is not passed in as a parameter, but is constructed by us. Notice that it looks at the encoder to find the correct input dimension and looks at the vocabulary (and, in particular, at the label -> index mapping) to find the correct output dimension. self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(), out_features=vocab.get_vocab_size('labels')) #### The last thing to notice is that we also instantiate a <code>CategoricalAccuracy</code> metric, which we'll use to track accuracy during each training and validation epoch. self.accuracy = CategoricalAccuracy() #### Next we need to implement <code>forward</code>, which is where the actual computation happens. Each <code>Instance</code> in your dataset will get (batched with other instances and) fed into <code>forward</code>. The <code>forward</code> method expects dicts of tensors as input, and it expects their names to be the names of the fields in your <code>Instance</code>. In this case we have a sentence field and (possibly) a labels field, so we'll construct our <code>forward</code> accordingly: def forward(self, sentence: Dict[str, torch.Tensor], labels: torch.Tensor = None) -> torch.Tensor: #### AllenNLP is designed to operate on batched inputs, but different input sequences have different lengths. Behind the scenes AllenNLP is padding the shorter inputs so that the batch has uniform shape, which means our computations need to use a mask to exclude the padding. Here we just use the utility function <code>get_text_field_mask</code>, which returns a tensor of 0s and 1s corresponding to the padded and unpadded locations. mask = get_text_field_mask(sentence) #### We start by passing the <code>sentence</code> tensor (each sentence a sequence of token ids) to the <code>word_embeddings</code> module, which converts each sentence into a sequence of embedded tensors. embeddings = self.word_embeddings(sentence) #### We next pass the embedded tensors (and the mask) to the LSTM, which produces a sequence of encoded outputs. encoder_out = self.encoder(embeddings, mask) #### Finally, we pass each encoded output tensor to the feedforward layer to produce logits corresponding to the various tags. tag_logits = self.hidden2tag(encoder_out) output = {"tag_logits": tag_logits} #### As before, the labels were optional, as we might want to run this model to make predictions on unlabeled data. If we do have labels, then we use them to update our accuracy metric and compute the "loss" that goes in our output. if labels is not None: self.accuracy(tag_logits, labels, mask) output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask) return output #### We included an accuracy metric that gets updated each forward pass. That means we need to override a <code>get_metrics</code> method that pulls the data out of it. Behind the scenes, the <code>CategoricalAccuracy</code> metric is storing the number of predictions and the number of correct predictions, updating those counts during each call to forward. Each call to get_metric returns the calculated accuracy and (optionally) resets the counts, which is what allows us to track accuracy anew for each epoch. def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {"accuracy": self.accuracy.get_metric(reset)}
class Task(): '''Abstract class for a task Methods and attributes: - load_data: load dataset from a path and create splits - yield dataset for training - dataset size - validate and test Outside the task: - process: pad and indexify data given a mapping - optimizer ''' __metaclass__ = ABCMeta def __init__(self, name, n_classes): self.name = name self.n_classes = n_classes self.train_data_text, self.val_data_text, self.test_data_text = \ None, None, None self.train_data = None self.val_data = None self.test_data = None self.pred_layer = None self.pair_input = 1 self.categorical = 1 # most tasks are self.val_metric = "%s_accuracy" % self.name self.val_metric_decreases = False self.scorer1 = CategoricalAccuracy() self.scorer2 = None @abstractmethod def load_data(self, path, max_seq_len): ''' Load data from path and create splits. ''' raise NotImplementedError def get_metrics(self, reset=False): '''Get metrics specific to the task''' acc = self.scorer1.get_metric(reset) return {'accuracy': acc}
class DialogQA(Model): """ This class implements modified version of BiDAF (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. In this set-up, a single instance is a dialog, list of question answer pairs. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. span_start_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span end predictions into the passage state. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_context_answers : ``int``, optional (default=0) If greater than 0, the model will consider previous question answering context. max_span_length: ``int``, optional (default=0) Maximum token length of the output span. max_turn_length: ``int``, optional (default=12) Maximum length of an interaction. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 30, max_turn_length: int = 12) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding(max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match(phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers") initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \ for yn_list in output_dict.pop("yesno")] followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \ for followup_list in output_dict.pop("followup")] output_dict['yesno'] = yesno_tags output_dict['followup'] = followup_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'yesno': self._span_yesno_accuracy.get_metric(reset), 'followup': self._span_followup_accuracy.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, span_followup_logits: torch.Tensor, max_span_length: int) -> torch.Tensor: # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as # yesno prediction bit and followup prediction bit from the predicted span end token. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() span_followup_logits = span_followup_logits.data.cpu().numpy() for b_i in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) followup_pred = np.argmax(span_followup_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) best_word_span[b_i, 3] = int(followup_pred) return best_word_span
span_start_accuracy_function = CategoricalAccuracy() span_end_accuracy_function = CategoricalAccuracy() span_accuracy_function = BooleanAccuracy() squad_metrics_function = SquadEmAndF1() # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss span_start_accuracy_function(span_start_logits, span_start.squeeze(-1)) span_end_accuracy_function(span_end_logits, span_end.squeeze(-1)) span_accuracy_function(best_span, torch.stack([span_start, span_end], -1)) span_start_accuracy = span_start_accuracy_function.get_metric() span_end_accuracy = span_end_accuracy_function.get_metric() span_accuracy = span_accuracy_function.get_metric() print ("Loss: ", loss) print ("span_start_accuracy: ", span_start_accuracy) print ("span_start_accuracy: ", span_start_accuracy) print ("span_end_accuracy: ", span_end_accuracy) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: best_span_str = [] question_tokens = [] passage_tokens = [] for i in range(batch_size):
class ESIM(Model): """ This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference" <https://www.semanticscholar.org/paper/Enhanced-LSTM-for-Natural-Language-Inference-Chen-Zhu/83e7654d545fbbaaf2328df365a781fb67b841b4>`_ by Chen et al., 2017. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the model. encoder : ``Seq2SeqEncoder`` Used to encode the premise and hypothesis. similarity_function : ``SimilarityFunction`` This is the similarity function used when computing the similarity matrix between encoded words in the premise and words in the hypothesis. projection_feedforward : ``FeedForward`` The feedforward network used to project down the encoded and enhanced premise and hypothesis. inference_encoder : ``Seq2SeqEncoder`` Used to encode the projected premise and hypothesis for prediction. output_feedforward : ``FeedForward`` Used to prepare the concatenated premise and hypothesis for prediction. output_logit : ``FeedForward`` This feedforward network computes the output logits. dropout : ``float``, optional (default=0.5) Dropout percentage to use. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, similarity_function: SimilarityFunction, projection_feedforward: FeedForward, inference_encoder: Seq2SeqEncoder, output_feedforward: FeedForward, output_logit: FeedForward, dropout: float = 0.5, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._encoder = encoder self._matrix_attention = LegacyMatrixAttention(similarity_function) self._projection_feedforward = projection_feedforward self._inference_encoder = inference_encoder if dropout: self.dropout = torch.nn.Dropout(dropout) self.rnn_input_dropout = InputVariationalDropout(dropout) else: self.dropout = None self.rnn_input_dropout = None self._output_feedforward = output_feedforward self._output_logit = output_logit self._num_labels = vocab.get_vocab_size(namespace="labels") check_dimensions_match( text_field_embedder.get_output_dim(), encoder.get_input_dim(), "text field embedding dim", "encoder input dim", ) check_dimensions_match( encoder.get_output_dim() * 4, projection_feedforward.get_input_dim(), "encoder output dim", "projection feedforward input", ) check_dimensions_match( projection_feedforward.get_output_dim(), inference_encoder.get_input_dim(), "proj feedforward output dim", "inference lstm input dim", ) self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( # type: ignore self, premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) # encode premise and hypothesis encoded_premise = self._encoder(embedded_premise, premise_mask) encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat( [ encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis, ], dim=-1, ) hypothesis_enhanced = torch.cat( [ encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise, ], dim=-1, ) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_premise = self._projection_feedforward( premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward( hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout( projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout( projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max, _ = replace_masked_values(v_ai, premise_mask.unsqueeze(-1), -1e7).max(dim=1) v_b_max, _ = replace_masked_values(v_bi, hypothesis_mask.unsqueeze(-1), -1e7).max(dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(premise_mask, 1, keepdim=True) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {"accuracy": self._accuracy.get_metric(reset)}
def test_does_not_divide_by_zero_with_no_count(self): accuracy = CategoricalAccuracy() self.assertAlmostEqual(accuracy.get_metric(), 0.0)
class DecomposableAttention(Model): """ This ``Model`` implements the Decomposable Attention model described in `"A Decomposable Attention Model for Natural Language Inference" <https://www.semanticscholar.org/paper/A-Decomposable-Attention-Model-for-Natural-Languag-Parikh-T%C3%A4ckstr%C3%B6m/07a9478e87a8304fc3267fa16e83e9f3bbd98b27>`_ by Parikh et al., 2016, with some optional enhancements before the decomposable attention actually happens. Parikh's original model allowed for computing an "intra-sentence" attention before doing the decomposable entailment step. We generalize this to any :class:`Seq2SeqEncoder` that can be applied to the premise and/or the hypothesis before computing entailment. The basic outline of this model is to get an embedded representation of each word in the premise and hypothesis, align words between the two, compare the aligned phrases, and make a final entailment decision based on this aggregated comparison. Each step in this process uses a feedforward network to modify the representation. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the model. attend_feedforward : ``FeedForward`` This feedforward network is applied to the encoded sentence representations before the similarity matrix is computed between words in the premise and words in the hypothesis. similarity_function : ``SimilarityFunction`` This is the similarity function used when computing the similarity matrix between words in the premise and words in the hypothesis. compare_feedforward : ``FeedForward`` This feedforward network is applied to the aligned premise and hypothesis representations, individually. aggregate_feedforward : ``FeedForward`` This final feedforward network is applied to the concatenated, summed result of the ``compare_feedforward`` network, and its output is used as the entailment class logits. premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``) After embedding the premise, we can optionally apply an encoder. If this is ``None``, we will do nothing. hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``) After embedding the hypothesis, we can optionally apply an encoder. If this is ``None``, we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder`` is also ``None``). initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, attend_feedforward: FeedForward, similarity_function: SimilarityFunction, compare_feedforward: FeedForward, aggregate_feedforward: FeedForward, premise_encoder: Optional[Seq2SeqEncoder] = None, hypothesis_encoder: Optional[Seq2SeqEncoder] = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(DecomposableAttention, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._attend_feedforward = TimeDistributed(attend_feedforward) self._matrix_attention = LegacyMatrixAttention(similarity_function) self._compare_feedforward = TimeDistributed(compare_feedforward) self._aggregate_feedforward = aggregate_feedforward self._premise_encoder = premise_encoder self._hypothesis_encoder = hypothesis_encoder or premise_encoder self._num_labels = vocab.get_vocab_size(namespace="labels") check_dimensions_match(text_field_embedder.get_output_dim(), attend_feedforward.get_input_dim(), "text field embedding dim", "attend feedforward input dim") check_dimensions_match(aggregate_feedforward.get_output_dim(), self._num_labels, "final output dimension", "number of labels") self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder( embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat( [embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat( [embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward( hypothesis_compare_input) compared_hypothesis = compared_hypothesis * \ hypothesis_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) output_dict = { "h2p_attention": h2p_attention, "p2h_attention": p2h_attention, "final_hidden": aggregate_input, } if self._aggregate_feedforward: label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict["label_logits"] = label_logits output_dict["label_probs"] = label_probs if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'accuracy': self._accuracy.get_metric(reset), } @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'DecomposableAttention': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params( vocab, embedder_params) premise_encoder_params = params.pop("premise_encoder", None) premise_encoder = Seq2SeqEncoder.from_params( premise_encoder_params) if premise_encoder_params else None hypothesis_encoder_params = params.pop("hypothesis_encoder", None) hypothesis_encoder = Seq2SeqEncoder.from_params( hypothesis_encoder_params) if hypothesis_encoder_params else None attend_feedforward = FeedForward.from_params( params.pop('attend_feedforward')) similarity_function = SimilarityFunction.from_params( params.pop("similarity_function")) compare_feedforward = FeedForward.from_params( params.pop('compare_feedforward')) aggregated_params = params.pop('aggregate_feedforward', None) aggregate_feedforward = FeedForward.from_params( aggregated_params) if aggregated_params else None initializer = InitializerApplicator.from_params( params.pop('initializer', [])) regularizer = RegularizerApplicator.from_params( params.pop('regularizer', [])) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, attend_feedforward=attend_feedforward, similarity_function=similarity_function, compare_feedforward=compare_feedforward, aggregate_feedforward=aggregate_feedforward, premise_encoder=premise_encoder, hypothesis_encoder=hypothesis_encoder, initializer=initializer, regularizer=regularizer) def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # add label to output argmax_indices = output_dict['label_probs'].max(dim=-1)[1].data.numpy() output_dict['label'] = [ self.vocab.get_token_from_index(x, namespace="labels") for x in argmax_indices ] # do not show last hidden layer del output_dict["final_hidden"] return output_dict
class TweetJointly(Model): def __init__( self, vocab: Vocabulary, transformer_model_name: str = "bert-base-uncased", feedforward: Optional[FeedForward] = None, smoothing: bool = False, smooth_alpha: float = 0.7, sentiment_task: bool = False, sentiment_task_weight: float = 1.0, sentiment_classification_with_label: bool = True, sentiment_seq2vec: Optional[Seq2VecEncoder] = None, candidate_span_task: bool = False, candidate_span_task_weight: float = 1.0, candidate_delay: int = 30000, candidate_span_num: int = 5, candidate_classification_layer_units: int = 128, candidate_span_extractor: Optional[SpanExtractor] = None, candidate_span_with_logits: bool = False, dropout: Optional[float] = None, **kwargs, ) -> None: super().__init__(vocab, **kwargs) if "BERTweet" not in transformer_model_name: self._text_field_embedder = BasicTextFieldEmbedder({ "tokens": PretrainedTransformerEmbedder(transformer_model_name) }) else: self._text_field_embedder = BasicTextFieldEmbedder( {"tokens": TweetBertEmbedder(transformer_model_name)}) # span start & end task if feedforward is None: self._linear_layer = nn.Sequential( nn.Linear(self._text_field_embedder.get_output_dim(), 128), nn.ReLU(), nn.Linear(128, 2), ) else: self._linear_layer = feedforward self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._jaccard = Jaccard() self._candidate_delay = candidate_delay self._delay = 0 self._smoothing = smoothing self._smooth_alpha = smooth_alpha if smoothing: self._loss = nn.KLDivLoss(reduction="batchmean") else: self._loss = nn.CrossEntropyLoss() # sentiment task self._sentiment_task = sentiment_task if self._sentiment_task: self._sentiment_classification_accuracy = CategoricalAccuracy() self._sentiment_loss_log = LossLog() self.register_buffer("sentiment_task_weight", torch.tensor(sentiment_task_weight)) self._sentiment_classification_with_label = ( sentiment_classification_with_label) if sentiment_seq2vec is None: raise ConfigurationError( "sentiment task is True, we need a sentiment seq2vec encoder" ) else: self._sentiment_encoder = sentiment_seq2vec self._sentiment_linear = nn.Linear( self._sentiment_encoder.get_output_dim(), vocab.get_vocab_size("labels"), ) # candidate span task self._candidate_span_task = candidate_span_task if candidate_span_task: assert candidate_span_num > 0 assert candidate_span_task_weight > 0 assert candidate_classification_layer_units > 0 self._candidate_span_num = candidate_span_num self.register_buffer("candidate_span_task_weight", torch.tensor(candidate_span_task_weight)) self._candidate_classification_layer_units = ( candidate_classification_layer_units) self._span_classification_accuracy = CategoricalAccuracy() self._candidate_loss_log = LossLog() self._candidate_span_linear = nn.Linear( self._text_field_embedder.get_output_dim(), self._candidate_classification_layer_units, ) if candidate_span_extractor is None: self._candidate_span_extractor = EndpointSpanExtractor( input_dim=self._candidate_classification_layer_units) else: self._candidate_span_extractor = candidate_span_extractor if candidate_span_with_logits: self._candidate_with_logits = True self._candidate_span_vec_linear = nn.Linear( self._candidate_span_extractor.get_output_dim() + 1, 1) else: self._candidate_with_logits = False self._candidate_span_vec_linear = nn.Linear( self._candidate_span_extractor.get_output_dim(), 1) self._candidate_jaccard = Jaccard() if sentiment_task or candidate_span_task: self._base_loss_log = LossLog() else: self._base_loss_log = None if dropout is not None: self._dropout = nn.Dropout(dropout) else: self._dropout = None def forward( # type: ignore self, text: Dict[str, Dict[str, torch.LongTensor]], sentiment: torch.IntTensor, text_with_sentiment: Dict[str, Dict[str, torch.LongTensor]], text_span: torch.IntTensor, selected_text_span: Optional[torch.IntTensor] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # batch_size * text_length * hidden_dims embedded_question = self._text_field_embedder(text_with_sentiment) if self._dropout is not None: embedded_question = self._dropout(embedded_question) self._delay += int(embedded_question.size(0)) # span start & span end task logits = self._linear_layer(embedded_question) span_start_logits, span_end_logits = logits.split(1, dim=-1) span_start_logits = span_start_logits.squeeze(-1) span_end_logits = span_end_logits.squeeze(-1) possible_answer_mask = torch.zeros_like( util.get_token_ids_from_text_field_tensors( text_with_sentiment)).bool() for i, (start, end) in enumerate(text_span): possible_answer_mask[i, start:end + 1] = True span_start_logits = util.replace_masked_values(span_start_logits, possible_answer_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, possible_answer_mask, -1e32) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_spans = get_best_span(span_start_logits, span_end_logits) best_span_scores = torch.gather( span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather( span_end_logits, 1, best_spans[:, 1].unsqueeze(1)) best_span_scores = best_span_scores.squeeze(1) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_spans, "best_span_scores": best_span_scores, } loss = torch.tensor(0.0).to(embedded_question.device) # sentiment task if self._sentiment_task: if self._sentiment_classification_with_label: global_context_vec = self._sentiment_encoder(embedded_question) else: embedded_only_text = self._text_field_embedder(text) if self._dropout is not None: embedded_only_text = self._dropout(embedded_only_text) global_context_vec = self._sentiment_encoder( embedded_only_text) sentiment_logits = self._sentiment_linear(global_context_vec) sentiment_probs = torch.softmax(sentiment_logits, dim=-1) self._sentiment_classification_accuracy(sentiment_probs, sentiment) sentiment_loss = cross_entropy(sentiment_logits, sentiment) self._sentiment_loss_log(sentiment_loss) loss.add_(self.sentiment_task_weight * sentiment_loss) predict_sentiment_idx = sentiment_probs.argmax(dim=-1) sentiment_predicts = [] for i in predict_sentiment_idx.tolist(): sentiment_predicts.append( self.vocab.get_token_from_index(i, "labels")) output_dict["sentiment_logits"] = sentiment_logits output_dict["sentiment_probs"] = sentiment_probs output_dict["sentiment_predicts"] = sentiment_predicts # span classification if self._candidate_span_task and (self._delay >= self._candidate_delay): # shape: (batch_size, passage_length, embedding_dim) text_features_for_candidate = self._candidate_span_linear( embedded_question) text_features_for_candidate = torch.relu( text_features_for_candidate) with torch.no_grad(): # batch_size * candidate_num * 2 candidate_span = get_candidate_span(span_start_probs, span_end_probs, self._candidate_span_num) candidate_span_list = candidate_span.tolist() output_dict["candidate_spans"] = candidate_span_list if selected_text_span is not None: candidate_span, candidate_span_label = self.candidate_span_with_labels( candidate_span, selected_text_span) else: candidate_span_label = None # shape: (batch_size, candidate_num, span_extractor_output_dim) span_feature_vec = self._candidate_span_extractor( text_features_for_candidate, candidate_span) if self._candidate_with_logits: candidate_span_start_logits = torch.gather( span_start_logits, 1, candidate_span[:, :, 0]) candidate_span_end_logits = torch.gather( span_end_logits, 1, candidate_span[:, :, 1]) candidate_span_sum_logits = (candidate_span_start_logits + candidate_span_end_logits) span_feature_vec = torch.cat( (span_feature_vec, candidate_span_sum_logits.unsqueeze(2)), -1) # batch_size * candidate_num span_classification_logits = self._candidate_span_vec_linear( span_feature_vec).squeeze() span_classification_probs = torch.softmax( span_classification_logits, -1) output_dict[ "span_classification_probs"] = span_classification_probs candidate_best_span_idx = span_classification_probs.argmax(dim=-1) view_idx = ( candidate_best_span_idx + torch.arange(0, end=candidate_best_span_idx.shape[0]).to( candidate_best_span_idx.device) * self._candidate_span_num) candidate_span_view = candidate_span.view(-1, 2) candidate_best_spans = candidate_span_view.index_select( 0, view_idx) output_dict["candidate_best_spans"] = candidate_best_spans.tolist() if selected_text_span is not None: self._span_classification_accuracy(span_classification_probs, candidate_span_label) candidate_span_loss = cross_entropy(span_classification_logits, candidate_span_label) self._candidate_loss_log(candidate_span_loss) weighted_loss = self.candidate_span_task_weight * candidate_span_loss if candidate_span_loss > 1e2: print(f"candidate loss: {candidate_span_loss}") print( f"span_classification_logits: {span_classification_logits}" ) print(f"candidate_span_label: {candidate_span_label}") loss.add_(weighted_loss) candidate_best_spans = candidate_best_spans.detach().cpu().numpy() output_dict["best_candidate_span_str"] = [] for metadata_entry, best_span in zip(metadata, candidate_best_spans): text_with_sentiment_tokens = metadata_entry[ "text_with_sentiment_tokens"] predicted_start, predicted_end = tuple(best_span) if predicted_end >= len(text_with_sentiment_tokens): predicted_end = len(text_with_sentiment_tokens) - 1 best_span_string = self.span_tokens_to_text( metadata_entry["text"], text_with_sentiment_tokens, predicted_start, predicted_end, ) output_dict["best_candidate_span_str"].append(best_span_string) answers = metadata_entry.get("selected_text", "") if len(answers) > 0: self._candidate_jaccard(best_span_string, answers) # Compute the loss for training. if selected_text_span is not None: span_start = selected_text_span[:, 0] span_end = selected_text_span[:, 1] span_mask = span_start != -1 self._span_accuracy( best_spans, selected_text_span, span_mask.unsqueeze(-1).expand_as(best_spans), ) if not self._smoothing: start_loss = cross_entropy(span_start_logits, span_start, ignore_index=-1) if torch.any(start_loss > 1e9): logger.critical("Start loss too high (%r)", start_loss) logger.critical("span_start_logits: %r", span_start_logits) logger.critical("span_start: %r", span_start) logger.critical("text_with_sentiment: %r", text_with_sentiment) assert False end_loss = cross_entropy(span_end_logits, span_end, ignore_index=-1) if torch.any(end_loss > 1e9): logger.critical("End loss too high (%r)", end_loss) logger.critical("span_end_logits: %r", span_end_logits) logger.critical("span_end: %r", span_end) assert False else: sequence_length = span_start_logits.size(1) device = span_start.device start_distance = get_sequence_distance_from_span_endpoint( sequence_length, span_start) start_smooth_probs = torch.exp( start_distance * torch.log(torch.tensor(self._smooth_alpha).to(device))) start_smooth_probs = start_smooth_probs * possible_answer_mask start_smooth_probs = start_smooth_probs / start_smooth_probs.sum( -1, keepdim=True) span_start_log_probs = span_start_logits - torch.log( torch.exp(span_start_logits).sum(-1)).unsqueeze(-1) end_distance = get_sequence_distance_from_span_endpoint( sequence_length, span_end) end_smooth_probs = torch.exp( end_distance * torch.log(torch.tensor(self._smooth_alpha).to(device))) end_smooth_probs = end_smooth_probs * possible_answer_mask end_smooth_probs = end_smooth_probs / end_smooth_probs.sum( -1, keepdim=True) span_end_log_probs = span_end_logits - torch.log( torch.exp(span_end_logits).sum(-1)).unsqueeze(-1) # print(end_smooth_probs) # print(start_smooth_probs) # print(span_end_log_probs) # print(span_start_log_probs) start_loss = self._loss(span_start_log_probs, start_smooth_probs) end_loss = self._loss(span_end_log_probs, end_smooth_probs) span_start_end_loss = (start_loss + end_loss) / 2 if self._base_loss_log is not None: self._base_loss_log(span_start_end_loss) loss.add_(span_start_end_loss) self._span_start_accuracy(span_start_logits, span_start, span_mask) self._span_end_accuracy(span_end_logits, span_end, span_mask) output_dict["loss"] = loss # compute best span jaccard best_spans = best_spans.detach().cpu().numpy() output_dict["best_span_str"] = [] for metadata_entry, best_span in zip(metadata, best_spans): text_with_sentiment_tokens = metadata_entry[ "text_with_sentiment_tokens"] predicted_start, predicted_end = tuple(best_span) best_span_string = self.span_tokens_to_text( metadata_entry["text"], text_with_sentiment_tokens, predicted_start, predicted_end, ) output_dict["best_span_str"].append(best_span_string) answers = metadata_entry.get("selected_text", "") if len(answers) > 0: self._jaccard(best_span_string, answers) return output_dict # @staticmethod # def candidate_span_with_labels( # candidate_span: torch.Tensor, selected_text_span: torch.Tensor # ) -> Tuple[torch.Tensor, torch.Tensor]: # correct_span_idx = (candidate_span == selected_text_span.unsqueeze(1)).prod(-1) # candidate_span_adjust = torch.where( # ~(correct_span_idx.unsqueeze(-1) == 1), # candidate_span, # selected_text_span.unsqueeze(1), # ) # candidate_span_label = correct_span_idx.argmax(-1) # return candidate_span_adjust, candidate_span_label @staticmethod def candidate_span_with_labels( candidate_span: torch.Tensor, selected_text_span: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: candidate_span_label = batch_span_jaccard( candidate_span, selected_text_span).max(-1).indices return candidate_span, candidate_span_label @staticmethod def get_candidate_span_mask(candidate_span: torch.Tensor, passage_length: int) -> torch.Tensor: device = candidate_span.device batch_size, candidate_num = candidate_span.size()[:-1] candidate_span_mask = torch.zeros(batch_size, candidate_num, passage_length).to(device) for i in range(batch_size): for j in range(candidate_num): span_start, span_end = candidate_span[i][j] candidate_span_mask[i][j][span_start:span_end + 1] = 1 return candidate_span_mask @staticmethod def span_tokens_to_text(source_text, tokens, span_start, span_end): text_with_sentiment_tokens = tokens predicted_start = span_start predicted_end = span_end while (predicted_start >= 0 and text_with_sentiment_tokens[predicted_start].idx is None): predicted_start -= 1 if predicted_start < 0: logger.warning( f"Could not map the token '{text_with_sentiment_tokens[span_start].text}' at index " f"'{span_start}' to an offset in the original text.") character_start = 0 else: character_start = text_with_sentiment_tokens[predicted_start].idx while (predicted_end < len(text_with_sentiment_tokens) and text_with_sentiment_tokens[predicted_end].idx is None): predicted_end -= 1 if predicted_end >= len(text_with_sentiment_tokens): print(text_with_sentiment_tokens) print(len(text_with_sentiment_tokens)) print(span_end) print(predicted_end) logger.warning( f"Could not map the token '{text_with_sentiment_tokens[span_end].text}' at index " f"'{span_end}' to an offset in the original text.") character_end = len(source_text) else: end_token = text_with_sentiment_tokens[predicted_end] if end_token.idx == 0: character_end = (end_token.idx + len(sanitize_wordpiece(end_token.text)) + 1) else: character_end = end_token.idx + len( sanitize_wordpiece(end_token.text)) best_span_string = source_text[character_start:character_end].strip() return best_span_string def get_metrics(self, reset: bool = False) -> Dict[str, float]: jaccard = self._jaccard.get_metric(reset) metrics = { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "jaccard": jaccard, } if self._candidate_span_task: metrics[ "candidate_span_acc"] = self._span_classification_accuracy.get_metric( reset) metrics["candidate_jaccard"] = self._candidate_jaccard.get_metric( reset) metrics["candidate_loss"] = self._candidate_loss_log.get_metric( reset) if self._sentiment_task: metrics[ "sentiment_acc"] = self._sentiment_classification_accuracy.get_metric( reset) metrics["sentiment_loss"] = self._sentiment_loss_log.get_metric( reset) if self._base_loss_log is not None: metrics["base_loss"] = self._base_loss_log.get_metric(reset) return metrics
class LSTMBatchNormFreezeDetGlobalNoFinalImageFull(Model): def __init__(self, vocab: Vocabulary, option_encoder: Seq2SeqEncoder, input_dropout: float = 0.3, initializer: InitializerApplicator = InitializerApplicator(), ): super(LSTMBatchNormFreezeDetGlobalNoFinalImageFull, self).__init__(vocab) self.rnn_input_dropout = TimeDistributed(InputVariationalDropout(input_dropout)) if input_dropout > 0 else None self.detector = SimpleDetector(pretrained=True, average_pool=True, semantic=False, final_dim=512) # freeze everything related to conv net for submodule in self.detector.backbone.modules(): # if isinstance(submodule, BatchNorm2d): # submodule.track_running_stats = False for p in submodule.parameters(): p.requires_grad = False for submodule in self.detector.after_roi_align.modules(): # if isinstance(submodule, BatchNorm2d): # submodule.track_running_stats = False for p in submodule.parameters(): p.requires_grad = False self.image_BN = BatchNorm1d(512) self.option_encoder = TimeDistributed(option_encoder) self.option_BN = torch.nn.Sequential( BatchNorm1d(512) ) self.query_BN = torch.nn.Sequential( BatchNorm1d(512) ) self.final_mlp = torch.nn.Sequential( torch.nn.Linear(1024, 512), torch.nn.ReLU(inplace=True), ) self.final_BN = torch.nn.Sequential( BatchNorm1d(512) ) self.final_mlp_linear = torch.nn.Sequential( torch.nn.Linear(512,1) ) self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) # recevie redundent parameters for convinence def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp(span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1) def embed_span(self, span, span_tags, span_mask, object_reps): """ :param span: Thing that will get embed and turned into [batch_size, ..leading_dims.., L, word_dim] :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :param span_mask: [batch_size, ..leading_dims.., span_mask :return: """ retrieved_feats = self._collect_obj_reps(span_tags, object_reps) span_rep = torch.cat((span['bert'], retrieved_feats), -1) # add recurrent dropout here if self.rnn_input_dropout: span_rep = self.rnn_input_dropout(span_rep) return span_rep, retrieved_feats def forward(self, images: torch.Tensor, objects: torch.LongTensor, segms: torch.Tensor, boxes: torch.Tensor, box_mask: torch.LongTensor, question: Dict[str, torch.Tensor], question_tags: torch.LongTensor, question_mask: torch.LongTensor, answers: Dict[str, torch.Tensor], answer_tags: torch.LongTensor, answer_mask: torch.LongTensor, metadata: List[Dict[str, Any]] = None, label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: """ :param images: [batch_size, 3, im_height, im_width] :param objects: [batch_size, max_num_objects] Padded objects :param boxes: [batch_size, max_num_objects, 4] Padded boxes :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK :param question: AllenNLP representation of the question. [batch_size, num_answers, seq_length] :param question_tags: A detection label for each item in the Q [batch_size, num_answers, seq_length] :param question_mask: Mask for the Q [batch_size, num_answers, seq_length] :param answers: AllenNLP representation of the answer. [batch_size, num_answers, seq_length] :param answer_tags: A detection label for each item in the A [batch_size, num_answers, seq_length] :param answer_mask: Mask for the As [batch_size, num_answers, seq_length] :param metadata: Ignore, this is about which dataset item we're on :param label: Optional, which item is valid :return: shit """ # Trim off boxes that are too long. this is an issue b/c dataparallel, it'll pad more zeros that are # not needed max_len = int(box_mask.sum(1).max().item()) objects = objects[:, :max_len] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] segms = segms[:, :max_len] obj_reps = self.detector(images=images, boxes=boxes, box_mask=box_mask, classes=objects, segms=segms) # option part batch_size, num_options, padded_seq_len, _ = answers['bert'].shape options, option_obj_reps = self.embed_span(answers, answer_tags, answer_mask, obj_reps['obj_reps']) assert (options.shape == (batch_size, num_options, padded_seq_len, 1280)) option_rep = self.option_encoder(options, answer_mask) # (batch_size, 4, seq_len, emb_len(512)) option_rep = replace_masked_values(option_rep, answer_mask[...,None], 0) seq_real_length = torch.sum(answer_mask, dim=-1, dtype=torch.float) # (batch_size, 4) seq_real_length = seq_real_length.view(-1,1) # (batch_size * 4,1) option_rep = option_rep.sum(dim=2) # (batch_size, 4, emb_len(512)) option_rep = option_rep.view(batch_size * num_options,512) # (batch_size * 4, emb_len(512)) option_rep = option_rep.div(seq_real_length) # (batch_size * 4, emb_len(512)) option_rep = self.option_BN(option_rep) option_rep = option_rep.view(batch_size, num_options, 512) # (batch_size, 4, emb_len(512)) # query part batch_size, num_options, padded_seq_len, _ = question['bert'].shape query, query_obj_reps = self.embed_span(question, question_tags, question_mask, obj_reps['obj_reps']) assert (query.shape == (batch_size, num_options, padded_seq_len, 1280)) query_rep = self.option_encoder(query, question_mask) # (batch_size, 4, seq_len, emb_len(512)) query_rep = replace_masked_values(query_rep, question_mask[...,None], 0) seq_real_length = torch.sum(question_mask, dim=-1, dtype=torch.float) # (batch_size, 4) seq_real_length = seq_real_length.view(-1,1) # (batch_size * 4,1) query_rep = query_rep.sum(dim=2) # (batch_size, 4, emb_len(512)) query_rep = query_rep.view(batch_size * num_options,512) # (batch_size * 4, emb_len(512)) query_rep = query_rep.div(seq_real_length) # (batch_size * 4, emb_len(512)) query_rep = self.query_BN(query_rep) query_rep = query_rep.view(batch_size, num_options, 512) # (batch_size, 4, emb_len(512)) # image part # assert (obj_reps['obj_reps'][:,0,:].shape == (batch_size, 512)) # images = obj_reps['obj_reps'][:,0,:] # the background i.e. whole image # images = self.image_BN(images) # images = images[:,None,:] # images = images.repeat(1,4,1) # (batch_size, 4, 512) # assert (images.shape == (batch_size, num_options,512)) query_option_image_cat = torch.cat((option_rep,query_rep),-1) assert (query_option_image_cat.shape == (batch_size,num_options, 512*2)) query_option_image_cat = self.final_mlp(query_option_image_cat) query_option_image_cat = query_option_image_cat.view(batch_size*num_options,512) query_option_image_cat = self.final_BN(query_option_image_cat) query_option_image_cat = query_option_image_cat.view(batch_size,num_options,512) logits = self.final_mlp_linear(query_option_image_cat) logits = logits.squeeze(2) class_probabilities = F.softmax(logits, dim=-1) output_dict = {"label_logits": logits, "label_probs": class_probabilities} if label is not None: loss = self._loss(logits, label.long().view(-1)) self._accuracy(logits, label) output_dict["loss"] = loss[None] # print ('one pass') return output_dict def get_metrics(self,reset=False): return {'accuracy': self._accuracy.get_metric(reset)}
class DependencyParser(Model): """ This dependency parser follows the model of ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . Word representations are generated using a bidirectional LSTM, followed by separate biaffine classifiers for pairs of words, predicting whether a directed arc exists between the two words and the dependency label the arc should have. Decoding can either be done greedily, or the optimal Minimum Spanning Tree can be decoded using Edmond's algorithm by viewing the dependency tree as a MST on a fully connected graph, where nodes are words and edges are scored dependency arcs. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. encoder : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use to generate representations of tokens. tag_representation_dim : ``int``, required. The dimension of the MLPs used for dependency tag prediction. arc_representation_dim : ``int``, required. The dimension of the MLPs used for head arc prediction. tag_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce tag representations. By default, a 1 layer feedforward network with an elu activation is used. arc_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce arc representations. By default, a 1 layer feedforward network with an elu activation is used. pos_tag_embedding : ``Embedding``, optional. Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model. use_mst_decoding_for_validation : ``bool``, optional (default = True). Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation. If false, decoding is greedy. dropout : ``float``, optional, (default = 0.0) The variational dropout applied to the output of the encoder and MLP layers. input_dropout : ``float``, optional, (default = 0.0) The dropout applied to the embedded text input. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, tag_representation_dim: int, arc_representation_dim: int, lemmatize_helper: LemmatizeHelper, task_config: TaskConfig, morpho_vector_dim: int = 0, gram_val_representation_dim: int = -1, lemma_representation_dim: int = -1, tag_feedforward: FeedForward = None, arc_feedforward: FeedForward = None, pos_tag_embedding: Embedding = None, use_mst_decoding_for_validation: bool = True, dropout: float = 0.0, input_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(DependencyParser, self).__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder self.encoder = encoder self.lemmatize_helper = lemmatize_helper self.task_config = task_config encoder_dim = encoder.get_output_dim() self.head_arc_feedforward = arc_feedforward or \ FeedForward(encoder_dim, 1, arc_representation_dim, Activation.by_name("elu")()) self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) self.arc_attention = BilinearMatrixAttention(arc_representation_dim, arc_representation_dim, use_input_biases=True) num_labels = self.vocab.get_vocab_size("head_tags") self.head_tag_feedforward = tag_feedforward or \ FeedForward(encoder_dim, 1, tag_representation_dim, Activation.by_name("elu")()) self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim, tag_representation_dim, num_labels) self._pos_tag_embedding = pos_tag_embedding or None assert self.task_config.params.get("use_pos_tag", False) == (self._pos_tag_embedding is not None) self._dropout = InputVariationalDropout(dropout) self._input_dropout = Dropout(input_dropout) self._head_sentinel = torch.nn.Parameter( torch.randn([1, 1, encoder.get_output_dim()])) if gram_val_representation_dim <= 0: self._gram_val_output = torch.nn.Linear( encoder_dim, self.vocab.get_vocab_size("grammar_value_tags")) else: self._gram_val_output = torch.nn.Sequential( Dropout(dropout), torch.nn.Linear(encoder_dim, gram_val_representation_dim), Dropout(dropout), torch.nn.Linear( gram_val_representation_dim, self.vocab.get_vocab_size("grammar_value_tags"))) if lemma_representation_dim <= 0: self._lemma_output = torch.nn.Linear(encoder_dim, len(lemmatize_helper)) else: self._lemma_output = torch.nn.Sequential( Dropout(dropout), torch.nn.Linear(encoder_dim, lemma_representation_dim), Dropout(dropout), torch.nn.Linear(lemma_representation_dim, len(lemmatize_helper))) representation_dim = text_field_embedder.get_output_dim( ) + morpho_vector_dim if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match(representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim") check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(), "tag representation dim", "tag feedforward output dim") check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(), "arc representation dim", "arc feedforward output dim") self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = { tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE } self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info( f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self._attachment_scores = AttachmentScores() self._gram_val_prediction_accuracy = CategoricalAccuracy() self._lemma_prediction_accuracy = CategoricalAccuracy() initializer(self) @overrides def forward( self, # type: ignore words: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], morpho_embedding: torch.FloatTensor = None, pos_tags: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. metadata : List[Dict[str, Any]], optional (default=None) A dictionary of metadata for each batch element which has keys: words : ``List[str]``, required. The tokens in the original sentence. pos : ``List[str]``, required. The dependencies POS tags for each word. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if morpho_embedding is not None: embedded_text_input = torch.cat( [embedded_text_input, morpho_embedding], -1) if grammar_values is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(grammar_values) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) output_dict = self._parse(embedded_text_input, mask, head_tags, head_indices, grammar_values, lemma_indices) if self.task_config.task_type == "multitask": losses = ["arc_nll", "tag_nll", "grammar_nll", "lemma_nll"] elif self.task_config.task_type == "single": if self.task_config.params["model"] == "morphology": losses = ["grammar_nll"] elif self.task_config.params["model"] == "lemmatization": losses = ["lemma_nll"] elif self.task_config.params["model"] == "syntax": losses = ["arc_nll", "tag_nll"] else: assert False, "Unknown model type {}".format( self.task_config.params["model"]) else: assert False, "Unknown task type {}".format( self.task_config.task_type) output_dict["loss"] = sum(output_dict[loss_name] for loss_name in losses) if head_indices is not None and head_tags is not None: evaluation_mask = self._get_mask_for_eval(mask, pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(output_dict["heads"][:, 1:], output_dict["head_tags"][:, 1:], head_indices, head_tags, evaluation_mask) output_dict["words"] = [meta["words"] for meta in metadata] if metadata and "pos" in metadata[0]: output_dict["pos"] = [meta["pos"] for meta in metadata] return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: head_tags = output_dict.pop("head_tags").cpu().detach().numpy() heads = output_dict.pop("heads").cpu().detach().numpy() predicted_gram_vals = output_dict.pop( "gram_vals").cpu().detach().numpy() predicted_lemmas = output_dict.pop("lemmas").cpu().detach().numpy() mask = output_dict.pop("mask") lengths = get_lengths_from_binary_sequence_mask(mask) assert len(head_tags) == len(heads) == len(lengths) == len( predicted_gram_vals) == len(predicted_lemmas) head_tag_labels, head_indices, decoded_gram_vals, decoded_lemmas = [], [], [], [] for instance_index in range(len(head_tags)): instance_heads, instance_tags = heads[instance_index], head_tags[ instance_index] words, length = output_dict["words"][instance_index], lengths[ instance_index] gram_vals, lemmas = predicted_gram_vals[ instance_index], predicted_lemmas[instance_index] words = words[:length.item() - 1] gram_vals = gram_vals[:length.item() - 1] lemmas = lemmas[:length.item() - 1] instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [ self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags ] head_tag_labels.append(labels) head_indices.append(instance_heads) decoded_gram_vals.append([ self.vocab.get_token_from_index(gram_val, "grammar_value_tags") for gram_val in gram_vals ]) decoded_lemmas.append([ self.lemmatize_helper.lemmatize(word, lemmatize_rule_index) for word, lemmatize_rule_index in zip(words, lemmas) ]) if self.task_config.task_type == "multitask": output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices output_dict["predicted_gram_vals"] = decoded_gram_vals output_dict["predicted_lemmas"] = decoded_lemmas elif self.task_config.task_type == "single": if self.task_config.params["model"] == "morphology": output_dict["predicted_gram_vals"] = decoded_gram_vals elif self.task_config.params["model"] == "lemmatization": output_dict["predicted_lemmas"] = decoded_lemmas elif self.task_config.params["model"] == "syntax": output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices else: assert False, "Unknown model type {}".format( self.task_config.params["model"]) else: assert False, "Unknown task type {}".format( self.task_config.task_type) return output_dict def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) grammar_value_logits = self._gram_val_output(encoded_text) predicted_gram_vals = grammar_value_logits.argmax(-1) lemma_logits = self._lemma_output(encoded_text) predicted_lemmas = lemma_logits.argmax(-1) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: grammar_nll = self._update_multiclass_prediction_metrics( logits=grammar_value_logits, targets=grammar_values, mask=token_mask, accuracy_metric=self._gram_val_prediction_accuracy) lemma_nll = torch.tensor(0.) if lemma_indices is not None: lemma_nll = self._update_multiclass_prediction_metrics( logits=lemma_logits, targets=lemma_indices, mask=token_mask, accuracy_metric=self._lemma_prediction_accuracy, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, } return output_dict @staticmethod def _update_multiclass_prediction_metrics(logits, targets, mask, accuracy_metric, masked_index=None): accuracy_metric(logits, targets, mask) logits = logits.view(-1, logits.shape[-1]) loss = F.cross_entropy(logits, targets.view(-1), reduction='none') if masked_index is not None: mask = mask * (targets != masked_index) loss_mask = mask.view(-1) return (loss * loss_mask).sum() / loss_mask.sum() def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax( attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax( head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag( attended_arcs.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, heads) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _mst_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions using the Edmonds' Algorithm for finding minimum spanning trees on directed graphs. Nodes in the graph are the words in the sentence, and between each pair of nodes, there is an edge in each direction, where the weight of the edge corresponds to the most likely dependency label probability for that arc. The MST is then generated from this directed graph. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the optimally decoded heads of each word. """ batch_size, sequence_length, tag_representation_dim = head_tag_representation.size( ) lengths = mask.data.sum(dim=1).long().cpu().numpy() expanded_shape = [ batch_size, sequence_length, sequence_length, tag_representation_dim ] head_tag_representation = head_tag_representation.unsqueeze(2) head_tag_representation = head_tag_representation.expand( *expanded_shape).contiguous() child_tag_representation = child_tag_representation.unsqueeze(1) child_tag_representation = child_tag_representation.expand( *expanded_shape).contiguous() # Shape (batch_size, sequence_length, sequence_length, num_head_tags) pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) # Note that this log_softmax is over the tag dimension, and we don't consider pairs # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. # Shape (batch, num_labels,sequence_length, sequence_length) normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute( 0, 3, 1, 2) # Mask padded tokens, because we only want to consider actual words as heads. minus_inf = -1e8 minus_mask = (1 - mask.float()) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) # Shape (batch_size, sequence_length, sequence_length) normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) # Shape (batch_size, num_head_tags, sequence_length, sequence_length) # This energy tensor expresses the following relation: # energy[i,j] = "Score that i is the head of j". In this # case, we have heads pointing to their children. batch_energy = torch.exp( normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) return self._run_mst_decoding(batch_energy, lengths) @staticmethod def _run_mst_decoding( batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: heads = [] head_tags = [] for energy, length in zip(batch_energy.detach().cpu(), lengths): scores, tag_ids = energy.max(dim=0) # Although we need to include the root node so that the MST includes it, # we do not want any word to be the parent of the root node. # Here, we enforce this by setting the scores for all word -> ROOT edges # edges to be 0. scores[0, :] = 0 # Decode the heads. Because we modify the scores to prevent # adding in word -> ROOT edges, we need to find the labels ourselves. instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False) # Find the labels which correspond to the edges in the max spanning tree. instance_head_tags = [] for child, parent in enumerate(instance_heads): instance_head_tags.append(tag_ids[parent, child].item()) # We don't care what the head or tag is for the root token, but by default it's # not necesarily the same in the batched vs unbatched case, which is annoying. # Here we'll just set them to zero. instance_heads[0] = 0 instance_head_tags[0] = 0 heads.append(instance_heads) head_tags.append(instance_head_tags) return torch.from_numpy(numpy.stack(heads)), torch.from_numpy( numpy.stack(head_tags)) def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[ range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous( ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = self._attachment_scores.get_metric(reset) metrics['GramValAcc'] = self._gram_val_prediction_accuracy.get_metric( reset) metrics['LemmaAcc'] = self._lemma_prediction_accuracy.get_metric(reset) metrics['MeanAcc'] = (metrics['GramValAcc'] + metrics['LemmaAcc'] + metrics['LAS']) / 3. return metrics
class ImageTextMatchingModel(Model): def __init__(self, vocab: Vocabulary, scibert_path: str, pretrained: bool = True, fusion_layer: int = 0, num_layers: int = None, image_root: str = None, full_matching: bool = False, retrieval_file: str = None, dropout: float = None, pretrained_bert: bool = False, tokens_namespace: str = "tokens", labels_namespace: str = "labels"): super().__init__(vocab) image_model = torchvision.models.resnet50(pretrained=pretrained) self.image_classifier_in_features = image_model.fc.in_features image_model.fc = torch.nn.Identity() self.image_feature_extractor = image_model config = BertConfig.from_json_file( os.path.join(scibert_path, 'config.json')) self.tokenizer = BertTokenizer(config=config, vocab_file=os.path.join( scibert_path, 'vocab.txt')) if dropout is not None: config.hidden_dropout_prob = dropout if num_layers is not None: config.num_hidden_layers = num_layers num_visual_positions = 1 self.bert = VisBertModel(config, self.image_classifier_in_features, fusion_layer, num_visual_positions) num_classifier_in_features = self.bert.config.hidden_size self.matching_classifier = torch.nn.Linear(num_classifier_in_features, 1) if pretrained_bert: state = torch.load(os.path.join(scibert_path, 'pytorch_model.bin')) filtered_state = {} for key in state: if key[:5] == 'bert.': filtered_state[key[5:]] = state[key] self.bert.load_state_dict(filtered_state, strict=False) self.mode = "binary_matching" self.max_sequence_length = 512 self.head_input_feature_dim = self.bert.config.hidden_size self._tokens_namespace = tokens_namespace if full_matching: expected_img_size = 224 self.image_transform = transforms.Compose([ transforms.Resize(expected_img_size), transforms.CenterCrop(expected_img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) self.mode = "full_matching" self.images = [] self.image_id_index_map = {} f = open(retrieval_file) lines = f.readlines() for line in lines: fname = json.loads(line)['image_id'] img = Image.open(os.path.join(image_root, fname)).convert('RGB') self.images.append(self.image_transform(img)) self.image_id_index_map[fname] = len(self.images) - 1 self.images = torch.stack(self.images) self.top5_accuracy = CategoricalAccuracy(top_k=5) self.top10_accuracy = CategoricalAccuracy(top_k=10) self.top20_accuracy = CategoricalAccuracy(top_k=20) self.loss = torch.nn.CrossEntropyLoss() self.accuracy = CategoricalAccuracy() def forward( self, token_ids: Dict[str, torch.Tensor], segment_ids: torch.Tensor, images: torch.Tensor, image_ids: List[str], mask: torch.Tensor = None, matching_labels: torch.Tensor = None, labels: Dict[str, torch.Tensor] = None, labels_text: Dict[str, torch.Tensor] = None, words_tokens_map: torch.Tensor = None, words_to_mask: torch.Tensor = None, tokens_to_mask: torch.Tensor = None, question: List[str] = None, answer: List[str] = None, category: torch.Tensor = None, ): original_token_ids = token_ids if mask is not None: input_mask = mask else: input_mask = util.get_text_field_mask(token_ids) token_ids = token_ids[self._tokens_namespace] batch_size = min(token_ids.shape[0], images.shape[0]) if images.shape[0] > token_ids.shape[0]: repeat_num = images.shape[0] // token_ids.shape[0] token_ids = token_ids.repeat(1, repeat_num).view( -1, token_ids.shape[-1]) input_mask = input_mask.repeat(1, repeat_num).view( -1, input_mask.shape[-1]) segment_ids = segment_ids.repeat(1, repeat_num).view( -1, segment_ids.shape[-1]) elif token_ids.shape[0] > images.shape[0]: repeat_num = token_ids.shape[0] // images.shape[0] images = images.unsqueeze(1).repeat(1, repeat_num, 1, 1, 1).view( -1, images.shape[1], images.shape[2], images.shape[3]) input_token_ids = token_ids visual_feats = self.image_feature_extractor(images) visual_feats = visual_feats.view(token_ids.shape[0], self.image_classifier_in_features) positions = torch.zeros( (visual_feats.shape[0], 4)).to(visual_feats.device).float() visual_inputs = (positions, visual_feats) sequence_encodings, joint_representation = self.bert( input_ids=input_token_ids, attention_mask=input_mask, token_type_ids=segment_ids, visual_inputs=visual_inputs) outputs = {'loss': torch.tensor(0.)} if self.mode == "binary_matching": joint_representation = joint_representation.view( batch_size, -1, joint_representation.shape[-1]) match_predictions = self.matching_classifier( joint_representation).view(batch_size, -1) if labels is not None: loss = self.loss(match_predictions, labels) outputs['loss'] = loss self.accuracy(match_predictions, labels) if self.mode == "full_matching": with torch.no_grad(): predictions = torch.zeros( (batch_size, self.images.shape[0])).to(images.device) batch = [] batch_image_ids = [] finished_images = set() sub_batch_size = 100 for i in range(0, self.images.shape[0], sub_batch_size): batch_images = self.images[i:i + sub_batch_size].to( images.device) visual_feats = self.image_feature_extractor(batch_images) visual_feats = visual_feats.view( -1, self.image_classifier_in_features).repeat( batch_size, 1) positions = torch.zeros( (visual_feats.shape[0], 4)).to(visual_feats.device).float() visual_inputs = (positions, visual_feats) batch_input_token_ids = input_token_ids.repeat( 1, batch_images.shape[0]).view(-1, input_token_ids.shape[-1]) batch_input_mask = input_mask.repeat( 1, batch_images.shape[0]).view(-1, input_mask.shape[-1]) batch_segment_ids = segment_ids.repeat( 1, batch_images.shape[0]).view(-1, segment_ids.shape[-1]) sequence_encodings, joint_representation = self.bert( input_ids=batch_input_token_ids, attention_mask=batch_input_mask, token_type_ids=batch_segment_ids, visual_inputs=visual_inputs) joint_representation = joint_representation.view( batch_size, batch_images.shape[0], joint_representation.shape[-1]) match_predictions = self.matching_classifier( joint_representation).view(batch_size, batch_images.shape[0]) predictions[:, i:i + batch_images.shape[0]] = match_predictions labels = torch.Tensor([ self.image_id_index_map[image_id] for image_id in image_ids ]).long().to(images.device) outputs['loss'] = self.loss(predictions, labels) outputs['predictions'] = predictions outputs['image_ids'] = image_ids self.accuracy(predictions, labels) self.top5_accuracy(predictions, labels) self.top10_accuracy(predictions, labels) self.top20_accuracy(predictions, labels) return outputs def get_metrics(self, reset: bool = False): metrics = {'accuracy': self.accuracy.get_metric(reset)} if "full_matching" in self.mode: metrics["top5_accuracy"] = self.top5_accuracy.get_metric(reset) metrics["top10_accuracy"] = self.top10_accuracy.get_metric(reset) metrics["top20_accuracy"] = self.top20_accuracy.get_metric(reset) return metrics
class SequenceClassifier(Model): """ This ``SequenceClassifier`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then predicts a label for the sequence. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. stacked_encoder : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and predicting output tags. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, stacked_encoder: Seq2SeqEncoder, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(SequenceClassifier, self).__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder self.num_classes = self.vocab.get_vocab_size("labels") self.stacked_encoder = stacked_encoder self.projection_layer = Linear(self.stacked_encoder.get_output_dim(), self.num_classes) if text_field_embedder.get_output_dim( ) != stacked_encoder.get_input_dim(): raise ConfigurationError( "The output dimension of the text_field_embedder must match the " "input dimension of the phrase_encoder. Found {} and {}, " "respectively.".format(text_field_embedder.get_output_dim(), stacked_encoder.get_input_dim())) self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) @overrides def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)``. Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text_input = self.text_field_embedder(tokens) batch_size, sequence_length, _ = embedded_text_input.size() mask = get_text_field_mask(tokens) encoded_text = self.stacked_encoder(embedded_text_input, mask) logits = self.projection_layer(torch.mean(encoded_text, 1).squeeze()) class_probabilities = F.softmax(logits) output_dict = { "logits": logits, "class_probabilities": class_probabilities } if label is not None: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss self._accuracy(logits, label.squeeze(-1)) return output_dict @overrides def forward_on_instance(self, instance: Instance, cuda_device: int) -> Dict[str, numpy.ndarray]: """ Takes an :class:`~allennlp.data.instance.Instance`, which typically has raw text in it, converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing) and returns the result. Before returning the result, we convert any ``torch.autograd.Variables`` or ``torch.Tensors`` into numpy arrays and remove the batch dimension. """ instance.index_fields(self.vocab) model_input = arrays_to_variables(instance.as_array_dict(), add_batch_dimension=True, cuda_device=cuda_device, for_training=False) outputs = self.decode(self.forward(**model_input)) for name, output in list(outputs.items()): if isinstance(output, torch.autograd.Variable): output = output.data.cpu().numpy() outputs[name] = output return outputs @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Does a simple position-wise argmax over each token, converts indices to string labels, and adds a ``"tags"`` key to the dictionary with the result. """ all_predictions = output_dict['class_probabilities'] if not isinstance(all_predictions, numpy.ndarray): all_predictions = all_predictions.data.numpy() argmax_i = numpy.argmax(all_predictions) logger.info(argmax_i) label = self.vocab.get_token_from_index(argmax_i, namespace="labels") output_dict['label'] = label return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'accuracy': self._accuracy.get_metric(reset)} @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'SimpleTagger': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params( vocab, embedder_params) stacked_encoder = Seq2SeqEncoder.from_params( params.pop("stacked_encoder")) initializer = InitializerApplicator.from_params( params.pop('initializer', [])) regularizer = RegularizerApplicator.from_params( params.pop('regularizer', [])) return cls(vocab=vocab, text_field_embedder=text_field_embedder, stacked_encoder=stacked_encoder, initializer=initializer, regularizer=regularizer)
class BasicClassifier(Model): """ This ``Model`` implements a basic text classifier. After embedding the text into a text field, we will optionally encode the embeddings with a ``Seq2SeqEncoder``. The resulting sequence is pooled using a ``Seq2VecEncoder`` and then passed to a linear classification layer, which projects into the label space. If a ``Seq2SeqEncoder`` is not provided, we will pass the embedded text directly to the ``Seq2VecEncoder``. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the input text into a ``TextField`` seq2seq_encoder : ``Seq2SeqEncoder``, optional (default=``None``) Optional Seq2Seq encoder layer for the input text. seq2vec_encoder : ``Seq2VecEncoder`` Required Seq2Vec encoder layer. If `seq2seq_encoder` is provided, this encoder will pool its output. Otherwise, this encoder will operate directly on the output of the `text_field_embedder`. dropout : ``float``, optional (default = ``None``) Dropout percentage to use. num_labels: ``int``, optional (default = ``None``) Number of labels to project to in classification layer. By default, the classification layer will project to the size of the vocabulary namespace corresponding to labels. label_namespace: ``str``, optional (default = "labels") Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) If provided, will be used to initialize the model parameters. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, seq2vec_encoder: Seq2VecEncoder, seq2seq_encoder: Seq2SeqEncoder = None, dropout: float = None, num_labels: int = None, label_namespace: str = "labels", initializer: InitializerApplicator = InitializerApplicator()) -> None: super().__init__(vocab) self._text_field_embedder = text_field_embedder if seq2seq_encoder: self._seq2seq_encoder = seq2seq_encoder else: self._seq2seq_encoder = None self._seq2vec_encoder = seq2vec_encoder self._classifier_input_dim = self._seq2vec_encoder.get_output_dim() if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = None self._label_namespace = label_namespace if num_labels: self._num_labels = num_labels else: self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace) self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels) self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalized log probabilities of the label. probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens).float() if self._seq2seq_encoder: embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) if self._dropout: embedded_text = self._dropout(embedded_text) logits = self._classification_layer(embedded_text) probs = torch.nn.functional.softmax(logits, dim=-1) output_dict = {"logits": logits, "probs": probs} if label is not None: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss self._accuracy(logits, label) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Does a simple argmax over the probabilities, converts index to string label, and add ``"label"`` key to the dictionary with the result. """ predictions = output_dict["probs"] if predictions.dim() == 2: predictions_list = [predictions[i] for i in range(predictions.shape[0])] else: predictions_list = [predictions] classes = [] for prediction in predictions_list: label_idx = prediction.argmax(dim=-1).item() label_str = (self.vocab.get_index_to_token_vocabulary(self._label_namespace) .get(label_idx, str(label_idx))) classes.append(label_str) output_dict["label"] = classes return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = {'accuracy': self._accuracy.get_metric(reset)} return metrics
class SyntacticEntailment(Model): """ This ``Model`` implements the Decomposable Attention model with Late Fusion. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the model. attend_feedforward : ``FeedForward`` This feedforward network is applied to the encoded sentence representations before the similarity matrix is computed between words in the premise and words in the hypothesis. similarity_function : ``SimilarityFunction`` This is the similarity function used when computing the similarity matrix between words in the premise and words in the hypothesis. compare_feedforward : ``FeedForward`` This feedforward network is applied to the aligned premise and hypothesis representations, individually. aggregate_feedforward : ``FeedForward`` This final feedforward network is applied to the concatenated, summed result of the ``compare_feedforward`` network, and its output is used as the entailment class logits. parser_model_path : str This specifies the filepath of the pretrained parser. parser_cuda_device : int The cuda device that the pretrained parser should run on. freeze_parser : bool Whether to allow the parser to be fine-tuned. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, attend_feedforward: FeedForward, similarity_function: SimilarityFunction, compare_feedforward: FeedForward, aggregate_feedforward: FeedForward, parser_model_path: str, parser_cuda_device: int, freeze_parser: bool, premise_encoder: Optional[Seq2SeqEncoder] = None, hypothesis_encoder: Optional[Seq2SeqEncoder] = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(SyntacticEntailment, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._attend_feedforward = TimeDistributed(attend_feedforward) self._attention = LegacyMatrixAttention(similarity_function) self._compare_feedforward = TimeDistributed(compare_feedforward) self._aggregate_feedforward = aggregate_feedforward self._premise_encoder = premise_encoder self._hypothesis_encoder = hypothesis_encoder or premise_encoder self._num_labels = vocab.get_vocab_size(namespace="labels") check_dimensions_match(text_field_embedder.get_output_dim(), attend_feedforward.get_input_dim(), "text field embedding dim", "attend feedforward input dim") check_dimensions_match(aggregate_feedforward.get_output_dim(), self._num_labels, "final output dimension", "number of labels") self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() self._parser = load_archive(parser_model_path, cuda_device=parser_cuda_device).model self._parser._head_sentinel.requires_grad = False for child in self._parser.children(): for param in child.parameters(): param.requires_grad = False if not freeze_parser: for param in self._parser.encoder.parameters(): param.requires_grad = True initializer(self) def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], premise_tags: torch.LongTensor, hypothesis: Dict[str, torch.LongTensor], hypothesis_tags: torch.LongTensor, label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` premise_tags : torch.LongTensor The POS tags of the premise. hypothesis : Dict[str, torch.LongTensor] From a ``TextField``. hypothesis_tags: torch.LongTensor The POS tags of the hypothesis. label : torch.IntTensor, optional, (default = None) From a ``LabelField``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder( embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat( [embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat( [embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward( hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze( -1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) # running the parser encoded_p_parse, p_parse_mask = self._parser(premise, premise_tags) p_parse_encoder_final_state = get_final_encoder_states( encoded_p_parse, p_parse_mask) encoded_h_parse, h_parse_mask = self._parser(hypothesis, hypothesis_tags) h_parse_encoder_final_state = get_final_encoder_states( encoded_h_parse, h_parse_mask) compared_premise = torch.cat( [compared_premise, p_parse_encoder_final_state], dim=-1) compared_hypothesis = torch.cat( [compared_hypothesis, h_parse_encoder_final_state], dim=-1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {'logits': label_logits, 'label_probs': label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict['loss'] = loss if metadata is not None: output_dict['premise_tokens'] = [ x['premise_tokens'] for x in metadata ] output_dict['hypothesis_tokens'] = [ x['hypothesis_tokens'] for x in metadata ] return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'accuracy': self._accuracy.get_metric(reset), }
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--vocab_file", default='bert-base-uncased-vocab.txt', type=str, required=True) parser.add_argument("--model_file", default='bert-base-uncased.tar.gz', type=str, required=True) parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model checkpoints and predictions will be written." ) parser.add_argument( "--predict_dir", default=None, type=str, required=True, help="The output directory where the predictions will be written.") # Other parameters parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json") parser.add_argument( "--predict_file", default=None, type=str, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json" ) parser.add_argument("--test_file", default=None, type=str) parser.add_argument( "--max_seq_length", default=384, type=int, help= "The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded." ) parser.add_argument( "--doc_stride", default=128, type=int, help= "When splitting up a long document into chunks, how much stride to take between chunks." ) parser.add_argument( "--max_query_length", default=64, type=int, help= "The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.") parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=2.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " "of training.") parser.add_argument( "--n_best_size", default=20, type=int, help= "The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument( "--max_answer_length", default=30, type=int, help= "The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") parser.add_argument( "--verbose_logging", default=False, action='store_true', help= "If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--view_id', type=int, default=1, help="view id of multi-view co-training(two-view)") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( "--do_lower_case", default=True, action='store_true', help= "Whether to lower case the input text. True for uncased models, False for cased models." ) parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument( '--fp16', default=False, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--save_all', default=False, action='store_true') # Base setting parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--max_ctx', type=int, default=2) parser.add_argument('--task_name', type=str, default='race') parser.add_argument('--bert_name', type=str, default='pool-race') parser.add_argument('--reader_name', type=str, default='race') parser.add_argument('--per_eval_step', type=int, default=10000000) # model parameters parser.add_argument('--evidence_lambda', type=float, default=0.8) # Parameters for running labeling model parser.add_argument('--do_label', default=False, action='store_true') parser.add_argument('--sentence_id_file', nargs='*') parser.add_argument('--weight_threshold', type=float, default=0.0) parser.add_argument('--only_correct', default=False, action='store_true') parser.add_argument('--label_threshold', type=float, default=0.0) parser.add_argument('--multi_evidence', default=False, action='store_true') parser.add_argument('--metric', default='accuracy', type=str) parser.add_argument('--num_evidence', default=1, type=int) parser.add_argument('--power_length', default=1., type=float) parser.add_argument('--num_choices', default=4, type=int) args = parser.parse_args() logger = setting_logger(args.output_dir) logger.info('================== Program start. ========================') model_params = prepare_model_params(args) read_params = prepare_read_params(args) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_predict and not args.do_label: raise ValueError( "At least one of `do_train` or `do_predict` or `do_label` must be True." ) if args.do_train: if not args.train_file: raise ValueError( "If `do_train` is True, then `train_file` must be specified.") if args.do_predict: if not args.predict_file: raise ValueError( "If `do_predict` is True, then `predict_file` must be specified." ) if args.do_train: if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError( "Output directory () already exists and is not empty.") os.makedirs(args.output_dir, exist_ok=True) if args.do_predict: os.makedirs(args.predict_dir, exist_ok=True) tokenizer = BertTokenizer.from_pretrained(args.vocab_file) data_reader = initialize_reader(args.reader_name) num_train_steps = None if args.do_train or args.do_label: train_examples = data_reader.read(input_file=args.train_file, **read_params) cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}'.format( args.bert_model, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length), str(args.max_ctx), str(args.task_name)) try: with open(cached_train_features_file, "rb") as reader: train_features = pickle.load(reader) except FileNotFoundError: train_features = data_reader.convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving train features into cached file %s", cached_train_features_file) with open(cached_train_features_file, "wb") as writer: pickle.dump(train_features, writer) num_train_steps = int( len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model if args.pretrain is not None: logger.info('Load pretrained model from {}'.format(args.pretrain)) model_state_dict = torch.load(args.pretrain, map_location='cuda:0') model = initialize_model(args.bert_name, args.model_file, state_dict=model_state_dict, **model_params) else: model = initialize_model(args.bert_name, args.model_file, **model_params) if args.fp16: model.half() model.to(device) if args.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) # hack to remove pooler, which is not used # thus it produce None grad that break apex param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] t_total = num_train_steps if num_train_steps is not None else -1 if args.local_rank != -1: t_total = t_total // torch.distributed.get_world_size() if args.fp16: try: from apex.optimizers import FP16_Optimizer from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, t_total=t_total) logger.info( f"warm up linear: warmup = {warmup_linear.warmup}, t_total = {warmup_linear.t_total}." ) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) # Prepare data eval_examples = data_reader.read(input_file=args.predict_file, **read_params) eval_features = data_reader.convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length) eval_tensors = data_reader.data_to_tensors(eval_features) eval_data = TensorDataset(*eval_tensors) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) if args.do_train: if args.do_label: logger.info('Training in State Wise.') sentence_label_file = args.sentence_id_file if sentence_label_file is not None: for file in sentence_label_file: train_features = data_reader.generate_features_sentence_ids( train_features, file) else: logger.info('No sentence id supervision is found.') else: logger.info('Training in traditional way.') logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Num train total optimization steps = %d", t_total) logger.info(" Batch size = %d", args.predict_batch_size) train_loss = AverageMeter() best_acc = 0.0 best_loss = 1000000 summary_writer = SummaryWriter(log_dir=args.output_dir) global_step = 0 eval_loss = AverageMeter() eval_accuracy = CategoricalAccuracy() eval_epoch = 0 train_tensors = data_reader.data_to_tensors(train_features) train_data = TensorDataset(*train_tensors) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) for epoch in range(int(args.num_train_epochs)): logger.info(f'Running at Epoch {epoch}') # Train for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", dynamic_ncols=True)): model.train() if n_gpu == 1: batch = batch_to_device( batch, device) # multi-gpu does scattering it-self inputs = data_reader.generate_inputs( batch, train_features, model_state=ModelState.Train) model_output = model(**inputs) loss = model_output['loss'] if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: # modify learning rate with special warm up BERT uses # if args.fp16 is False, BertAdam is used and handles this automatically if args.fp16: lr_this_step = args.learning_rate * warmup_linear.get_lr( global_step) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step summary_writer.add_scalar('lr', lr_this_step, global_step) else: summary_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) optimizer.step() optimizer.zero_grad() global_step += 1 train_loss.update(loss.item(), 1) summary_writer.add_scalar('train_loss', train_loss.avg, global_step) # logger.info(f'Train loss: {train_loss.avg}') if (step + 1) % args.per_eval_step == 0 or step == len( train_dataloader) - 1: # Evaluation model.eval() logger.info("Start evaluating") for _, eval_batch in enumerate( tqdm(eval_dataloader, desc="Evaluating", dynamic_ncols=True)): if n_gpu == 1: eval_batch = batch_to_device( eval_batch, device) # multi-gpu does scattering it-self inputs = data_reader.generate_inputs( eval_batch, eval_features, model_state=ModelState.Evaluate) with torch.no_grad(): output_dict = model(**inputs) loss, choice_logits = output_dict[ 'loss'], output_dict['choice_logits'] eval_loss.update(loss.item(), 1) eval_accuracy(choice_logits, inputs["labels"]) eval_epoch_loss = eval_loss.avg summary_writer.add_scalar('eval_loss', eval_epoch_loss, eval_epoch) eval_loss.reset() current_acc = eval_accuracy.get_metric(reset=True) summary_writer.add_scalar('eval_acc', current_acc, eval_epoch) torch.cuda.empty_cache() if args.save_all: model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, f"pytorch_model_{eval_epoch}.bin") torch.save(model_to_save.state_dict(), output_model_file) if current_acc > best_acc: best_acc = current_acc model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "pytorch_model.bin") torch.save(model_to_save.state_dict(), output_model_file) if eval_epoch_loss < best_loss: best_loss = eval_epoch_loss model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "pytorch_loss_model.bin") torch.save(model_to_save.state_dict(), output_model_file) logger.info( 'Eval Epoch: %d, Accuracy: %.4f (Best Accuracy: %.4f)' % (eval_epoch, current_acc, best_acc)) eval_epoch += 1 logger.info( f'Epoch {epoch}: Accuracy: {best_acc}, Train Loss: {train_loss.avg}' ) summary_writer.close() for output_model_name in ["pytorch_model.bin", "pytorch_loss_model.bin"]: # Loading trained model output_model_file = os.path.join(args.output_dir, output_model_name) model_state_dict = torch.load(output_model_file, map_location='cuda:0') model = initialize_model(args.bert_name, args.model_file, state_dict=model_state_dict, **model_params) model.to(device) # Write Yes/No predictions if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): test_examples = data_reader.read(args.test_file) test_features = data_reader.convert_examples_to_features( test_examples, tokenizer, args.max_seq_length) test_tensors = data_reader.data_to_tensors(test_features) test_data = TensorDataset(*test_tensors) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.predict_batch_size) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(test_examples)) logger.info(" Num split examples = %d", len(test_features)) logger.info(" Batch size = %d", args.predict_batch_size) model.eval() all_results = [] test_acc = CategoricalAccuracy() logger.info("Start predicting yes/no on Dev set.") for batch in tqdm(test_dataloader, desc="Testing"): if n_gpu == 1: batch = batch_to_device( batch, device) # multi-gpu does scattering it-self inputs = data_reader.generate_inputs( batch, test_features, model_state=ModelState.Evaluate) with torch.no_grad(): batch_choice_logits = model(**inputs)['choice_logits'] test_acc(batch_choice_logits, inputs['labels']) example_indices = batch[-1] for i, example_index in enumerate(example_indices): choice_logits = batch_choice_logits[i].detach().cpu( ).tolist() test_feature = test_features[example_index.item()] unique_id = int(test_feature.unique_id) all_results.append( RawResultChoice(unique_id=unique_id, choice_logits=choice_logits)) if "loss" in output_model_name: logger.info( 'Predicting question choice on test set using model with lowest loss on validation set.' ) output_prediction_file = os.path.join(args.predict_dir, 'loss_predictions.json') else: logger.info( 'Predicting question choice on test set using model with best accuracy on validation set,' ) output_prediction_file = os.path.join(args.predict_dir, 'predictions.json') data_reader.write_predictions(test_examples, test_features, all_results, output_prediction_file) logger.info( f"Accuracy on Test set: {test_acc.get_metric(reset=True)}") # Loading trained model. if args.metric == 'accuracy': logger.info("Load model with best accuracy on validation set.") output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") elif args.metric == 'loss': logger.info("Load model with lowest loss on validation set.") output_model_file = os.path.join(args.output_dir, "pytorch_loss_model.bin") else: raise RuntimeError( f"Wrong metric type for {args.metric}, which must be in ['accuracy', 'loss']." ) model_state_dict = torch.load(output_model_file, map_location='cuda:0') model = initialize_model(args.bert_name, args.model_file, state_dict=model_state_dict, **model_params) model.to(device) # Labeling sentence id. if args.do_label and (args.local_rank == -1 or torch.distributed.get_rank() == 0): f = open('debug_log.txt', 'w') def softmax(x): """Compute softmax values for each sets of scores in x.""" e_x = np.exp(x - np.max(x)) return e_x / e_x.sum() def beam_search(sentence_sim, beam_num=10): """ sentence_sim(numpy) """ max_length = args.num_evidence sentence_sim = np.pad(sentence_sim, (1, 0), 'constant', constant_values=(0, )) sentences = [{'sim': sentence_sim, 'sentences': [], 'value': 0.}] while sentences[0][ 'sentences'] == [] or sentences[0]['sentences'][-1] != 0: new_sentences = [] for sentence in sentences: if sentence['sentences'] != [] and sentence['sentences'][ -1] == 0: new_sentences.append(sentence) continue scores = softmax(sentence['sim']) for i in range(len(sentence['sim'])): if i == 0 and sentence['sentences'] == []: continue if len(sentence['sentences']) > max_length: continue if len(sentence['sentences']) == max_length and i != 0: continue if i in sentence['sentences']: continue if max_length == 1 and i == 0: value = sentence['value'] else: value = sentence['value'] + np.log(scores[i]) # `i - 1` refers to original sentence id new_sentence = { 'sim': np.copy(sentence['sim']), 'sentences': sentence['sentences'] + [i], 'value': value } new_sentence['sim'][i] = -1e15 new_sentences.append(new_sentence) sentences = sorted(new_sentences, key=lambda x: x['value'] / np.power( len(x['sentences']), args.power_length), reverse=True)[:beam_num] sentence = sentences[0] sentence['value'] = sentence['value'] / np.power( len(sentence['sentences']), args.power_length) print(sentence['value'], file=f, flush=True) return sentence def batch_choice_beam_search(sentence_sim, sentence_mask, beam_num=10) -> List[List[Dict]]: """ :param sentence_sim: [batch, num_choices, max_sen] -> torch.FloatTensor, device=cpu :param sentence_mask: [batch, num_choices, max_sen] -> torch.FloatTensor, device=cpu :param beam_num: int :return: batch * num_choices * num_evidences -> List[List[int]] """ batch_size = sentence_sim.size(0) num_choices = sentence_sim.size(1) sentence_sim = sentence_sim.numpy() + 1e-15 sentence_mask = sentence_mask.numpy() sentence_ids = [] for b in range(batch_size): choice_sentence_ids = [] for c in range(num_choices): choice_sentence_ids.append( beam_search( sentence_sim[b, c, :int(sum(sentence_mask[b, c]))], beam_num)) print( '=================single choice=====================', file=f, flush=True) sentence_ids.append(choice_sentence_ids) return sentence_ids test_examples = train_examples test_features = train_features test_tensors = data_reader.data_to_tensors(test_features) test_data = TensorDataset(*test_tensors) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.predict_batch_size) logger.info("***** Running labeling *****") logger.info(" Num orig examples = %d", len(test_examples)) logger.info(" Num split examples = %d", len(test_features)) logger.info(" Batch size = %d", args.predict_batch_size) model.eval() all_results = [] logger.info("Start labeling.") for batch in tqdm(test_dataloader, desc="Testing"): if n_gpu == 1: batch = batch_to_device(batch, device) inputs = data_reader.generate_inputs(batch, test_features, model_state=ModelState.Test) with torch.no_grad(): output_dict = model(**inputs) batch_choice_logits, batch_sentence_logits = output_dict[ "choice_logits"], output_dict["sentence_logits"] batch_sentence_mask = output_dict["sentence_mask"] example_indices = batch[-1] batch_beam_results = batch_choice_beam_search( batch_sentence_logits, batch_sentence_mask) for i, example_index in enumerate(example_indices): choice_logits = batch_choice_logits[i].detach().cpu() evidence_list = batch_beam_results[i] test_feature = test_features[example_index.item()] unique_id = int(test_feature.unique_id) all_results.append( RawOutput(unique_id=unique_id, model_output={ "choice_logits": choice_logits, "evidence_list": evidence_list })) output_prediction_file = os.path.join(args.predict_dir, 'sentence_id_file.json') data_reader.predict_sentence_ids( test_examples, test_features, all_results, output_prediction_file, weight_threshold=args.weight_threshold, only_correct=args.only_correct, label_threshold=args.label_threshold)
class AMTask(Model): """ A class that implements a task-specific model. It conceptually belongs to a formalism or corpus. """ def __init__(self, vocab: Vocabulary, name: str, edge_model: EdgeModel, loss_function: EdgeLoss, supertagger: Supertagger, lexlabeltagger: Supertagger, supertagger_loss: SupertaggingLoss, lexlabel_loss: SupertaggingLoss, output_null_lex_label: bool = True, loss_mixing: Dict[str, float] = None, dropout: float = 0.0, validation_evaluator: Optional[Evaluator] = None, regularizer: Optional[RegularizerApplicator] = None): super().__init__(vocab, regularizer) self.name = name self.edge_model = edge_model self.supertagger = supertagger self.lexlabeltagger = lexlabeltagger self.supertagger_loss = supertagger_loss self.lexlabel_loss = lexlabel_loss self.loss_function = loss_function self.loss_mixing = loss_mixing or dict() self.validation_evaluator = validation_evaluator self.output_null_lex_label = output_null_lex_label self._dropout = InputVariationalDropout(dropout) loss_names = [ "edge_existence", "edge_label", "supertagging", "lexlabel" ] for loss_name in loss_names: if loss_name not in self.loss_mixing: self.loss_mixing[loss_name] = 1.0 logger.info( f"Loss name {loss_name} not found in loss_mixing, using a weight of 1.0" ) else: if self.loss_mixing[loss_name] is None: if loss_name not in ["supertagging", "lexlabel"]: raise ConfigurationError( "Only the loss mixing coefficients for supertagging and lexlabel may be None, but not " + loss_name) not_contained = set(self.loss_mixing.keys()) - set(loss_names) if len(not_contained): logger.critical( f"The following loss name(s) are unknown: {not_contained}") raise ValueError( f"The following loss name(s) are unknown: {not_contained}") self._supertagging_acc = CategoricalAccuracy() self._lexlabel_acc = CategoricalAccuracy() self._attachment_scores = AttachmentScores() self.current_epoch = 0 tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = { tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE } self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info( f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") def check_all_dimensions_match(self, encoder_output_dim): check_dimensions_match(encoder_output_dim, self.edge_model.encoder_dim(), "encoder output dim", self.name + " input dim edge model") check_dimensions_match(encoder_output_dim, self.supertagger.encoder_dim(), "encoder output dim", self.name + " supertagger input dim") check_dimensions_match(encoder_output_dim, self.lexlabeltagger.encoder_dim(), "encoder output dim", self.name + " lexical label tagger input dim") @overrides def forward( self, # type: ignore encoded_text_parsing: torch.Tensor, encoded_text_tagging: torch.Tensor, mask: torch.Tensor, pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], supertags: torch.LongTensor = None, lexlabels: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: """ Takes a batch of encoded sentences and returns a dictionary with loss and predictions. :param encoded_text_parsing: sentence representation of shape (batch_size, seq_len, encoder_output_dim) :param encoded_text_tagging: sentence representation of shape (batch_size, seq_len, encoder_output_dim) or None if formalism of batch doesn't need supertagging :param mask: matching the sentence representation of shape (batch_size, seq_len) :param pos_tags: the accompanying pos tags (batch_size, seq_len) :param metadata: :param supertags: the accompanying supertags (batch_size, seq_len) :param lexlabels: the accompanying lexical labels (batch_size, seq_len) :param head_tags: the gold heads of every word (batch_size, seq_len) :param head_indices: the gold edge labels for each word (incoming edge, see amconll files) (batch_size, seq_len) :return: """ encoded_text_parsing = self._dropout(encoded_text_parsing) if encoded_text_tagging is not None: encoded_text_tagging = self._dropout(encoded_text_tagging) batch_size, seq_len, _ = encoded_text_parsing.shape edge_existence_scores = self.edge_model.edge_existence( encoded_text_parsing, mask) # shape (batch_size, seq_len, seq_len) # shape (batch_size, seq_len, num_supertags) if encoded_text_tagging is not None: supertagger_logits = self.supertagger.compute_logits( encoded_text_tagging) lexlabel_logits = self.lexlabeltagger.compute_logits( encoded_text_tagging ) # shape (batch_size, seq_len, num label tags) else: supertagger_logits = None lexlabel_logits = None # Make predictions on data: if self.training: predicted_heads = self._greedy_decode_arcs(edge_existence_scores, mask) edge_label_logits = self.edge_model.label_scores( encoded_text_parsing, predicted_heads ) # shape (batch_size, seq_len, num edge labels) predicted_edge_labels = self._greedy_decode_edge_labels( edge_label_logits) else: # Find best tree with CLE predicted_heads = cle_decode(edge_existence_scores, mask.data.sum(dim=1).long()) # With info about tree structure, get edge label scores edge_label_logits = self.edge_model.label_scores( encoded_text_parsing, predicted_heads) # Predict edge labels predicted_edge_labels = self._greedy_decode_edge_labels( edge_label_logits) output_dict = { "heads": predicted_heads, "edge_existence_scores": edge_existence_scores, "label_logits": edge_label_logits, # shape (batch_size, seq_len, num edge labels) "full_label_logits": self.edge_model.full_label_scores( encoded_text_parsing ), #these are mostly required for the projective decoder "mask": mask, "words": [meta["words"] for meta in metadata], "attributes": [meta["attributes"] for meta in metadata], "token_ranges": [meta["token_ranges"] for meta in metadata], "encoded_text_parsing": encoded_text_parsing, "encoded_text_tagging": encoded_text_tagging, "position_in_corpus": [meta["position_in_corpus"] for meta in metadata], "formalism": self.name } if encoded_text_tagging is not None and self.loss_mixing[ "supertagging"] is not None: output_dict[ "supertag_scores"] = supertagger_logits # shape (batch_size, seq_len, num supertags) output_dict["best_supertags"] = Supertagger.top_k_supertags( supertagger_logits, 1).squeeze(2) # shape (batch_size, seq_len) if encoded_text_tagging is not None and self.loss_mixing[ "lexlabel"] is not None: if not self.output_null_lex_label: bottom_lex_label_index = self.vocab.get_token_index( "_", namespace=self.name + "_lex_labels") masked_lexlabel_logits = lexlabel_logits.clone().detach( ) # shape (batch_size, seq_len, num label tags) masked_lexlabel_logits[:, :, bottom_lex_label_index] = -1e20 else: masked_lexlabel_logits = lexlabel_logits output_dict["lexlabels"] = Supertagger.top_k_supertags( masked_lexlabel_logits, 1).squeeze(2) # shape (batch_size, seq_len) is_annotated = metadata[0]["is_annotated"] if any(metadata[i]["is_annotated"] != is_annotated for i in range(batch_size)): raise ValueError( "Batch contained inconsistent information if data is annotated." ) # Compute loss: if is_annotated and head_indices is not None and head_tags is not None: gold_edge_label_logits = self.edge_model.label_scores( encoded_text_parsing, head_indices) edge_label_loss = self.loss_function.label_loss( gold_edge_label_logits, mask, head_tags) edge_existence_loss = self.loss_function.edge_existence_loss( edge_existence_scores, head_indices, mask) # compute loss, remove loss for artificial root if encoded_text_tagging is not None and self.loss_mixing[ "supertagging"] is not None: supertagger_logits = supertagger_logits[:, 1:, :].contiguous() supertagging_nll = self.supertagger_loss.loss( supertagger_logits, supertags, mask[:, 1:]) else: supertagging_nll = None if encoded_text_tagging is not None and self.loss_mixing[ "lexlabel"] is not None: lexlabel_logits = lexlabel_logits[:, 1:, :].contiguous() lexlabel_nll = self.lexlabel_loss.loss(lexlabel_logits, lexlabels, mask[:, 1:]) else: lexlabel_nll = None loss = self.loss_mixing[ "edge_existence"] * edge_existence_loss + self.loss_mixing[ "edge_label"] * edge_label_loss if supertagging_nll is not None: loss += self.loss_mixing["supertagging"] * supertagging_nll if lexlabel_nll is not None: loss += self.loss_mixing["lexlabel"] * lexlabel_nll # Compute LAS/UAS/Supertagging acc/Lex label acc: evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attachment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. if edge_existence_loss is not None and edge_label_loss is not None: self._attachment_scores(predicted_heads[:, 1:], predicted_edge_labels[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) if supertagging_nll is not None: self._supertagging_acc(supertagger_logits, supertags, mask[:, 1:]) # compare against gold data if lexlabel_nll is not None: self._lexlabel_acc(lexlabel_logits, lexlabels, mask[:, 1:]) # compare against gold data output_dict["arc_loss"] = edge_existence_loss output_dict["edge_label_loss"] = edge_label_loss output_dict["supertagging_loss"] = supertagging_nll output_dict["lexlabel_loss"] = lexlabel_nll output_dict["loss"] = loss return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]): """ In contrast to its name, this function does not perform the decoding but only prepares it. Therefore, we take the result of forward and perform the following steps (for each sentence in batch): - remove padding - identifiy the root of the sentence, group other root-candidates under the proper root - collect a selection of supertags to speed up computation (top k selection is done later) :param output_dict: result of forward :return: output_dict with the following keys added: - lexlabels: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root) - supertags: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root) """ best_supertags = output_dict.pop( "best_supertags").cpu().detach().numpy() supertag_scores = output_dict.pop( "supertag_scores") # shape (batch_size, seq_len, num supertags) full_label_logits = output_dict.pop("full_label_logits").cpu().detach( ).numpy() #shape (batch size, seq len, seq len, num edge labels) edge_existence_scores = output_dict.pop( "edge_existence_scores").cpu().detach().numpy( ) #shape (batch size, seq len, seq len, num edge labels) k = 10 if self.validation_evaluator: #retrieve k supertags from validation evaluator. if isinstance(self.validation_evaluator.predictor, AMconllPredictor): k = self.validation_evaluator.predictor.k top_k_supertags = Supertagger.top_k_supertags( supertag_scores, k).cpu().detach().numpy() # shape (batch_size, seq_len, k) supertag_scores = supertag_scores.cpu().detach().numpy() lexlabels = output_dict.pop( "lexlabels").cpu().detach().numpy() #shape (batch_size, seq_len) heads = output_dict.pop("heads") heads_cpu = heads.cpu().detach().numpy() mask = output_dict.pop("mask") edge_label_logits = output_dict.pop("label_logits").cpu().detach( ).numpy() # shape (batch_size, seq_len, num edge labels) encoded_text_parsing = output_dict.pop("encoded_text_parsing") output_dict.pop("encoded_text_tagging") #don't need that lengths = get_lengths_from_binary_sequence_mask(mask) #here we collect things, in the end we will have one entry for each sentence: all_edge_label_logits = [] all_supertags = [] head_indices = [] roots = [] all_predicted_lex_labels = [] all_full_label_logits = [] all_edge_existence_scores = [] all_supertag_scores = [] #we need the following to identify the root root_edge_label_id = self.vocab.get_token_index("ROOT", namespace=self.name + "_head_tags") bot_id = self.vocab.get_token_index(AMSentence.get_bottom_supertag(), namespace=self.name + "_supertag_labels") for i, length in enumerate(lengths): instance_heads_cpu = list(heads_cpu[i, 1:length]) #Postprocess heads and find root of sentence: instance_heads_cpu, root = find_root( instance_heads_cpu, best_supertags[i, 1:length], edge_label_logits[i, 1:length, :], root_edge_label_id, bot_id, modify=True) roots.append(root) #apply changes to instance_heads tensor: instance_heads = heads[i, :] for j, x in enumerate(instance_heads_cpu): instance_heads[j + 1] = torch.tensor( x ) #+1 because we removed the first position from instance_heads_cpu # re-calculate edge label logits since heads might have changed: label_logits = self.edge_model.label_scores( encoded_text_parsing[i].unsqueeze(0), instance_heads.unsqueeze(0)).squeeze(0).detach().cpu().numpy() #(un)squeeze: fake batch dimension all_edge_label_logits.append(label_logits[1:length, :]) all_full_label_logits.append( full_label_logits[i, :length, :length, :]) all_edge_existence_scores.append( edge_existence_scores[i, :length, :length]) #calculate supertags for this sentence: all_supertag_scores.append(supertag_scores[ i, 1:length, :]) #new shape (sent length, num supertags) supertags_for_this_sentence = [] for word in range(1, length): supertags_for_this_word = [] for top_k in top_k_supertags[i, word]: fragment, typ = AMSentence.split_supertag( self.vocab.get_token_from_index(top_k, namespace=self.name + "_supertag_labels")) score = supertag_scores[i, word, top_k] supertags_for_this_word.append((score, fragment, typ)) if bot_id not in top_k_supertags[ i, word]: #\bot is not in the top k, but we have to add it anyway in order for the decoder to work properly. fragment, typ = AMSentence.split_supertag( AMSentence.get_bottom_supertag()) supertags_for_this_word.append( (supertag_scores[i, word, bot_id], fragment, typ)) supertags_for_this_sentence.append(supertags_for_this_word) all_supertags.append(supertags_for_this_sentence) all_predicted_lex_labels.append([ self.vocab.get_token_from_index(label, namespace=self.name + "_lex_labels") for label in lexlabels[i, 1:length] ]) head_indices.append(instance_heads_cpu) output_dict["lexlabels"] = all_predicted_lex_labels output_dict["supertags"] = all_supertags output_dict["root"] = roots output_dict["label_logits"] = all_edge_label_logits output_dict["predicted_heads"] = head_indices output_dict["full_label_logits"] = all_full_label_logits output_dict["edge_existence_scores"] = all_edge_existence_scores output_dict["supertag_scores"] = all_supertag_scores return output_dict def _greedy_decode_edge_labels( self, edge_label_logits: torch.Tensor) -> torch.Tensor: """ Assigns edge labels according to (existing) edges. Parameters ---------- edge_label_logits: ``torch.Tensor`` of shape (batch_size, sequence_length, num_head_tags) Returns ------- head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded head tags (labels of incoming edges) of each word. """ _, head_tags = edge_label_logits.max(dim=2) return head_tags def _greedy_decode_arcs(self, existence_scores: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Decodes the head predictions by decoding the unlabeled arcs independently for each word. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- existence_scores : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. mask: torch.Tensor, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. existence_scores = existence_scores + torch.diag( existence_scores.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).byte().unsqueeze(2) existence_scores.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = existence_scores.max(dim=2) return heads def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask def metrics(self, parser_model, reset: bool = False, model_path=None) -> Dict[str, float]: """ Is called by a GraphDependencyParser :param parser_model: a GraphDependencyParser :param reset: :return: """ r = self.get_metrics(reset) if reset: #epoch done if self.training: #done on the training data self.current_epoch += 1 else: #done on dev/test data if self.validation_evaluator: metrics = self.validation_evaluator.eval( parser_model, self.current_epoch, model_path) for name, val in metrics.items(): r[name] = val return r def get_metrics(self, reset: bool = False) -> Dict[str, float]: r = self._attachment_scores.get_metric(reset) if self.loss_mixing["supertagging"] is not None: r["Constant_Acc"] = self._supertagging_acc.get_metric(reset) if self.loss_mixing["lexlabel"] is not None: r["Label_Acc"] = self._lexlabel_acc.get_metric(reset) return r
class SeqClassificationModel(Model): """ Question answering model where answers are sentences """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, use_sep: bool = True, with_crf: bool = False, self_attn: Seq2SeqEncoder = None, bert_dropout: float = 0.1, sci_sum: bool = False, additional_feature_size: int = 0, ) -> None: super(SeqClassificationModel, self).__init__(vocab) self.text_field_embedder = text_field_embedder self.vocab = vocab self.use_sep = use_sep self.with_crf = with_crf self.sci_sum = sci_sum self.self_attn = self_attn self.additional_feature_size = additional_feature_size self.dropout = torch.nn.Dropout(p=bert_dropout) # define loss if self.sci_sum: self.loss = torch.nn.MSELoss( reduction='none') # labels are rouge scores self.labels_are_scores = True self.num_labels = 1 else: self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none') self.labels_are_scores = False self.num_labels = self.vocab.get_vocab_size(namespace='labels') # define accuracy metrics self.label_accuracy = CategoricalAccuracy() self.label_f1_metrics = {} # define F1 metrics per label for label_index in range(self.num_labels): label_name = self.vocab.get_token_from_index( namespace='labels', index=label_index) self.label_f1_metrics[label_name] = F1Measure(label_index) encoded_senetence_dim = text_field_embedder._token_embedders[ 'bert'].output_dim ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim( ) ff_in_dim += self.additional_feature_size self.time_distributed_aggregate_feedforward = TimeDistributed( Linear(ff_in_dim, self.num_labels)) if self.with_crf: self.crf = ConditionalRandomField( self.num_labels, constraints=None, include_start_end_transitions=True) def forward( self, # type: ignore sentences: torch.LongTensor, labels: torch.IntTensor = None, confidences: torch.Tensor = None, additional_features: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- TODO: add description Returns ------- An output dictionary consisting of: loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # =========================================================================================================== # Layer 1: For each sentence, participant pair: create a Glove embedding for each token # Input: sentences # Output: embedded_sentences # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size embedded_sentences = self.text_field_embedder(sentences) mask = get_text_field_mask(sentences, num_wrapping_dims=1).float() batch_size, num_sentences, _, _ = embedded_sentences.size() if self.use_sep: # The following code collects vectors of the SEP tokens from all the examples in the batch, # and arrange them in one list. It does the same for the labels and confidences. # TODO: replace 103 with '[SEP]' sentences_mask = sentences[ 'bert'] == 102 # mask for all the SEP tokens in the batch embedded_sentences = embedded_sentences[ sentences_mask] # given batch_size x num_sentences_per_example x sent_len x vector_len # returns num_sentences_per_batch x vector_len assert embedded_sentences.dim() == 2 num_sentences = embedded_sentences.shape[0] # for the rest of the code in this model to work, think of the data we have as one example # with so many sentences and a batch of size 1 batch_size = 1 embedded_sentences = embedded_sentences.unsqueeze(dim=0) embedded_sentences = self.dropout(embedded_sentences) if labels is not None: if self.labels_are_scores: labels_mask = labels != 0.0 # mask for all the labels in the batch (no padding) else: labels_mask = labels != -1 # mask for all the labels in the batch (no padding) labels = labels[ labels_mask] # given batch_size x num_sentences_per_example return num_sentences_per_batch assert labels.dim() == 1 if confidences is not None: confidences = confidences[labels_mask] assert confidences.dim() == 1 if additional_features is not None: additional_features = additional_features[labels_mask] assert additional_features.dim() == 2 num_labels = labels.shape[0] if num_labels != num_sentences: # bert truncates long sentences, so some of the SEP tokens might be gone assert num_labels > num_sentences # but `num_labels` should be at least greater than `num_sentences` logger.warning( f'Found {num_labels} labels but {num_sentences} sentences' ) labels = labels[: num_sentences] # Ignore some labels. This is ok for training but bad for testing. # We are ignoring this problem for now. # TODO: fix, at least for testing # do the same for `confidences` if confidences is not None: num_confidences = confidences.shape[0] if num_confidences != num_sentences: assert num_confidences > num_sentences confidences = confidences[:num_sentences] # and for `additional_features` if additional_features is not None: num_additional_features = additional_features.shape[0] if num_additional_features != num_sentences: assert num_additional_features > num_sentences additional_features = additional_features[: num_sentences] # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1 labels = labels.unsqueeze(dim=0) if confidences is not None: confidences = confidences.unsqueeze(dim=0) if additional_features is not None: additional_features = additional_features.unsqueeze(dim=0) else: # ['CLS'] token embedded_sentences = embedded_sentences[:, :, 0, :] embedded_sentences = self.dropout(embedded_sentences) batch_size, num_sentences, _ = embedded_sentences.size() sent_mask = (mask.sum(dim=2) != 0) embedded_sentences = self.self_attn(embedded_sentences, sent_mask) if additional_features is not None: embedded_sentences = torch.cat( (embedded_sentences, additional_features), dim=-1) label_logits = self.time_distributed_aggregate_feedforward( embedded_sentences) # label_logits: batch_size, num_sentences, num_labels if self.labels_are_scores: label_probs = label_logits else: label_probs = torch.nn.functional.softmax(label_logits, dim=-1) # Create output dictionary for the trainer # Compute loss and epoch metrics output_dict = {"action_probs": label_probs} # ===================================================================== if self.with_crf: # Layer 4 = CRF layer across labels of sentences in an abstract mask_sentences = (labels != -1) best_paths = self.crf.viterbi_tags(label_logits, mask_sentences) # # # Just get the tags and ignore the score. predicted_labels = [x for x, y in best_paths] # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}") label_loss = 0.0 if labels is not None: # Compute cross entropy loss flattened_logits = label_logits.view((batch_size * num_sentences), self.num_labels) flattened_gold = labels.contiguous().view(-1) if not self.with_crf: label_loss = self.loss(flattened_logits.squeeze(), flattened_gold) if confidences is not None: label_loss = label_loss * confidences.type_as( label_loss).view(-1) label_loss = label_loss.mean() flattened_probs = torch.softmax(flattened_logits, dim=-1) else: clamped_labels = torch.clamp(labels, min=0) log_likelihood = self.crf(label_logits, clamped_labels, mask_sentences) label_loss = -log_likelihood # compute categorical accuracy crf_label_probs = label_logits * 0. for i, instance_labels in enumerate(predicted_labels): for j, label_id in enumerate(instance_labels): crf_label_probs[i, j, label_id] = 1 flattened_probs = crf_label_probs.view( (batch_size * num_sentences), self.num_labels) if not self.labels_are_scores: evaluation_mask = (flattened_gold != -1) self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask) # compute F1 per label for label_index in range(self.num_labels): label_name = self.vocab.get_token_from_index( namespace='labels', index=label_index) metric = self.label_f1_metrics[label_name] metric(flattened_probs, flattened_gold, mask=evaluation_mask) if labels is not None: output_dict["loss"] = label_loss output_dict['action_logits'] = label_logits return output_dict def get_metrics(self, reset: bool = False): metric_dict = {} if not self.labels_are_scores: type_accuracy = self.label_accuracy.get_metric(reset) metric_dict['acc'] = type_accuracy average_F1 = 0.0 for name, metric in self.label_f1_metrics.items(): metric_val = metric.get_metric(reset) metric_dict[name + 'F'] = metric_val[2] average_F1 += metric_val[2] average_F1 /= len(self.label_f1_metrics.items()) metric_dict['avgF'] = average_F1 return metric_dict
class BidirectionalAttentionFlow_1(Model): """ This class implements a Bayesian version of Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). """ def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None: super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer) """ Initialize some data structures """ self.cf_a = cf_a # Bayesian data models self.VBmodels = [] self.LinearModels = [] """ ############## TEXT FIELD EMBEDDER with ELMO #################### text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. """ if (cf_a.use_ELMO): if (type(preloaded_elmo) != type(None)): text_field_embedder = preloaded_elmo else: text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput ) print ("ELMO loaded from disk or downloaded") else: text_field_embedder = None # embedder_out_dim = text_field_embedder.get_output_dim() self._text_field_embedder = text_field_embedder if(cf_a.Add_Linear_projection_ELMO): if (self.cf_a.VB_Linear_projection_ELMO): prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior)) print ("----------------- Bayesian Linear Projection ELMO --------------") linear_projection_ELMO = LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior) self.VBmodels.append(linear_projection_ELMO) else: linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200) self._linear_projection_ELMO = linear_projection_ELMO """ ############## Highway layers #################### num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. """ Input_dimension_highway = None if (cf_a.Add_Linear_projection_ELMO): Input_dimension_highway = 200 else: Input_dimension_highway = text_field_embedder.get_output_dim() num_highway_layers = cf_a.num_highway_layers # Linear later to compute the start if (self.cf_a.VB_highway_layers): print ("----------------- Bayesian Highway network --------------") prior = Vil.Prior(**(cf_a.VB_highway_layers_prior)) highway_layer = HighwayVB(Input_dimension_highway, num_highway_layers, prior = prior) self.VBmodels.append(highway_layer) else: highway_layer = Highway(Input_dimension_highway, num_highway_layers) highway_layer = TimeDistributed(highway_layer) self._highway_layer = highway_layer """ ############## Phrase layer #################### phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. """ if cf_a.phrase_layer_dropout > 0: ## Create dropout layer dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout) else: dropout_phrase_layer = lambda x: x phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout)) phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2 self._phrase_layer = phrase_layer self._dropout_phrase_layer = dropout_phrase_layer """ ############## Matrix attention layer #################### similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. """ # Linear later to compute the start if (self.cf_a.VB_similarity_function): prior = Vil.Prior(**(cf_a.VB_similarity_function_prior)) print ("----------------- Bayesian Similarity matrix --------------") similarity_function = LinearSimilarityVB( combination = "x,y,x*y", tensor_1_dim = phrase_encoding_out_dim, tensor_2_dim = phrase_encoding_out_dim, prior = prior) self.VBmodels.append(similarity_function) else: similarity_function = LinearSimilarity( combination = "x,y,x*y", tensor_1_dim = phrase_encoding_out_dim, tensor_2_dim = phrase_encoding_out_dim) matrix_attention = LegacyMatrixAttention(similarity_function) self._matrix_attention = matrix_attention """ ############## Modelling Layer #################### modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. """ ## Create dropout layer if cf_a.modeling_passage_dropout > 0: ## Create dropout layer dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout) else: dropout_modeling_passage = lambda x: x modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout)) self._modeling_layer = modeling_layer self._dropout_modeling_passage = dropout_modeling_passage """ ############## Span Start Representation ##################### span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. """ encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim # Linear later to compute the start if (self.cf_a.VB_span_start_predictor_linear): prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior)) print ("----------------- Bayesian Span Start Predictor--------------") span_start_predictor_linear = LinearVB(span_start_input_dim, 1, prior = prior) self.VBmodels.append(span_start_predictor_linear) else: span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1) self._span_start_predictor_linear = span_start_predictor_linear self._span_start_predictor = TimeDistributed(span_start_predictor_linear) """ ############## Span End Representation ##################### """ ## Create dropout layer if cf_a.span_end_encoder_dropout > 0: dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout) else: dropout_span_end_encode = lambda x: x span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, batch_first=True, bidirectional = True, num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_encoder = span_end_encoder self._dropout_span_end_encode = dropout_span_end_encode if (self.cf_a.VB_span_end_predictor_linear): print ("----------------- Bayesian Span End Predictor--------------") prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior)) span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior) self.VBmodels.append(span_end_predictor_linear) else: span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1) self._span_end_predictor_linear = span_end_predictor_linear self._span_end_predictor = TimeDistributed(span_end_predictor_linear) """ Dropput last layers """ if cf_a.spans_output_dropout > 0: dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout) else: dropout_spans_output = lambda x: x self._dropout_spans_output = dropout_spans_output """ Checkings and accuracy """ # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() """ mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. """ self._mask_lstms = cf_a.mask_lstms """ ################### Initialize parameters ############################## """ #### THEY ARE ALL INITIALIZED WHEN INSTANTING THE COMPONENTS ### """ ####################### OPTIMIZER ################ """ optimizer = pytut.get_optimizers(self, cf_a) self._optimizer = optimizer #### TODO: Learning rate scheduler #### #scheduler = optim.ReduceLROnPlateau(optimizer, 'max') def forward_ensemble(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False) -> Dict[str, torch.Tensor]: """ Sample 10 times and add them together """ self.set_posterior_mean(True) most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information) self.set_posterior_mean(False) subresults = [most_likely_output] for i in range(10): subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)) batch_size = len(subresults[0]["best_span"]) best_span = bidut.merge_span_probs(subresults) output = { "best_span": best_span, "best_span_str": [], "models_output": subresults } if (get_sample_level_information): output["em_samples"] = [] output["f1_samples"] = [] for index in range(batch_size): if metadata is not None: passage_str = metadata[index]['original_passage'] offsets = metadata[index]['token_offsets'] predicted_span = tuple(best_span[index].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output["best_span_str"].append(best_span_string) answer_texts = metadata[index].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output["em_samples"].append(em_sample) output["f1_samples"].append(f1_sample) if (get_sample_level_information): # Add information about the individual samples for future analysis output["span_start_sample_loss"] = [] output["span_end_sample_loss"] = [] for i in range (batch_size): span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]]) output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) return output def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, get_sample_level_information = False, get_attentions = False) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ """ #################### Sample Bayesian weights ################## """ self.sample_posterior() """ ################## MASK COMPUTING ######################## """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None """ ###################### EMBEDDING + HIGHWAY LAYER ######################## """ # self.cf_a.use_ELMO if(self.cf_a.Add_Linear_projection_ELMO): embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])) embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])) else: embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]) embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) """ ###################### phrase_layer LAYER ######################## """ encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) """ ###################### Attention LAYER ######################## """ # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) """ ###################### Spans LAYER ######################## """ # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout_spans_output(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout_spans_output(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = bidut.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss output_dict["span_start_loss"] = span_start_loss output_dict["span_end_loss"] = span_end_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: if (get_sample_level_information): output_dict["em_samples"] = [] output_dict["f1_samples"] = [] output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) if (get_sample_level_information): em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts) output_dict["em_samples"].append(em_sample) output_dict["f1_samples"].append(f1_sample) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens if (get_sample_level_information): # Add information about the individual samples for future analysis output_dict["span_start_sample_loss"] = [] output_dict["span_end_sample_loss"] = [] for i in range (batch_size): span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]]) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]]) output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy())) output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy())) if(get_attentions): output_dict["C2Q_attention"] = passage_question_attention output_dict["Q2C_attention"] = question_passage_attention output_dict["simmilarity"] = passage_question_similarity return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } def train_batch(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ It is enough to just compute the total loss because the normal weights do not depend on the KL Divergence """ # Now we can just compute both losses which will build the dynamic graph output = self.forward(question,passage,span_start,span_end,metadata ) data_loss = output["loss"] KL_div = self.get_KL_divergence() total_loss = self.combine_losses(data_loss, KL_div) self.zero_grad() # zeroes the gradient buffers of all parameters total_loss.backward() if (type(self._optimizer) == type(None)): parameters = filter(lambda p: p.requires_grad, self.parameters()) with torch.no_grad(): for f in parameters: f.data.sub_(f.grad.data * self.lr ) else: # print ("Training") self._optimizer.step() self._optimizer.zero_grad() return output def fill_batch_training_information(self, training_logger, output_batch): """ Function to fill the the training_logger for each batch. training_logger: Dictionary that will hold all the training info output_batch: Output from training the batch """ training_logger["train"]["span_start_loss_batch"].append(output_batch["span_start_loss"].detach().cpu().numpy()) training_logger["train"]["span_end_loss_batch"].append(output_batch["span_end_loss"].detach().cpu().numpy()) training_logger["train"]["loss_batch"].append(output_batch["loss"].detach().cpu().numpy()) # Training metrics: metrics = self.get_metrics() training_logger["train"]["start_acc_batch"].append(metrics["start_acc"]) training_logger["train"]["end_acc_batch"].append(metrics["end_acc"]) training_logger["train"]["span_acc_batch"].append(metrics["span_acc"]) training_logger["train"]["em_batch"].append(metrics["em"]) training_logger["train"]["f1_batch"].append(metrics["f1"]) def fill_epoch_training_information(self, training_logger,device, validation_iterable, num_batches_validation): """ Fill the information per each epoch """ Ntrials_CUDA = 100 # Training Epoch final metrics metrics = self.get_metrics(reset = True) training_logger["train"]["start_acc"].append(metrics["start_acc"]) training_logger["train"]["end_acc"].append(metrics["end_acc"]) training_logger["train"]["span_acc"].append(metrics["span_acc"]) training_logger["train"]["em"].append(metrics["em"]) training_logger["train"]["f1"].append(metrics["f1"]) self.set_posterior_mean(True) self.eval() data_loss_validation = 0 loss_validation = 0 with torch.no_grad(): # Compute the validation accuracy by using all the Validation dataset but in batches. for j in range(num_batches_validation): tensor_dict = next(validation_iterable) trial_index = 0 while (1): try: tensor_dict = pytut.move_to_device(tensor_dict, device) ## Move the tensor to cuda output_batch = self.forward(**tensor_dict) break; except RuntimeError as er: print (er.args) torch.cuda.empty_cache() time.sleep(5) torch.cuda.empty_cache() trial_index += 1 if (trial_index == Ntrials_CUDA): print ("Too many failed trials to allocate in memory") send_error_email(str(er.args)) sys.exit(0) data_loss_validation += output_batch["loss"].detach().cpu().numpy() ## Memmory management !! if (self.cf_a.force_free_batch_memory): del tensor_dict["question"]; del tensor_dict["passage"] del tensor_dict del output_batch torch.cuda.empty_cache() if (self.cf_a.force_call_garbage_collector): gc.collect() data_loss_validation = data_loss_validation/num_batches_validation # loss_validation = loss_validation/num_batches_validation # Training Epoch final metrics metrics = self.get_metrics(reset = True) training_logger["validation"]["start_acc"].append(metrics["start_acc"]) training_logger["validation"]["end_acc"].append(metrics["end_acc"]) training_logger["validation"]["span_acc"].append(metrics["span_acc"]) training_logger["validation"]["em"].append(metrics["em"]) training_logger["validation"]["f1"].append(metrics["f1"]) training_logger["validation"]["data_loss"].append(data_loss_validation) self.train() self.set_posterior_mean(False) def trim_model(self, mu_sigma_ratio = 2): total_size_w = [] total_removed_w = [] total_size_b = [] total_removed_b = [] if (self.cf_a.VB_Linear_projection_ELMO): VBmodel = self._linear_projection_ELMO size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_highway_layers): VBmodel = self._highway_layer._module.VBmodels[0] Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_similarity_function): VBmodel = self._matrix_attention._similarity_function Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_span_start_predictor_linear): VBmodel = self._span_start_predictor_linear Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) if (self.cf_a.VB_span_end_predictor_linear): VBmodel = self._span_end_predictor_linear Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel, mu_sigma_ratio) total_size_w.append(size_w) total_removed_w.append(removed_w) total_size_b.append(size_b) total_removed_b.append(removed_b) return total_size_w, total_removed_w, total_size_b, total_removed_b # print (weights_to_remove_W.shape) """ BAYESIAN NECESSARY FUNCTIONS """ sample_posterior = GeneralVBModel.sample_posterior get_KL_divergence = GeneralVBModel.get_KL_divergence set_posterior_mean = GeneralVBModel.set_posterior_mean combine_losses = GeneralVBModel.combine_losses def save_VB_weights(self): """ Function that saves only the VB weights of the model. """ pretrained_dict = ... model_dict = self.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict self.load_state_dict(pretrained_dict)
class LstmClassifier(Model): def __init__(self, embedder: TextFieldEmbedder, encoder: Seq2VecEncoder, vocab: Vocabulary, positive_label: str = '4') -> None: super().__init__(vocab) # We need the embeddings to convert word IDs to their vector representations self.embedder = embedder self.encoder = encoder # After converting a sequence of vectors to a single vector, we feed it into # a fully-connected linear layer to reduce the dimension to the total number of labels. self.linear = torch.nn.Linear( in_features=encoder.get_output_dim(), out_features=vocab.get_vocab_size('labels')) # Monitor the metrics - we use accuracy, as well as prec, rec, f1 for 4 (very positive) positive_index = vocab.get_token_index(positive_label, namespace='labels') self.accuracy = CategoricalAccuracy() self.f1_measure = F1Measure(positive_index) # We use the cross entropy loss because this is a classification task. # Note that PyTorch's CrossEntropyLoss combines softmax and log likelihood loss, # which makes it unnecessary to add a separate softmax layer. self.loss_function = torch.nn.CrossEntropyLoss() # Instances are fed to forward after batching. # Fields are passed through arguments with the same name. def forward(self, tokens: TextFieldTensors, label: torch.Tensor = None) -> torch.Tensor: # In deep NLP, when sequences of tensors in different lengths are batched together, # shorter sequences get padded with zeros to make them equal length. # Masking is the process to ignore extra zeros added by padding mask = get_text_field_mask(tokens) # Forward pass embeddings = self.embedder(tokens) encoder_out = self.encoder(embeddings, mask) logits = self.linear(encoder_out) probs = torch.softmax(logits, dim=-1) # In AllenNLP, the output of forward() is a dictionary. # Your output dictionary must contain a "loss" key for your model to be trained. output = {"logits": logits, "cls_emb": encoder_out, "probs": probs} if label is not None: self.accuracy(logits, label) self.f1_measure(logits, label) output["loss"] = self.loss_function(logits, label) return output def get_metrics(self, reset: bool = False) -> Dict[str, float]: precision, recall, f1_measure = self.f1_measure.get_metric(reset) return { 'accuracy': self.accuracy.get_metric(reset), 'precision': precision, 'recall': recall, 'f1_measure': f1_measure }
class SimpleFusion(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, output_feedforward: FeedForward, regularizer: Optional[RegularizerApplicator] = None, detector_final_dim: int = 512, dropout: float = 0.5, initializer: InitializerApplicator = InitializerApplicator()) -> None: """ :param vocab: :param text_field_embedder: :param encoder: :param output_feedforward: :param regularizer: :param detector_final_dim: :param dropout: :param initializer: """ super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._encoder = encoder self.detector = SimpleDetector(detector_final_dim) if dropout: self.dropout = nn.Dropout(dropout) self.rnn_input_dropout = InputVariationalDropout(dropout) else: self.dropout = None self.rnn_input_dropout = None self._output_feedforward = output_feedforward self._accuracy = CategoricalAccuracy() self._loss = nn.CrossEntropyLoss() initializer(self) def forward(self, premise_img: torch.Tensor, hypothesis: Dict[str, torch.Tensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: """ :param premise_img: :param hypothesis: :param label: :return: """ embedded_hypothesis = self._text_field_embedder(hypothesis) hypothesis_mask = get_text_field_mask(hypothesis).float() if self.rnn_input_dropout: embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) hypothesis_hidden_state = get_final_encoder_states( encoded_hypothesis, hypothesis_mask, self._encoder.is_bidirectional() ) img_feats = self.detector(premise_img) fused_features = torch.cat((img_feats, hypothesis_hidden_state), dim=-1) label_logits = self._output_feedforward(fused_features) label_probs = nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'accuracy': self._accuracy.get_metric(reset)}
class BertForClassification(Model): """ An AllenNLP Model that runs pretrained BERT, takes the pooled output, and adds a Linear layer on top. If you want an easy way to use BERT for classification, this is it. Note that this is a somewhat non-AllenNLP-ish model architecture, in that it essentially requires you to use the "bert-pretrained" token indexer, rather than configuring whatever indexing scheme you like. See `allennlp/tests/fixtures/bert/bert_for_classification.jsonnet` for an example of what your config might look like. # Parameters vocab : ``Vocabulary`` bert_model : ``Union[str, BertModel]`` The BERT model to be wrapped. If a string is provided, we will call ``BertModel.from_pretrained(bert_model)`` and use the result. num_labels : ``int``, optional (default: None) How many output classes to predict. If not provided, we'll use the vocab_size for the ``label_namespace``. index : ``str``, optional (default: "bert") The index of the token indexer that generates the BERT indices. label_namespace : ``str``, optional (default : "labels") Used to determine the number of classes if ``num_labels`` is not supplied. trainable : ``bool``, optional (default : True) If True, the weights of the pretrained BERT model will be updated during training. Otherwise, they will be frozen and only the final linear layer will be trained. initializer : ``InitializerApplicator``, optional If provided, will be used to initialize the final linear layer *only*. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__( self, vocab: Vocabulary, bert_model: Union[str, BertModel], dropout: float = 0.0, num_labels: int = None, index: str = "bert", label_namespace: str = "labels", trainable: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizer) if isinstance(bert_model, str): self.bert_model = PretrainedBertModel.load(bert_model) else: self.bert_model = bert_model for param in self.bert_model.parameters(): param.requires_grad = trainable in_features = self.bert_model.config.hidden_size self._label_namespace = label_namespace if num_labels: out_features = num_labels else: out_features = vocab.get_vocab_size( namespace=self._label_namespace) self._dropout = torch.nn.Dropout(p=dropout) self._classification_layer = torch.nn.Linear(in_features, out_features) self._accuracy = CategoricalAccuracy() # ****** add by jlk ****** self._f1score = F1Measure(positive_label=1) # ****** add by jlk ****** self._loss = torch.nn.CrossEntropyLoss() self._index = index initializer(self._classification_layer) def forward( # type: ignore self, data_id: Dict[str, torch.LongTensor], tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: """ # Parameters tokens : Dict[str, torch.LongTensor] From a ``TextField`` (that has a bert-pretrained token indexer) label : torch.IntTensor, optional (default = None) From a ``LabelField`` # Returns An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalized log probabilities of the label. probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ input_ids = tokens[self._index] token_type_ids = tokens[f"{self._index}-type-ids"] input_mask = (input_ids != 0).long() _, pooled = self.bert_model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=input_mask) pooled = self._dropout(pooled) # apply classification layer logits = self._classification_layer(pooled) probs = torch.nn.functional.softmax(logits, dim=-1) output_dict = {"logits": logits, "probs": probs} if label is not None: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss self._accuracy(logits, label) self._f1score(logits, label) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Does a simple argmax over the probabilities, converts index to string label, and add ``"label"`` key to the dictionary with the result. """ predictions = output_dict["probs"] if predictions.dim() == 2: predictions_list = [ predictions[i] for i in range(predictions.shape[0]) ] else: predictions_list = [predictions] classes = [] for prediction in predictions_list: label_idx = prediction.argmax(dim=-1).item() label_str = self.vocab.get_index_to_token_vocabulary( self._label_namespace).get(label_idx, str(label_idx)) classes.append(label_str) output_dict["label"] = classes return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: # ****** add by jlk ****** metrics = { "accuracy": self._accuracy.get_metric(reset), "f1score": self._f1score.get_metric(reset)[2] } return metrics
class ShallowProductOfExpertsClassifier(Model): """ This `Model` implements a basic text classifier. After embedding the text into a text field, we will optionally encode the embeddings with a `Seq2SeqEncoder`. The resulting sequence is pooled using a `Seq2VecEncoder` and then passed to a linear classification layer, which projects into the label space. If a `Seq2SeqEncoder` is not provided, we will pass the embedded text directly to the `Seq2VecEncoder`. Registered as a `Model` with name "basic_classifier". # Parameters vocab : `Vocabulary` text_field_embedder : `TextFieldEmbedder` Used to embed the input text into a `TextField` seq2seq_encoder : `Seq2SeqEncoder`, optional (default=`None`) Optional Seq2Seq encoder layer for the input text. seq2vec_encoder : `Seq2VecEncoder` Required Seq2Vec encoder layer. If `seq2seq_encoder` is provided, this encoder will pool its output. Otherwise, this encoder will operate directly on the output of the `text_field_embedder`. feedforward : `FeedForward`, optional, (default = None). An optional feedforward layer to apply after the seq2vec_encoder. dropout : `float`, optional (default = `None`) Dropout percentage to use. num_labels : `int`, optional (default = `None`) Number of labels to project to in classification layer. By default, the classification layer will project to the size of the vocabulary namespace corresponding to labels. label_namespace : `str`, optional (default = "labels") Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace. initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`) If provided, will be used to initialize the model parameters. """ def __init__( self, vocab: Vocabulary, beta: float, text_field_embedder: TextFieldEmbedder, seq2vec_encoder: Seq2VecEncoder, seq2seq_encoder: Seq2SeqEncoder = None, feedforward: Optional[FeedForward] = None, feedforward_hyp_only: Optional[FeedForward] = None, dropout: float = None, num_labels: int = None, label_namespace: str = "labels", evaluation_mode: bool = False, initializer: InitializerApplicator = InitializerApplicator(), **kwargs, ) -> None: super().__init__(vocab, **kwargs) self.evaluation_mode = evaluation_mode self._text_field_embedder = text_field_embedder if seq2seq_encoder: self._seq2seq_encoder = seq2seq_encoder else: self._seq2seq_encoder = None self._seq2vec_encoder = seq2vec_encoder self._feedforward = feedforward self._feedforward_hyp_only = feedforward_hyp_only if feedforward is not None: self._classifier_input_dim = self._feedforward.get_output_dim() else: self._classifier_input_dim = self._seq2vec_encoder.get_output_dim() if feedforward_hyp_only is not None: self._classifier_hyp_only_input_dim = self._feedforward_hyp_only.get_output_dim( ) else: self._classifier_hyp_only_input_dim = self._seq2vec_encoder.get_output_dim( ) if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = None self._label_namespace = label_namespace if num_labels: self._num_labels = num_labels else: self._num_labels = vocab.get_vocab_size( namespace=self._label_namespace) self._classification_layer = torch.nn.Linear( self._classifier_input_dim, self._num_labels) self._classification_layer_hyp_only = torch.nn.Linear( self._classifier_hyp_only_input_dim, self._num_labels) self._beta = beta self._accuracy = CategoricalAccuracy() self._hyp_only_accuracy = CategoricalAccuracy() self._cross_ent_loss = torch.nn.CrossEntropyLoss() self._nll_loss = torch.nn.NLLLoss() initializer(self) def finetune(self): self.evaluation_mode = True def forward( # type: ignore self, tokens: TextFieldTensors, bias_tokens: TextFieldTensors = None, label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens) if self._seq2seq_encoder: embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) if self._dropout: embedded_text = self._dropout(embedded_text) if self._feedforward is not None: embedded_text = self._feedforward(embedded_text) sentence_pair_logits = self._classification_layer(embedded_text) # If we're training, also compute loss and accuracy for the bias-only model if not self.evaluation_mode and bias_tokens is not None: # Make predictions with hypothesis only embedded_text = self._text_field_embedder(bias_tokens) mask = get_text_field_mask(bias_tokens) if self._seq2seq_encoder: embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) if self._dropout: embedded_text = self._dropout(embedded_text) if self._feedforward_hyp_only is not None: embedded_text = self._feedforward_hyp_only(embedded_text) hyp_only_logits = self._classification_layer_hyp_only( embedded_text) log_probs_pair = torch.log_softmax(sentence_pair_logits, dim=1) log_probs_hyp = torch.log_softmax(hyp_only_logits, dim=1) # Combine with product of experts (normalized log space sum) # Do not require gradients from hyp-only classifier combined = log_probs_pair + log_probs_hyp.detach() # NLL loss over combined labels loss = self._nll_loss(combined, label.long().view(-1)) hyp_loss = self._nll_loss(log_probs_hyp, label.long().view(-1)) self._accuracy(combined, label) self._hyp_only_accuracy(hyp_only_logits, label) output_dict = {"loss": loss + self._beta * hyp_loss} return output_dict else: loss = self._cross_ent_loss(sentence_pair_logits, label) self._accuracy(sentence_pair_logits, label) return { "loss": loss, "logits": sentence_pair_logits, "probs": torch.softmax(sentence_pair_logits, dim=1) } @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Does a simple argmax over the probabilities, converts index to string label, and add `"label"` key to the dictionary with the result. """ predictions = output_dict["probs"] if predictions.dim() == 2: predictions_list = [ predictions[i] for i in range(predictions.shape[0]) ] else: predictions_list = [predictions] classes = [] for prediction in predictions_list: label_idx = prediction.argmax(dim=-1).item() label_str = self.vocab.get_index_to_token_vocabulary( self._label_namespace).get(label_idx, str(label_idx)) classes.append(label_str) output_dict["label"] = classes return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = { "hyp_only_accuracy": self._hyp_only_accuracy.get_metric(reset), "accuracy": self._accuracy.get_metric(reset) } return metrics
class SmartdataEventxModel(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, span_extractor: SpanExtractor, entity_embedder: TokenEmbedder, hidden_dim: int, loss_weight: float = 1.0, trigger_gamma: float = None, role_gamma: float = None, positive_class_weight: float = 1.0, triggers_namespace: str = 'trigger_labels', roles_namespace: str = 'arg_role_labels', initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None) -> None: super().__init__(vocab=vocab, regularizer=regularizer) self._triggers_namespace = triggers_namespace self._roles_namespace = roles_namespace self.num_trigger_classes = self.vocab.get_vocab_size( triggers_namespace) self.num_role_classes = self.vocab.get_vocab_size(roles_namespace) self.hidden_dim = hidden_dim self.loss_weight = loss_weight self.trigger_gamma = trigger_gamma self.role_gamma = role_gamma self.text_field_embedder = text_field_embedder self.encoder = encoder self.entity_embedder = entity_embedder self.span_extractor = span_extractor self.trigger_projection = Linear(self.encoder.get_output_dim(), self.num_trigger_classes) self.trigger_to_hidden = Linear(self.encoder.get_output_dim(), self.hidden_dim) self.entities_to_hidden = Linear(self.encoder.get_output_dim(), self.hidden_dim) self.hidden_bias = Parameter(torch.Tensor(self.hidden_dim)) torch.nn.init.normal_(self.hidden_bias) self.hidden_to_roles = Linear(self.hidden_dim, self.num_role_classes) self.trigger_accuracy = CategoricalAccuracy() trigger_labels_to_idx = self.vocab.get_token_to_index_vocabulary( namespace=triggers_namespace) evaluated_trigger_idxs = list(trigger_labels_to_idx.values()) evaluated_trigger_idxs.remove( trigger_labels_to_idx[NEGATIVE_TRIGGER_LABEL]) self.trigger_f1 = MicroFBetaMeasure( average='micro', # Macro averaging in get_metrics labels=evaluated_trigger_idxs) role_labels_to_idx = self.vocab.get_token_to_index_vocabulary( namespace=roles_namespace) evaluated_role_idxs = list(role_labels_to_idx.values()) evaluated_role_idxs.remove(role_labels_to_idx[NEGATIVE_ARGUMENT_LABEL]) self.role_accuracy = CategoricalAccuracy() self.role_f1 = MicroFBetaMeasure( average='micro', # Macro averaging in get_metrics labels=evaluated_role_idxs) # Trigger class weighting as done in JMEE repo trigger_labels_to_idx = self.vocab\ .get_token_to_index_vocabulary(namespace=triggers_namespace) self.trigger_class_weights = torch.ones( len(trigger_labels_to_idx)) * positive_class_weight self.trigger_class_weights[ trigger_labels_to_idx[NEGATIVE_TRIGGER_LABEL]] = 1.0 initializer(self) @overrides def forward( self, tokens: Dict[str, torch.LongTensor], entity_tags: torch.LongTensor, entity_spans: torch.LongTensor, trigger_spans: torch.LongTensor, trigger_labels: torch.LongTensor = None, arg_roles: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: embedded_tokens = self.text_field_embedder(tokens) text_mask = get_text_field_mask(tokens) embedded_entity_tags = self.entity_embedder(entity_tags) embedded_input = torch.cat([embedded_tokens, embedded_entity_tags], dim=-1) encoded_input = self.encoder(embedded_input, text_mask) ########################### # Trigger type prediction # ########################### # Extract the spans of the triggers trigger_spans_mask = (trigger_spans[:, :, 0] >= 0).long() encoded_triggers = self.span_extractor( sequence_tensor=encoded_input, span_indices=trigger_spans, sequence_mask=text_mask, span_indices_mask=trigger_spans_mask) # Pass the extracted triggers through a projection for classification trigger_logits = self.trigger_projection(encoded_triggers) # Add the trigger predictions to the output trigger_probabilities = F.softmax(trigger_logits, dim=-1) output_dict = { "trigger_logits": trigger_logits, "trigger_probabilities": trigger_probabilities } if trigger_labels is not None: # Compute loss and metrics using the given trigger labels trigger_mask = (trigger_labels != -1) trigger_labels = trigger_labels * trigger_mask self.trigger_accuracy(trigger_logits, trigger_labels, trigger_mask.float()) self.trigger_f1(trigger_logits, trigger_labels, trigger_mask.float()) trigger_logits_t = trigger_logits.permute(0, 2, 1) trigger_loss = self._cross_entropy_focal_loss( logits=trigger_logits_t, target=trigger_labels, target_mask=trigger_mask, gamma=self.trigger_gamma, alpha=self.trigger_class_weights) output_dict["triggers_loss"] = trigger_loss output_dict["loss"] = trigger_loss ######################################## # Argument detection and role labeling # ######################################## # Extract the spans of the encoded entities entity_spans_mask = (entity_spans[:, :, 0] >= 0).long() encoded_entities = self.span_extractor( sequence_tensor=encoded_input, span_indices=entity_spans, sequence_mask=text_mask, span_indices_mask=entity_spans_mask) # Project both triggers and entities/args into a 'hidden' comparison space triggers_hidden = self.trigger_to_hidden(encoded_triggers) args_hidden = self.entities_to_hidden(encoded_entities) # Create the cross-product of triggers and args via broadcasting trigger = triggers_hidden.unsqueeze(2) # B x T x 1 x H args = args_hidden.unsqueeze(1) # B x T x E x H trigger_arg = trigger + args + self.hidden_bias # B x T x E x H # Pass through activation and projection for classification role_activations = F.relu(trigger_arg) role_logits = self.hidden_to_roles(role_activations) # B x T x E x R # Add the role predictions to the output role_probabilities = torch.softmax(role_logits, dim=-1) output_dict['role_logits'] = role_logits output_dict['role_probabilities'] = role_probabilities # Compute loss and metrics using the given role labels if arg_roles is not None: arg_roles = self._assert_target_shape(logits=role_logits, target=arg_roles) target_mask = (arg_roles != -1) target = arg_roles * target_mask # remove negative indices self.role_accuracy(role_logits, target, target_mask.float()) self.role_f1(role_logits, target, target_mask.float()) # Masked batch-wise cross entropy loss, optionally with focal-loss role_logits_t = role_logits.permute(0, 3, 1, 2) role_loss = cross_entropy_focal_loss(logits=role_logits_t, target=target, target_mask=target_mask, gamma=self.role_gamma) output_dict['role_loss'] = role_loss output_dict['loss'] += self.loss_weight * role_loss # Append the original tokens for visualization if metadata is not None: output_dict["words"] = [x["words"] for x in metadata] # Append the trigger and entity spans to reconstruct the event after prediction output_dict['entity_spans'] = entity_spans output_dict['trigger_spans'] = trigger_spans return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: trigger_predictions = output_dict['trigger_probabilities'].cpu( ).data.numpy() trigger_labels = [[ self.vocab.get_token_from_index(trigger_idx, namespace=self._triggers_namespace) for trigger_idx in example ] for example in np.argmax(trigger_predictions, axis=-1)] output_dict['trigger_labels'] = trigger_labels arg_role_predictions = output_dict['role_logits'].cpu().data.numpy() arg_role_labels = [[ [ self.vocab.get_token_from_index( role_idx, namespace=self._roles_namespace) for role_idx in event ] for event in example ] for example in np.argmax(arg_role_predictions, axis=-1)] output_dict['role_labels'] = arg_role_labels events = [] for batch_idx in range(len(trigger_labels)): words = output_dict['words'][batch_idx] batch_events = [] for trigger_idx, trigger_label in enumerate( trigger_labels[batch_idx]): if trigger_label == NEGATIVE_TRIGGER_LABEL: continue trigger_span = output_dict['trigger_spans'][batch_idx][ trigger_idx] trigger_start = trigger_span[0].item() trigger_end = trigger_span[1].item() + 1 if trigger_start < 0: continue event = { 'event_type': trigger_label, 'trigger': { 'text': " ".join(words[trigger_start:trigger_end]), 'start': trigger_start, 'end': trigger_end }, 'arguments': [] } for entity_idx, role_label in enumerate( arg_role_labels[batch_idx][trigger_idx]): if role_label == NEGATIVE_ARGUMENT_LABEL: continue arg_span = output_dict['entity_spans'][batch_idx][ entity_idx] arg_start = arg_span[0].item() arg_end = arg_span[1].item() + 1 if arg_start < 0: continue argument = { 'text': " ".join(words[arg_start:arg_end]), 'start': arg_start, 'end': arg_end, 'role': role_label } event['arguments'].append(argument) if len(event['arguments']) > 0: batch_events.append(event) events.append(batch_events) output_dict['events'] = events return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'trigger_acc': self.trigger_accuracy.get_metric(reset=reset), 'trigger_f1': self.trigger_f1.get_metric(reset=reset)['fscore'], 'role_acc': self.role_accuracy.get_metric(reset=reset), 'role_f1': self.role_f1.get_metric(reset=reset)['fscore'] } @staticmethod def _assert_target_shape(logits, target): """ Asserts that target tensors are always of the same size of logits. This is not always the case since some batches are not completely filled. """ expected_shape = logits.shape[:-1] if target.shape == expected_shape: return target else: new_target = torch.full(size=expected_shape, fill_value=-1, dtype=target.dtype, device=target.device) batch_size, triggers_len, arguments_len = target.shape new_target[:, :triggers_len, :arguments_len] = target return new_target
class DecomposableAttention(Model): """ This `Model` implements the Decomposable Attention model described in [A Decomposable Attention Model for Natural Language Inference]( https://www.semanticscholar.org/paper/A-Decomposable-Attention-Model-for-Natural-Languag-Parikh-T%C3%A4ckstr%C3%B6m/07a9478e87a8304fc3267fa16e83e9f3bbd98b27) by Parikh et al., 2016, with some optional enhancements before the decomposable attention actually happens. Parikh's original model allowed for computing an "intra-sentence" attention before doing the decomposable entailment step. We generalize this to any [`Seq2SeqEncoder`](../modules/seq2seq_encoders/seq2seq_encoder.md) that can be applied to the premise and/or the hypothesis before computing entailment. The basic outline of this model is to get an embedded representation of each word in the premise and hypothesis, align words between the two, compare the aligned phrases, and make a final entailment decision based on this aggregated comparison. Each step in this process uses a feedforward network to modify the representation. Registered as a `Model` with name "decomposable_attention". # Parameters vocab : `Vocabulary` text_field_embedder : `TextFieldEmbedder` Used to embed the `premise` and `hypothesis` `TextFields` we get as input to the model. attend_feedforward : `FeedForward` This feedforward network is applied to the encoded sentence representations before the similarity matrix is computed between words in the premise and words in the hypothesis. matrix_attention : `MatrixAttention` This is the attention function used when computing the similarity matrix between words in the premise and words in the hypothesis. compare_feedforward : `FeedForward` This feedforward network is applied to the aligned premise and hypothesis representations, individually. aggregate_feedforward : `FeedForward` This final feedforward network is applied to the concatenated, summed result of the `compare_feedforward` network, and its output is used as the entailment class logits. premise_encoder : `Seq2SeqEncoder`, optional (default=`None`) After embedding the premise, we can optionally apply an encoder. If this is `None`, we will do nothing. hypothesis_encoder : `Seq2SeqEncoder`, optional (default=`None`) After embedding the hypothesis, we can optionally apply an encoder. If this is `None`, we will use the `premise_encoder` for the encoding (doing nothing if `premise_encoder` is also `None`). initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`) Used to initialize the model parameters. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, attend_feedforward: FeedForward, matrix_attention: MatrixAttention, compare_feedforward: FeedForward, aggregate_feedforward: FeedForward, premise_encoder: Optional[Seq2SeqEncoder] = None, hypothesis_encoder: Optional[Seq2SeqEncoder] = None, initializer: InitializerApplicator = InitializerApplicator(), **kwargs, ) -> None: super().__init__(vocab, **kwargs) self._text_field_embedder = text_field_embedder self._attend_feedforward = TimeDistributed(attend_feedforward) self._matrix_attention = matrix_attention self._compare_feedforward = TimeDistributed(compare_feedforward) self._aggregate_feedforward = aggregate_feedforward self._premise_encoder = premise_encoder self._hypothesis_encoder = hypothesis_encoder or premise_encoder self._num_labels = vocab.get_vocab_size(namespace="labels") check_dimensions_match( text_field_embedder.get_output_dim(), attend_feedforward.get_input_dim(), "text field embedding dim", "attend feedforward input dim", ) check_dimensions_match( aggregate_feedforward.get_output_dim(), self._num_labels, "final output dimension", "number of labels", ) self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( # type: ignore self, premise: TextFieldTensors, hypothesis: TextFieldTensors, label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters premise : `TextFieldTensors` From a `TextField` hypothesis : `TextFieldTensors` From a `TextField` label : `torch.IntTensor`, optional (default = `None`) From a `LabelField` metadata : `List[Dict[str, Any]]`, optional (default = `None`) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. # Returns An output dictionary consisting of: label_logits : `torch.FloatTensor` A tensor of shape `(batch_size, num_labels)` representing unnormalised log probabilities of the entailment label. label_probs : `torch.FloatTensor` A tensor of shape `(batch_size, num_labels)` representing probabilities of the entailment label. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise) hypothesis_mask = get_text_field_mask(hypothesis) if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder( embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat( [embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat( [embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward( hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze( -1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs, "h2p_attention": h2p_attention, "p2h_attention": p2h_attention, } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss if metadata is not None: output_dict["premise_tokens"] = [ x["premise_tokens"] for x in metadata ] output_dict["hypothesis_tokens"] = [ x["hypothesis_tokens"] for x in metadata ] return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {"accuracy": self._accuracy.get_metric(reset)} default_predictor = "textual_entailment"
class EntityNLMDiscriminator(Model): """ Implementation of the discriminative model from: https://arxiv.org/abs/1708.00781 used to draw importance samples. Parameters ---------- vocab : ``Vocabulary`` The model vocabulary. text_field_embedder : ``TextFieldEmbedder`` Used to embed tokens. encoder : ``Seq2SeqEncoder`` Used to encode the sequence of token embeddings. embedding_dim : ``int`` The dimension of entity / length embeddings. Should match the encoder output size. max_mention_length : ``int`` Maximum entity mention length. max_embeddings : ``int`` Maximum number of embeddings. variational_dropout_rate : ``float``, optional Dropout rate of variational dropout applied to input embeddings. Default: 0.0 dropout_rate : ``float``, optional Dropout rate applied to hidden states. Default: 0.0 initializer : ``InitializerApplicator``, optional Used to initialize model parameters. """ # pylint: disable=line-too-long def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, embedding_dim: int, max_mention_length: int, max_embeddings: int, variational_dropout_rate: float = 0.0, dropout_rate: float = 0.0, initializer: InitializerApplicator = InitializerApplicator()) -> None: super(EntityNLMDiscriminator, self).__init__(vocab) self._text_field_embedder = text_field_embedder self._encoder = encoder self._embedding_dim = embedding_dim self._max_mention_length = max_mention_length self._max_embeddings = max_embeddings self._state: Optional[StateDict] = None # Input variational dropout self._variational_dropout = InputVariationalDropout( variational_dropout_rate) self._dropout = torch.nn.Dropout(dropout_rate) # For entity type prediction self._entity_type_projection = torch.nn.Linear(in_features=embedding_dim, out_features=2, bias=False) self._dynamic_embeddings = DynamicEmbedding(embedding_dim=embedding_dim, max_embeddings=max_embeddings) # For mention length prediction self._mention_length_projection = torch.nn.Linear(in_features=2*embedding_dim, out_features=max_mention_length) self._entity_type_accuracy = CategoricalAccuracy() self._entity_id_accuracy = CategoricalAccuracy() self._mention_length_accuracy = CategoricalAccuracy() initializer(self) @overrides def forward(self, # pylint: disable=arguments-differ tokens: Dict[str, torch.Tensor], entity_types: Optional[torch.Tensor] = None, entity_ids: Optional[torch.Tensor] = None, mention_lengths: Optional[torch.Tensor] = None, reset: bool = False)-> Dict[str, torch.Tensor]: """ Computes the loss during training / validation. Parameters ---------- tokens : ``Dict[str, torch.Tensor]`` A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of tokens. entity_types : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the corresponding token belongs to a mention. entity_ids : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the entities the corresponding token is mentioning. mention_lengths : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining tokens (including the current one) there are in the mention. reset : ``bool`` Whether or not to reset the model's state. This should be done at the start of each new sequence. Returns ------- An output dictionary consisting of: loss : ``torch.Tensor`` The combined loss. """ batch_size = tokens['tokens'].shape[0] if reset: self.reset_states(batch_size) else: self.detach_states() if entity_types is not None: output_dict = self._forward_loop(tokens=tokens, entity_types=entity_types, entity_ids=entity_ids, mention_lengths=mention_lengths) else: output_dict = {} return output_dict def sample(self, tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Generates a sample from the discriminative model. WARNING: Unlike during training, this function expects the full (unsplit) sequence of tokens. Parameters ---------- tokens : ``Dict[str, torch.Tensor]`` A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of tokens. Returns ------- An output dictionary consisting of: logp : ``torch.Tensor`` A tensor containing the log-probability of the sample (averaged over time) entity_types : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the corresponding token belongs to a mention. entity_ids : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the entities the corresponding token is mentioning. mention_lengths : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining tokens (including the current one) there are in the mention. """ batch_size, sequence_length = tokens['tokens'].shape # We will use a standard iterator during evaluation instead of a split iterator. Otherwise # it will be a pain to handle generating multiple samples for a sequence since there's no # way to get back to the first split. self.reset_states(batch_size) # Embed tokens and get RNN hidden state. mask = get_text_field_mask(tokens) embeddings = self._text_field_embedder(tokens) hidden = self._encoder(embeddings, mask) prev_mention_lengths = tokens['tokens'].new_ones(batch_size) # Initialize outputs # Track total logp for **each** generated sample logp = hidden.new_zeros(batch_size) entity_types = torch.zeros_like(tokens['tokens'], dtype=torch.uint8) entity_ids = torch.zeros_like(tokens['tokens']) mention_lengths = torch.ones_like(tokens['tokens']) # Generate outputs for timestep in range(sequence_length): current_hidden = hidden[:, timestep] # We only predict types / ids / lengths if the previous mention is terminated. predict_mask = prev_mention_lengths == 1 predict_mask = predict_mask * mask[:, timestep].byte() if predict_mask.sum() > 0: # Predict entity types entity_type_logits = self._entity_type_projection( current_hidden[predict_mask]) entity_type_logp = F.log_softmax(entity_type_logits, dim=-1) entity_type_prediction_logp, entity_type_predictions = sample_from_logp( entity_type_logp) entity_type_predictions = entity_type_predictions.byte() entity_types[predict_mask, timestep] = entity_type_predictions logp[predict_mask] += entity_type_prediction_logp # Only predict entity and mention lengths if we predicted that there was a mention predict_em = entity_types[:, timestep] * predict_mask if predict_em.sum() > 0: # Predict entity ids entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, mask=predict_em) entity_id_logits = entity_id_prediction_outputs['logits'] entity_id_logp = F.log_softmax(entity_id_logits, dim=-1) entity_id_prediction_logp, entity_id_predictions = sample_from_logp( entity_id_logp) # Predict mention lengths - we do this before writing the # entity id predictions since we'll need to reindex the new # entities, but need the null embeddings here. predicted_entity_embeddings = self._dynamic_embeddings.embeddings[ predict_em, entity_id_predictions] concatenated = torch.cat( (current_hidden[predict_em], predicted_entity_embeddings), dim=-1) mention_length_logits = self._mention_length_projection( concatenated) mention_length_logp = F.log_softmax( mention_length_logits, dim=-1) mention_length_prediction_logp, mention_length_predictions = sample_from_logp( mention_length_logp) # Write predictions new_entity_mask = entity_id_predictions == 0 new_entity_labels = self._dynamic_embeddings.num_embeddings[predict_em] entity_id_predictions[new_entity_mask] = new_entity_labels[new_entity_mask] entity_ids[predict_em, timestep] = entity_id_predictions logp[predict_em] += entity_id_prediction_logp mention_lengths[predict_em, timestep] = mention_length_predictions logp[predict_em] += mention_length_prediction_logp # Add / update entity embeddings new_entities = entity_ids[:, timestep] == self._dynamic_embeddings.num_embeddings self._dynamic_embeddings.add_embeddings(timestep, new_entities) self._dynamic_embeddings.update_embeddings(hidden=current_hidden, update_indices=entity_ids[:, timestep], timestep=timestep, mask=predict_em) # If the previous mentions are ongoing, we assign the output deterministically. Mention # lengths decrease by 1, all other outputs are copied from the previous timestep. Do # not need to add anything to logp since these 'predictions' have probability 1 under # the model. deterministic_mask = prev_mention_lengths > 1 deterministic_mask = deterministic_mask * mask[:, timestep].byte() if deterministic_mask.sum() > 1: entity_types[deterministic_mask, timestep] = entity_types[deterministic_mask, timestep - 1] entity_ids[deterministic_mask, timestep] = entity_ids[deterministic_mask, timestep - 1] mention_lengths[deterministic_mask, timestep] = mention_lengths[deterministic_mask, timestep - 1] - 1 # Update mention lengths for next timestep prev_mention_lengths = mention_lengths[:, timestep] return { 'logp': logp, 'sample': { 'entity_types': entity_types, 'entity_ids': entity_ids, 'mention_lengths': mention_lengths } } def _forward_loop(self, tokens: Dict[str, torch.Tensor], entity_types: torch.Tensor, entity_ids: torch.Tensor, mention_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: """ Performs the forward pass to calculate the loss on a chunk of training data. Parameters ---------- tokens : ``Dict[str, torch.Tensor]`` A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of tokens. entity_types : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` indicating whether or not the corresponding token belongs to a mention. entity_ids : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` containing the ids of the entities the corresponding token is mentioning. mention_lengths : ``torch.Tensor`` A tensor of shape ``(batch_size, sequence_length)`` tracking how many remaining tokens (including the current one) there are in the mention. Returns ------- An output dictionary consisting of: entity_type_loss : ``torch.Tensor`` The loss of entity type predictions. entity_id_loss : ``torch.Tensor`` The loss of entity id predictions. mention_length_loss : ``torch.Tensor`` The loss of mention length predictions. loss : ``torch.Tensor`` The combined loss. """ batch_size, sequence_length = tokens['tokens'].shape # Need to track previous mention lengths in order to know when to measure loss. if self._state is None: prev_mention_lengths = mention_lengths.new_ones(batch_size) else: prev_mention_lengths = self._state['prev_mention_lengths'] # Embed tokens and get RNN hidden state. mask = get_text_field_mask(tokens) embeddings = self._text_field_embedder(tokens) embeddings = self._variational_dropout(embeddings) hidden = self._encoder(embeddings, mask) # Initialize losses entity_type_loss = torch.tensor( 0.0, requires_grad=True, device=hidden.device) entity_id_loss = torch.tensor( 0.0, requires_grad=True, device=hidden.device) mention_length_loss = torch.tensor( 0.0, requires_grad=True, device=hidden.device) for timestep in range(sequence_length): current_entity_types = entity_types[:, timestep] current_entity_ids = entity_ids[:, timestep] current_mention_lengths = mention_lengths[:, timestep] current_hidden = hidden[:, timestep] current_hidden = self._dropout(hidden[:, timestep]) # We only predict types / ids / lengths if we are not currently in the process of # generating a mention (e.g. if the previous remaining mention length is 1). Indexing / # masking with ``predict_all`` makes it possible to do this in batch. predict_all = prev_mention_lengths == 1 predict_all = predict_all * mask[:, timestep].byte() if predict_all.sum() > 0: # Equation 3 in the paper. entity_type_logits = self._entity_type_projection( current_hidden[predict_all]) entity_type_loss = entity_type_loss + F.cross_entropy( entity_type_logits, current_entity_types[predict_all].long(), reduction='sum') self._entity_type_accuracy(predictions=entity_type_logits, gold_labels=current_entity_types[predict_all].long()) # Only proceed to predict entity and mention length if there is in fact an entity. predict_em = current_entity_types * predict_all if predict_em.sum() > 0: # Equation 4 in the paper. We want new entities to correspond to a prediction of # zero, their embedding should be added after they've been predicted for the first # time. modified_entity_ids = current_entity_ids.clone() modified_entity_ids[modified_entity_ids == self._dynamic_embeddings.num_embeddings] = 0 entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, target=modified_entity_ids, mask=predict_em) entity_id_loss = entity_id_loss + \ entity_id_prediction_outputs['loss'].sum() self._entity_id_accuracy(predictions=entity_id_prediction_outputs['logits'], gold_labels=modified_entity_ids[predict_em]) # Equation 5 in the paper. predicted_entity_embeddings = self._dynamic_embeddings.embeddings[ predict_em, modified_entity_ids[predict_em]] predicted_entity_embeddings = self._dropout( predicted_entity_embeddings) concatenated = torch.cat( (current_hidden[predict_em], predicted_entity_embeddings), dim=-1) mention_length_logits = self._mention_length_projection( concatenated) mention_length_loss = mention_length_loss + F.cross_entropy( mention_length_logits, current_mention_lengths[predict_em]) self._mention_length_accuracy(predictions=mention_length_logits, gold_labels=current_mention_lengths[predict_em]) # We add new entities to any sequence where the current entity id matches the number of # embeddings that currently exist for that sequence (this means we need a new one since # there is an additional dummy embedding). new_entities = current_entity_ids == self._dynamic_embeddings.num_embeddings self._dynamic_embeddings.add_embeddings(timestep, new_entities) # We also perform updates of the currently observed entities. self._dynamic_embeddings.update_embeddings(hidden=current_hidden, update_indices=current_entity_ids, timestep=timestep, mask=current_entity_types) prev_mention_lengths = current_mention_lengths # Normalize the losses entity_type_loss = entity_type_loss / mask.sum() entity_id_loss = entity_id_loss / mask.sum() mention_length_loss = mention_length_loss / mask.sum() total_loss = entity_type_loss + entity_id_loss + mention_length_loss output_dict = { 'entity_type_loss': entity_type_loss, 'entity_id_loss': entity_id_loss, 'mention_length_loss': mention_length_loss, 'loss': total_loss } # Update state self._state = { 'prev_mention_lengths': prev_mention_lengths.detach() } return output_dict def reset_states(self, batch_size: int) -> None: """Resets the model's internals. Should be called at the start of a new batch.""" self._encoder.reset_states() self._dynamic_embeddings.reset_states(batch_size) self._state = None def detach_states(self): """Detaches the model's state to enforce truncated backpropagation.""" self._dynamic_embeddings.detach_states() @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'et_acc': self._entity_type_accuracy.get_metric(reset), 'eid_acc': self._entity_id_accuracy.get_metric(reset), 'ml_acc': self._mention_length_accuracy.get_metric(reset) }
class TransformerClassificationTT(Model): """ This class implements a classification patterned after the proposed model in [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al)] (https://api.semanticscholar.org/CorpusID:198953378). Parameters ---------- vocab : ``Vocabulary`` transformer_model : ``str``, optional (default=``"roberta-large"``) This model chooses the embedder according to this setting. You probably want to make sure this matches the setting in the reader. """ def __init__( self, vocab: Vocabulary, transformer_model: str = "roberta-large", num_labels: Optional[int] = None, label_namespace: str = "labels", override_weights_file: Optional[str] = None, **kwargs, ) -> None: super().__init__(vocab, **kwargs) transformer_kwargs = { "model_name": transformer_model, "weights_path": override_weights_file, } self.embeddings = TransformerEmbeddings.from_pretrained_module( **transformer_kwargs) self.transformer_stack = TransformerStack.from_pretrained_module( **transformer_kwargs) self.pooler = TransformerPooler.from_pretrained_module( **transformer_kwargs) self.pooler_dropout = Dropout(p=0.1) self.label_tokens = vocab.get_index_to_token_vocabulary( label_namespace) if num_labels is None: num_labels = len(self.label_tokens) self.linear_layer = torch.nn.Linear(self.pooler.get_output_dim(), num_labels) self.linear_layer.weight.data.normal_(mean=0.0, std=0.02) self.linear_layer.bias.data.zero_() from allennlp.training.metrics import CategoricalAccuracy, FBetaMeasure self.loss = torch.nn.CrossEntropyLoss() self.acc = CategoricalAccuracy() self.f1 = FBetaMeasure() def forward( # type: ignore self, text: Dict[str, torch.Tensor], label: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- text : ``Dict[str, torch.LongTensor]`` From a ``TensorTextField``. Contains the text to be classified. label : ``Optional[torch.LongTensor]`` From a ``LabelField``, specifies the true class of the instance Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. This is only returned when `correct_alternative` is not `None`. logits : ``torch.FloatTensor`` The logits for every possible answer choice """ embedded_alternatives = self.embeddings(**text) embedded_alternatives = self.transformer_stack(embedded_alternatives, text["attention_mask"]) embedded_alternatives = self.pooler( embedded_alternatives.final_hidden_states) embedded_alternatives = self.pooler_dropout(embedded_alternatives) logits = self.linear_layer(embedded_alternatives) result = {"logits": logits, "answers": logits.argmax(1)} if label is not None: result["loss"] = self.loss(logits, label) self.acc(logits, label) self.f1(logits, label) return result def get_metrics(self, reset: bool = False) -> Dict[str, float]: result = {"acc": self.acc.get_metric(reset)} for metric_name, metrics_per_class in self.f1.get_metric( reset).items(): for class_index, value in enumerate(metrics_per_class): result[ f"{self.label_tokens[class_index]}-{metric_name}"] = value return result
class SpanConstituencyParser(Model): """ This ``SpanConstituencyParser`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, extracts span representations using a ``SpanExtractor``, and then predicts a label for each span in the sequence. These labels are non-terminal nodes in a constituency parse tree, which we then greedily reconstruct. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. span_extractor : ``SpanExtractor``, required. The method used to extract the spans from the encoded sequence. encoder : ``Seq2SeqEncoder``, required. The encoder that we will use in between embedding tokens and generating span representations. feedforward_layer : ``FeedForward``, required. The FeedForward layer that we will use in between the encoder and the linear projection to a distribution over span labels. pos_tag_embedding : ``Embedding``, optional. Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, span_extractor: SpanExtractor, encoder: Seq2SeqEncoder, feedforward_layer: FeedForward = None, pos_tag_embedding: Embedding = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, evalb_directory_path: str = None) -> None: super(SpanConstituencyParser, self).__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder self.span_extractor = span_extractor self.num_classes = self.vocab.get_vocab_size("labels") self.encoder = encoder self.feedforward_layer = TimeDistributed(feedforward_layer) if feedforward_layer else None self.pos_tag_embedding = pos_tag_embedding or None if feedforward_layer is not None: output_dim = feedforward_layer.get_output_dim() else: output_dim = span_extractor.get_output_dim() self.tag_projection_layer = TimeDistributed(Linear(output_dim, self.num_classes)) representation_dim = text_field_embedder.get_output_dim() if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match(representation_dim, encoder.get_input_dim(), "representation dim (tokens + optional POS tags)", "encoder input dim") check_dimensions_match(encoder.get_output_dim(), span_extractor.get_input_dim(), "encoder input dim", "span extractor input dim") if feedforward_layer is not None: check_dimensions_match(span_extractor.get_output_dim(), feedforward_layer.get_input_dim(), "span extractor output dim", "feedforward input dim") self.tag_accuracy = CategoricalAccuracy() if evalb_directory_path is not None: self._evalb_score = EvalbBracketingScorer(evalb_directory_path) else: self._evalb_score = None initializer(self) @overrides def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], spans: torch.LongTensor, metadata: List[Dict[str, Any]], pos_tags: Dict[str, torch.LongTensor] = None, span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. spans : ``torch.LongTensor``, required. A tensor of shape ``(batch_size, num_spans, 2)`` representing the inclusive start and end indices of all possible spans in the sentence. metadata : List[Dict[str, Any]], required. A dictionary of metadata for each batch element which has keys: tokens : ``List[str]``, required. The original string tokens in the sentence. gold_tree : ``nltk.Tree``, optional (default = None) Gold NLTK trees for use in evaluation. pos_tags : ``List[str]``, optional. The POS tags for the sentence. These can be used in the model as embedded features, but they are passed here in addition for use in constructing the tree. pos_tags : ``torch.LongTensor``, optional (default = None) The output of a ``SequenceLabelField`` containing POS tags. span_labels : ``torch.LongTensor``, optional (default = None) A torch tensor representing the integer gold class labels for all possible spans, of shape ``(batch_size, num_spans)``. Returns ------- An output dictionary consisting of: class_probabilities : ``torch.FloatTensor`` A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. spans : ``torch.LongTensor`` The original spans tensor. tokens : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, required. A list of POS tags in the sentence for each element in the batch. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ embedded_text_input = self.text_field_embedder(tokens) if pos_tags is not None and self.pos_tag_embedding is not None: embedded_pos_tags = self.pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self.pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(tokens) # Looking at the span start index is enough to know if # this is padding or not. Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long() if span_mask.dim() == 1: # This happens if you use batch_size 1 and encounter # a length 1 sentence in PTB, which do exist. -.- span_mask = span_mask.unsqueeze(-1) if span_labels is not None and span_labels.dim() == 1: span_labels = span_labels.unsqueeze(-1) num_spans = get_lengths_from_binary_sequence_mask(span_mask) encoded_text = self.encoder(embedded_text_input, mask) span_representations = self.span_extractor(encoded_text, spans, mask, span_mask) if self.feedforward_layer is not None: span_representations = self.feedforward_layer(span_representations) logits = self.tag_projection_layer(span_representations) class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1)) output_dict = { "class_probabilities": class_probabilities, "spans": spans, "tokens": [meta["tokens"] for meta in metadata], "pos_tags": [meta.get("pos_tags") for meta in metadata], "num_spans": num_spans } if span_labels is not None: loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask) self.tag_accuracy(class_probabilities, span_labels, span_mask) output_dict["loss"] = loss # The evalb score is expensive to compute, so we only compute # it for the validation and test sets. batch_gold_trees = [meta.get("gold_tree") for meta in metadata] if all(batch_gold_trees) and self._evalb_score is not None and not self.training: gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1] for tree in batch_gold_trees] predicted_trees = self.construct_trees(class_probabilities.cpu().data, spans.cpu().data, num_spans.data, output_dict["tokens"], gold_pos_tags) self._evalb_score(predicted_trees, batch_gold_trees) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Constructs an NLTK ``Tree`` given the scored spans. We also switch to exclusive span ends when constructing the tree representation, because it makes indexing into lists cleaner for ranges of text, rather than individual indices. Finally, for batch prediction, we will have padded spans and class probabilities. In order to make this less confusing, we remove all the padded spans and distributions from ``spans`` and ``class_probabilities`` respectively. """ all_predictions = output_dict['class_probabilities'].cpu().data all_spans = output_dict["spans"].cpu().data all_sentences = output_dict["tokens"] all_pos_tags = output_dict["pos_tags"] if all(output_dict["pos_tags"]) else None num_spans = output_dict["num_spans"].data trees = self.construct_trees(all_predictions, all_spans, num_spans, all_sentences, all_pos_tags) batch_size = all_predictions.size(0) output_dict["spans"] = [all_spans[i, :num_spans[i]] for i in range(batch_size)] output_dict["class_probabilities"] = [all_predictions[i, :num_spans[i], :] for i in range(batch_size)] output_dict["trees"] = trees return output_dict def construct_trees(self, predictions: torch.FloatTensor, all_spans: torch.LongTensor, num_spans: torch.LongTensor, sentences: List[List[str]], pos_tags: List[List[str]] = None) -> List[Tree]: """ Construct ``nltk.Tree``'s for each batch element by greedily nesting spans. The trees use exclusive end indices, which contrasts with how spans are represented in the rest of the model. Parameters ---------- predictions : ``torch.FloatTensor``, required. A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. all_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the span indices we scored. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. sentences : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, optional (default = None). A list of POS tags for each word in the sentence for each element in the batch. Returns ------- A ``List[Tree]`` containing the decoded trees for each element in the batch. """ # Switch to using exclusive end spans. exclusive_end_spans = all_spans.clone() exclusive_end_spans[:, :, -1] += 1 no_label_id = self.vocab.get_token_index("NO-LABEL", "labels") trees: List[Tree] = [] for batch_index, (scored_spans, spans, sentence) in enumerate(zip(predictions, exclusive_end_spans, sentences)): selected_spans = [] for prediction, span in zip(scored_spans[:num_spans[batch_index]], spans[:num_spans[batch_index]]): start, end = span no_label_prob = prediction[no_label_id] label_prob, label_index = torch.max(prediction, -1) # Does the span have a label != NO-LABEL or is it the root node? # If so, include it in the spans that we consider. if int(label_index) != no_label_id or (start == 0 and end == len(sentence)): # TODO(Mark): Remove this once pylint sorts out named tuples. # https://github.com/PyCQA/pylint/issues/1418 selected_spans.append(SpanInformation(start=int(start), # pylint: disable=no-value-for-parameter end=int(end), label_prob=float(label_prob), no_label_prob=float(no_label_prob), label_index=int(label_index))) # The spans we've selected might overlap, which causes problems when we try # to construct the tree as they won't nest properly. consistent_spans = self.resolve_overlap_conflicts_greedily(selected_spans) spans_to_labels = {(span.start, span.end): self.vocab.get_token_from_index(span.label_index, "labels") for span in consistent_spans} sentence_pos = pos_tags[batch_index] if pos_tags is not None else None trees.append(self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos)) return trees @staticmethod def resolve_overlap_conflicts_greedily(spans: List[SpanInformation]) -> List[SpanInformation]: """ Given a set of spans, removes spans which overlap by evaluating the difference in probability between one being labeled and the other explicitly having no label and vice-versa. The worst case time complexity of this method is ``O(k * n^4)`` where ``n`` is the length of the sentence that the spans were enumerated from (and therefore ``k * m^2`` complexity with respect to the number of spans ``m``) and ``k`` is the number of conflicts. However, in practice, there are very few conflicts. Hopefully. This function modifies ``spans`` to remove overlapping spans. Parameters ---------- spans: ``List[SpanInformation]``, required. A list of spans, where each span is a ``namedtuple`` containing the following attributes: start : ``int`` The start index of the span. end : ``int`` The exclusive end index of the span. no_label_prob : ``float`` The probability of this span being assigned the ``NO-LABEL`` label. label_prob : ``float`` The probability of the most likely label. Returns ------- A modified list of ``spans``, with the conflicts resolved by considering local differences between pairs of spans and removing one of the two spans. """ conflicts_exist = True while conflicts_exist: conflicts_exist = False for span1_index, span1 in enumerate(spans): for span2_index, span2 in list(enumerate(spans))[span1_index + 1:]: if (span1.start < span2.start < span1.end < span2.end or span2.start < span1.start < span2.end < span1.end): # The spans overlap. conflicts_exist = True # What's the more likely situation: that span2 was labeled # and span1 was unlabled, or that span1 was labeled and span2 # was unlabled? In the first case, we delete span2 from the # set of spans to form the tree - in the second case, we delete # span1. if (span1.no_label_prob + span2.label_prob < span2.no_label_prob + span1.label_prob): spans.pop(span2_index) else: spans.pop(span1_index) break return spans @staticmethod def construct_tree_from_spans(spans_to_labels: Dict[Tuple[int, int], str], sentence: List[str], pos_tags: List[str] = None) -> Tree: """ Parameters ---------- spans_to_labels : ``Dict[Tuple[int, int], str]``, required. A mapping from spans to constituency labels. sentence : ``List[str]``, required. A list of tokens forming the sentence to be parsed. pos_tags : ``List[str]``, optional (default = None) A list of the pos tags for the words in the sentence, if they were either predicted or taken as input to the model. Returns ------- An ``nltk.Tree`` constructed from the labelled spans. """ def assemble_subtree(start: int, end: int): if (start, end) in spans_to_labels: # Some labels contain nested spans, e.g S-VP. # We actually want to create (S (VP ...)) nodes # for these labels, so we split them up here. labels: List[str] = spans_to_labels[(start, end)].split("-") else: labels = None # This node is a leaf. if end - start == 1: word = sentence[start] pos_tag = pos_tags[start] if pos_tags is not None else "XX" tree = Tree(pos_tag, [word]) if labels is not None and pos_tags is not None: # If POS tags were passed explicitly, # they are added as pre-terminal nodes. while labels: tree = Tree(labels.pop(), [tree]) elif labels is not None: # Otherwise, we didn't want POS tags # at all. tree = Tree(labels.pop(), [word]) while labels: tree = Tree(labels.pop(), [tree]) return [tree] argmax_split = start + 1 # Find the next largest subspan such that # the left hand side is a constituent. for split in range(end - 1, start, -1): if (start, split) in spans_to_labels: argmax_split = split break left_trees = assemble_subtree(start, argmax_split) right_trees = assemble_subtree(argmax_split, end) children = left_trees + right_trees if labels is not None: while labels: children = [Tree(labels.pop(), children)] return children tree = assemble_subtree(0, len(sentence)) return tree[0] @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics = {} all_metrics["tag_accuracy"] = self.tag_accuracy.get_metric(reset=reset) if self._evalb_score is not None: evalb_metrics = self._evalb_score.get_metric(reset=reset) all_metrics.update(evalb_metrics) return all_metrics @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'SpanConstituencyParser': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params) span_extractor = SpanExtractor.from_params(params.pop("span_extractor")) encoder = Seq2SeqEncoder.from_params(params.pop("encoder")) feed_forward_params = params.pop("feedforward", None) if feed_forward_params is not None: feedforward_layer = FeedForward.from_params(feed_forward_params) else: feedforward_layer = None pos_tag_embedding_params = params.pop("pos_tag_embedding", None) if pos_tag_embedding_params is not None: pos_tag_embedding = Embedding.from_params(vocab, pos_tag_embedding_params) else: pos_tag_embedding = None initializer = InitializerApplicator.from_params(params.pop('initializer', [])) regularizer = RegularizerApplicator.from_params(params.pop('regularizer', [])) evalb_directory_path = params.pop("evalb_directory_path", None) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, span_extractor=span_extractor, encoder=encoder, feedforward_layer=feedforward_layer, pos_tag_embedding=pos_tag_embedding, initializer=initializer, regularizer=regularizer, evalb_directory_path=evalb_directory_path)
class BasicClassifier(Model): """ This `Model` implements a basic text classifier. After embedding the text into a text field, we will optionally encode the embeddings with a `Seq2SeqEncoder`. The resulting sequence is pooled using a `Seq2VecEncoder` and then passed to a linear classification layer, which projects into the label space. If a `Seq2SeqEncoder` is not provided, we will pass the embedded text directly to the `Seq2VecEncoder`. Registered as a `Model` with name "basic_classifier". # Parameters vocab : `Vocabulary` text_field_embedder : `TextFieldEmbedder` Used to embed the input text into a `TextField` seq2seq_encoder : `Seq2SeqEncoder`, optional (default=`None`) Optional Seq2Seq encoder layer for the input text. seq2vec_encoder : `Seq2VecEncoder` Required Seq2Vec encoder layer. If `seq2seq_encoder` is provided, this encoder will pool its output. Otherwise, this encoder will operate directly on the output of the `text_field_embedder`. feedforward : `FeedForward`, optional, (default = None). An optional feedforward layer to apply after the seq2vec_encoder. dropout : `float`, optional (default = `None`) Dropout percentage to use. num_labels : `int`, optional (default = `None`) Number of labels to project to in classification layer. By default, the classification layer will project to the size of the vocabulary namespace corresponding to labels. label_namespace : `str`, optional (default = "labels") Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace. initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`) If provided, will be used to initialize the model parameters. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, seq2vec_encoder: Seq2VecEncoder, seq2seq_encoder: Seq2SeqEncoder = None, feedforward: Optional[FeedForward] = None, dropout: float = None, num_labels: int = None, label_namespace: str = "labels", namespace: str = "tokens", initializer: InitializerApplicator = InitializerApplicator(), **kwargs, ) -> None: super().__init__(vocab, **kwargs) self._text_field_embedder = text_field_embedder if seq2seq_encoder: self._seq2seq_encoder = seq2seq_encoder else: self._seq2seq_encoder = None self._seq2vec_encoder = seq2vec_encoder self._feedforward = feedforward if feedforward is not None: self._classifier_input_dim = self._feedforward.get_output_dim() else: self._classifier_input_dim = self._seq2vec_encoder.get_output_dim() if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = None self._label_namespace = label_namespace self._namespace = namespace if num_labels: self._num_labels = num_labels else: self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace) self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels) self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( # type: ignore self, tokens: TextFieldTensors, label: torch.IntTensor = None ) -> Dict[str, torch.Tensor]: """ # Parameters tokens : TextFieldTensors From a `TextField` label : torch.IntTensor, optional (default = None) From a `LabelField` # Returns An output dictionary consisting of: - `logits` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing unnormalized log probabilities of the label. - `probs` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing probabilities of the label. - `loss` : (`torch.FloatTensor`, optional) : A scalar loss to be optimised. """ embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens) if self._seq2seq_encoder: embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) if self._dropout: embedded_text = self._dropout(embedded_text) if self._feedforward is not None: embedded_text = self._feedforward(embedded_text) logits = self._classification_layer(embedded_text) probs = torch.nn.functional.softmax(logits, dim=-1) output_dict = {"logits": logits, "probs": probs} output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(tokens) if label is not None: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss self._accuracy(logits, label) return output_dict @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: """ Does a simple argmax over the probabilities, converts index to string label, and add `"label"` key to the dictionary with the result. """ predictions = output_dict["probs"] if predictions.dim() == 2: predictions_list = [predictions[i] for i in range(predictions.shape[0])] else: predictions_list = [predictions] classes = [] for prediction in predictions_list: label_idx = prediction.argmax(dim=-1).item() label_str = self.vocab.get_index_to_token_vocabulary(self._label_namespace).get( label_idx, str(label_idx) ) classes.append(label_str) output_dict["label"] = classes tokens = [] for instance_tokens in output_dict["token_ids"]: tokens.append( [ self.vocab.get_token_from_index(token_id.item(), namespace=self._namespace) for token_id in instance_tokens ] ) output_dict["tokens"] = tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = {"accuracy": self._accuracy.get_metric(reset)} return metrics
class DecomposableAttention(Model): """ This ``Model`` implements the Decomposable Attention model described in `"A Decomposable Attention Model for Natural Language Inference" <https://www.semanticscholar.org/paper/A-Decomposable-Attention-Model-for-Natural-Languag-Parikh-T%C3%A4ckstr%C3%B6m/07a9478e87a8304fc3267fa16e83e9f3bbd98b27>`_ by Parikh et al., 2016, with some optional enhancements before the decomposable attention actually happens. Parikh's original model allowed for computing an "intra-sentence" attention before doing the decomposable entailment step. We generalize this to any :class:`Seq2SeqEncoder` that can be applied to the premise and/or the hypothesis before computing entailment. The basic outline of this model is to get an embedded representation of each word in the premise and hypothesis, align words between the two, compare the aligned phrases, and make a final entailment decision based on this aggregated comparison. Each step in this process uses a feedforward network to modify the representation. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the model. attend_feedforward : ``FeedForward`` This feedforward network is applied to the encoded sentence representations before the similarity matrix is computed between words in the premise and words in the hypothesis. similarity_function : ``SimilarityFunction`` This is the similarity function used when computing the similarity matrix between words in the premise and words in the hypothesis. compare_feedforward : ``FeedForward`` This feedforward network is applied to the aligned premise and hypothesis representations, individually. aggregate_feedforward : ``FeedForward`` This final feedforward network is applied to the concatenated, summed result of the ``compare_feedforward`` network, and its output is used as the entailment class logits. premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``) After embedding the premise, we can optionally apply an encoder. If this is ``None``, we will do nothing. hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``) After embedding the hypothesis, we can optionally apply an encoder. If this is ``None``, we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder`` is also ``None``). initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, attend_feedforward: FeedForward, similarity_function: SimilarityFunction, compare_feedforward: FeedForward, aggregate_feedforward: FeedForward, premise_encoder: Optional[Seq2SeqEncoder] = None, hypothesis_encoder: Optional[Seq2SeqEncoder] = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(DecomposableAttention, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._attend_feedforward = TimeDistributed(attend_feedforward) self._matrix_attention = LegacyMatrixAttention(similarity_function) self._compare_feedforward = TimeDistributed(compare_feedforward) self._aggregate_feedforward = aggregate_feedforward self._premise_encoder = premise_encoder self._hypothesis_encoder = hypothesis_encoder or premise_encoder self._num_labels = vocab.get_vocab_size(namespace="labels") check_dimensions_match(text_field_embedder.get_output_dim(), attend_feedforward.get_input_dim(), "text field embedding dim", "attend feedforward input dim") check_dimensions_match(aggregate_feedforward.get_output_dim(), self._num_labels, "final output dimension", "number of labels") self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward(self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional, (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward(hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs, "h2p_attention": h2p_attention, "p2h_attention": p2h_attention} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss if metadata is not None: output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata] output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata] return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'accuracy': self._accuracy.get_metric(reset), }
class QuarelSemanticParser(Model): """ A ``QuarelSemanticParser`` is a variant of ``WikiTablesSemanticParser`` with various tweaks and changes. Parameters ---------- vocab : ``Vocabulary`` question_embedder : ``TextFieldEmbedder`` Embedder for questions. action_embedding_dim : ``int`` Dimension to use for action embeddings. encoder : ``Seq2SeqEncoder`` The encoder to use for the input question. decoder_beam_search : ``BeamSearch`` When we're not training, this is how we will do decoding. max_decoding_steps : ``int`` When we're decoding with a beam search, what's the maximum number of steps we should take? This only applies at evaluation time, not during training. attention : ``Attention`` We compute an attention over the input question at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function. dropout : ``float``, optional (default=0) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_linking_features : ``int``, optional (default=10) We need to construct a parameter vector for the linking features, so we need to know how many there are. The default of 8 here matches the default in the ``KnowledgeGraphField``, which is to use all eight defined features. If this is 0, another term will be added to the linking score. This term contains the maximum similarity value from the entity's neighbors and the question. use_entities : ``bool``, optional (default=False) Whether dynamic entities are part of the action space num_entity_bits : ``int``, optional (default=0) Whether any bits are added to encoder input/output to represent tagged entities entity_bits_output : ``bool``, optional (default=False) Whether entity bits are added to the encoder output or input denotation_only : ``bool``, optional (default=False) Whether to only predict target denotation, skipping the the whole logical form decoder entity_similarity_mode : ``str``, optional (default="dot_product") How to compute vector similarity between question and entity tokens, can take values "dot_product" or "weighted_dot_product" (learned weights on each dimension) rule_namespace : ``str``, optional (default=rule_labels) The vocabulary namespace to use for production rules. The default corresponds to the default used in the dataset reader, so you likely don't need to modify this. """ def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, decoder_beam_search: BeamSearch, max_decoding_steps: int, attention: Attention, mixture_feedforward: FeedForward = None, add_action_bias: bool = True, dropout: float = 0.0, num_linking_features: int = 0, num_entity_bits: int = 0, entity_bits_output: bool = True, use_entities: bool = False, denotation_only: bool = False, # Deprecated parameter to load older models entity_encoder: Seq2VecEncoder = None, # pylint: disable=unused-argument entity_similarity_mode: str = "dot_product", rule_namespace: str = 'rule_labels') -> None: super(QuarelSemanticParser, self).__init__(vocab) self._question_embedder = question_embedder self._encoder = encoder self._beam_search = decoder_beam_search self._max_decoding_steps = max_decoding_steps if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._denotation_accuracy = Average() self._action_sequence_accuracy = Average() self._has_logical_form = Average() self._embedding_dim = question_embedder.get_output_dim() self._use_entities = use_entities # Note: there's only one non-trivial entity type in QuaRel for now, so most of the # entity_type stuff is irrelevant self._num_entity_types = 4 # TODO(mattg): get this in a more principled way somehow? self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim) self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim) self._entity_similarity_layer = None self._entity_similarity_mode = entity_similarity_mode if self._entity_similarity_mode == "weighted_dot_product": self._entity_similarity_layer = \ TimeDistributed(torch.nn.Linear(self._embedding_dim, 1, bias=False)) # Center initial values around unweighted dot product self._entity_similarity_layer._module.weight.data += 1 # pylint: disable=protected-access elif self._entity_similarity_mode == "dot_product": pass else: raise ValueError("Invalid entity_similarity_mode: {}".format(self._entity_similarity_mode)) if num_linking_features > 0: self._linking_params = torch.nn.Linear(num_linking_features, 1) else: self._linking_params = None self._decoder_trainer = MaximumMarginalLikelihood() self._encoder_output_dim = self._encoder.get_output_dim() if entity_bits_output: self._encoder_output_dim += num_entity_bits self._entity_bits_output = entity_bits_output self._debug_count = 10 self._num_denotation_cats = 2 # Hardcoded for simplicity self._denotation_only = denotation_only if self._denotation_only: self._denotation_accuracy_cat = CategoricalAccuracy() self._denotation_classifier = torch.nn.Linear(self._encoder_output_dim, self._num_denotation_cats) # Rest of init not needed for denotation only where no decoding to actions needed return self._action_padding_index = -1 # the padding value used by IndexField num_actions = vocab.get_vocab_size(self._rule_namespace) self._num_actions = num_actions self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) # We are tying the action embeddings used for input and output # self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) self._output_action_embedder = self._action_embedder # tied weights self._add_action_bias = add_action_bias if self._add_action_bias: self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1) # This is what we pass as input in the first step of decoding, when we don't have a # previous action, or a previous question attention. self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(self._encoder_output_dim)) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_question) self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=attention, num_start_types=self._num_start_types, predict_start_type_separately=False, add_action_bias=self._add_action_bias, mixture_feedforward=mixture_feedforward, dropout=dropout) @overrides def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[QuarelWorld], actions: List[List[ProductionRule]], entity_bits: torch.Tensor = None, denotation_target: torch.Tensor = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[QuarelWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ table_text = table['text'] self._debug_count -= 1 # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table) if self._use_entities: if self._entity_similarity_mode == "dot_product": # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score elif self._entity_similarity_mode == "weighted_dot_product": embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13) eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1) ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim) ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1) product = torch.mul(eqe, ete) product = product.view(batch_size, num_question_tokens*num_entities*num_entity_tokens, self._embedding_dim) question_entity_similarity = self._entity_similarity_layer(product) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) linking_scores = question_entity_similarity_max_score # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) encoder_input = embedded_question else: if entity_bits is not None and not self._entity_bits_output: encoder_input = torch.cat([embedded_question, entity_bits], 2) else: encoder_input = embedded_question # Fake linking_scores added for downstream code to not object linking_scores = question_mask.clone().fill_(0).unsqueeze(1) linking_probabilities = None # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) if self._entity_bits_output and entity_bits is not None: encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) # For predicting a categorical denotation directly if self._denotation_only: denotation_logits = self._denotation_classifier(final_encoder_output) loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1)) self._denotation_accuracy_cat(denotation_logits, denotation_target) return {"loss": loss} memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim) _, num_entities, num_question_tokens = linking_scores.size() if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=None, debug_info=None) if self.training: outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) return outputs else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] if self._linking_params is not None: outputs['linking_scores'] = linking_scores outputs['feature_scores'] = feature_scores outputs['linking_features'] = linking_features if self._use_entities: outputs['linking_probabilities'] = linking_probabilities if entity_bits is not None: outputs['entity_bits'] = entity_bits # outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] outputs['denotation_acc'] = [] outputs['score'] = [] outputs['parse_acc'] = [] outputs['answer_index'] = [] if metadata is not None: outputs['question_tokens'] = [] outputs['world_extractions'] = [] for i in range(batch_size): if metadata is not None: outputs['question_tokens'].append(metadata[i].get('question_tokens', [])) if metadata is not None: outputs['world_extractions'].append(metadata[i].get('world_extractions', {})) outputs['entities'].append(world[i].table_graph.entities) # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] sequence_in_targets = 0 if target_action_sequences is not None: targets = target_action_sequences[i].data sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form(action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' denotation_accuracy = 0.0 predicted_answer_index = world[i].execute(logical_form) if metadata is not None and 'answer_index' in metadata[i]: answer_index = metadata[i]['answer_index'] denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index) self._denotation_accuracy(denotation_accuracy) score = math.exp(best_final_states[i][0].score[0].data.cpu().item()) outputs['answer_index'].append(predicted_answer_index) outputs['score'].append(score) outputs['parse_acc'].append(sequence_in_targets) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['denotation_acc'].append(denotation_accuracy) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore else: outputs['parse_acc'].append(0) outputs['logical_form'].append('') outputs['denotation_acc'].append(0) outputs['score'].append(0) outputs['answer_index'].append(-1) outputs['best_action_sequence'].append([]) outputs['debug_info'].append([]) self._has_logical_form(0.0) return outputs @staticmethod def _get_type_vector(worlds: List[QuarelWorld], num_entities: int, tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces a tensor with shape ``(batch_size, num_entities)`` that encodes each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[WikiTablesWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate(world.table_graph.entities): # We need numbers to be first, then cells, then parts, then row, because our # entities are going to be sorted. We do a split by type and then a merge later, # and it relies on this sorting. if entity.startswith('fb:cell'): entity_type = 1 elif entity.startswith('fb:part'): entity_type = 2 elif entity.startswith('fb:row'): entity_type = 3 else: entity_type = 0 types.append(entity_type) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: 0) batch_types.append(padded) return tensor.new_tensor(batch_types, dtype=torch.long), entity_types def _get_linking_probabilities(self, worlds: List[QuarelWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[QuarelWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "fb:cell", and "fb:cell" comes before "fb:row". This is not a great # assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select(1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros(num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item() def _denotation_match(self, predicted_answer_index: int, target_answer_index: int) -> float: if predicted_answer_index < 0: # Logical form doesn't properly resolve, we do random guess with appropriate credit return 1.0/self._num_denotation_cats elif predicted_answer_index == target_answer_index: return 1.0 return 0.0 @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ We track three metrics here: 1. parse_acc, which is the percentage of the time that our best output action sequence corresponds to a correct logical form 2. denotation_acc, which is the percentage of examples where we get the correct denotation, including spurious correct answers using the wrong logical form 3. lf_percent, which is the percentage of time that decoding actually produces a finished logical form. We might not produce a valid logical form if the decoder gets into a repetitive loop, or we're trying to produce a super long logical form and run out of time steps, or something. """ if self._denotation_only: metrics = {'denotation_acc': self._denotation_accuracy_cat.get_metric(reset)} else: metrics = { 'parse_acc': self._action_sequence_accuracy.get_metric(reset), 'denotation_acc': self._denotation_accuracy.get_metric(reset), 'lf_percent': self._has_logical_form.get_metric(reset), } return metrics def _create_grammar_state(self, world: QuarelWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, entity_types: torch.Tensor) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``QuarelWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index entity_map = {} for entity_index, entity in enumerate(world.table_graph.entities): entity_map[entity] = entity_index valid_actions = world.get_valid_actions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [action_map[action_string] for action_string in action_strings] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append((production_rule_array[2], action_index)) else: linked_actions.append((production_rule_array[0], action_index)) # Then we get the embedded representations of the global actions. global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder(global_action_tensor) if self._add_action_bias: global_action_biases = self._action_biases(global_action_tensor) global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._output_action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_output_embeddings, list(global_action_ids)) # Then the representations of the linked actions. if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(' -> ')[1] for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] # (num_linked_actions, num_question_tokens) entity_linking_scores = linking_scores[entity_ids] # (num_linked_actions,) entity_type_tensor = entity_types[entity_ids] # (num_linked_actions, entity_type_embedding_dim) entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor) translated_valid_actions[key]['linked'] = (entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) return GrammarStatelet([START_SYMBOL], translated_valid_actions, type_declaration.is_nonterminal) @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" in "encoder/decoder", where that decoder logic lives in ``FrictionQDecoderStep``. This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. """ action_mapping = output_dict['action_mapping'] best_actions = output_dict["best_action_sequence"] debug_infos = output_dict['debug_info'] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip(predicted_actions, debug_info): action_info = {} action_info['predicted_action'] = predicted_action considered_actions = action_debug_info['considered_actions'] probabilities = action_debug_info['probabilities'] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info['considered_actions'] = considered_actions action_info['action_probabilities'] = probabilities action_info['question_attention'] = action_debug_info.get('question_attention', []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict
class AclClassifier(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, node_embedder: TokenEmbedder, verbose_metrics: False, classifier_feedforward: FeedForward, use_node_vector: bool = True, use_abstract: bool = True, dropout: float = 0.2, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(AclClassifier, self).__init__(vocab, regularizer) self.node_embedder = node_embedder self.text_field_embedder = text_field_embedder self.use_node_vector = use_node_vector self.use_abstract = use_abstract self.dropout = torch.nn.Dropout(dropout) self.num_classes = self.vocab.get_vocab_size("labels") self.classifier_feedforward = classifier_feedforward self.label_accuracy = CategoricalAccuracy() self.label_f1_metrics = {} self.verbose_metrics = verbose_metrics for i in range(self.num_classes): label_name = vocab.get_token_from_index(index=i, namespace="labels") self.label_f1_metrics[label_name] = F1Measure(positive_label=i) self.loss = torch.nn.CrossEntropyLoss() initializer(self) @overrides def forward(self, abstract: Dict[str, torch.LongTensor], paper_id: torch.LongTensor, label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: if self.use_abstract and self.use_node_vector: embedding = torch.cat([ self.text_field_embedder(abstract)[:, 0, :], self.node_embedder(paper_id)], dim=-1) elif self.use_abstract: embedding = self.text_field_embedder(abstract)[:, 0, :] elif self.use_node_vector: embedding = self.node_embedder(paper_id) else: embedding = self.node_embedder(paper_id) logits = self.classifier_feedforward(self.dropout(embedding)) class_probs = F.softmax(logits, dim=1) output_dict = {"logits": logits} if label is not None: loss = self.loss(logits, label) output_dict["label"] = label output_dict["loss"] = loss for i in range(self.num_classes): label_name = self.vocab.get_token_from_index(index=i, namespace="labels") metric = self.label_f1_metrics[label_name] metric(class_probs, label) self.label_accuracy(logits, label) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: class_probs = F.softmax(output_dict['logits'], dim=-1) output_dict['pred_label'] = [ self.vocab.get_token_from_index(index=int(np.argmax(probs)), namespace="labels") for probs in class_probs.cpu() ] output_dict['label'] = [ self.vocab.get_token_from_index(index=int(label), namespace="labels") for label in output_dict['label'].cpu() ] return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metric_dict = {} sum_f1 = 0.0 for name, metric in self.label_f1_metrics.items(): metric_val = metric.get_metric(reset) if self.verbose_metrics: metric_dict[name + '_P'] = metric_val[0] metric_dict[name + '_R'] = metric_val[1] metric_dict[name + '_F1'] = metric_val[2] sum_f1 += metric_val[2] names = list(self.label_f1_metrics.keys()) total_len = len(names) if total_len > 0: average_f1 = sum_f1 / total_len else: average_f1 = 0.0 metric_dict['average_F1'] = average_f1 metric_dict['accuracy'] = self.label_accuracy.get_metric(reset) return metric_dict
class DialogQA(Model): """ This class implements modified version of BiDAF (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. In this set-up, a single instance is a dialog, list of question answer pairs. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. span_start_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span end predictions into the passage state. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). num_context_answers : ``int``, optional (default=0) If greater than 0, the model will consider previous question answering context. max_span_length: ``int``, optional (default=0) Maximum token length of the output span. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, initializer: InitializerApplicator, dropout: float = 0.2, num_context_answers: int = 0, marker_embedding_dim: int = 10, max_span_length: int = 25) -> None: super().__init__(vocab) self._num_context_answers = num_context_answers self._max_span_length = max_span_length self._text_field_embedder = text_field_embedder self._phrase_layer = phrase_layer self._marker_embedding_dim = marker_embedding_dim self._encoding_dim = phrase_layer.get_output_dim() max_turn_length = 15 self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._merge_atten = TimeDistributed( torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim)) self._residual_encoder = residual_encoder if num_context_answers > 0: self._question_num_marker = torch.nn.Embedding( max_turn_length, marker_embedding_dim * num_context_answers) self._prev_ans_marker = torch.nn.Embedding( (num_context_answers * 4) + 1, marker_embedding_dim) self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._followup_lin = torch.nn.Linear(self._encoding_dim, 3) self._merge_self_attention = TimeDistributed( torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim)) self._span_start_encoder = span_start_encoder self._span_end_encoder = span_end_encoder self._span_start_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_end_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 1)) self._span_yesno_predictor = TimeDistributed( torch.nn.Linear(self._encoding_dim, 3)) self._span_followup_predictor = TimeDistributed(self._followup_lin) check_dimensions_match( phrase_layer.get_input_dim(), text_field_embedder.get_output_dim() + marker_embedding_dim * num_context_answers, "phrase layer input dim", "embedding dim + marker dim * num context answers") initializer(self) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_yesno_accuracy = CategoricalAccuracy() self._span_followup_accuracy = CategoricalAccuracy() self._span_gt_yesno_accuracy = CategoricalAccuracy() self._span_gt_followup_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question[ 'token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape( total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout( self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat( 1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view( total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector( max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat( 1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat( batch_size, 1, 1) question_num_ind = question_num_ind.reshape( total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker( question_num_ind) embedded_question = torch.cat( [embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view( total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view( total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker( p3_answer_marker) repeated_embedded_passage = torch.cat( [repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout( self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout( self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view( total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout( self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention( repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax( question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum( repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector ], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout( self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([ self_attention_vecs, residual_layer, residual_layer * self_attention_vecs ], dim=-1) residual_layer = F.relu( self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder( torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze( -1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().cpu().detach( ).numpy().reshape(total_qa_count) # print("span_end = {}, type(span_end)={} total_qa_count = {}".format(span_end, type(span_end), total_qa_count)) print("span_end.shape = {}".format(span_end.shape)) for i in range(0, total_qa_count): gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) # print("i = {}, gold_span_end_loc = {}".format(i, gold_span_end_loc)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append( squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \ for yn_list in output_dict.pop("yesno")] followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \ for followup_list in output_dict.pop("followup")] output_dict['yesno'] = yesno_tags output_dict['followup'] = followup_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'yesno': self._span_yesno_accuracy.get_metric(reset), 'followup': self._span_followup_accuracy.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, span_followup_logits: torch.Tensor, max_span_length: int) -> torch.Tensor: # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as # yesno prediction bit and followup prediction bit from the predicted span end token. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() span_followup_logits = span_followup_logits.data.cpu().numpy() for b_i in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) followup_pred = np.argmax(span_followup_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) best_word_span[b_i, 3] = int(followup_pred) return best_word_span
class Baseline(Model): def __init__(self, word_embeddings: TextFieldEmbedder, vocab: Vocabulary) -> None: super().__init__(vocab) self.word_embeddings = word_embeddings self.text_seq_encoder = PytorchSeq2VecWrapper(LSTM(word_embeddings.get_output_dim(), int(word_embeddings.get_output_dim()/2), batch_first=True, bidirectional=True)) self.out = torch.nn.Linear( in_features=self.text_seq_encoder.get_output_dim()*4, out_features=vocab.get_vocab_size('labels') ) self.accuracy = CategoricalAccuracy() self.f_score_0 = F1Measure(positive_label=0) self.f_score_1 = F1Measure(positive_label=1) self.f_score_2 = F1Measure(positive_label=2) self.loss = CrossEntropyLoss() def forward(self, article: Dict[str, torch.Tensor], outcome: Dict[str, torch.Tensor], intervention: Dict[str, torch.Tensor], comparator: Dict[str, torch.Tensor], labels: torch.Tensor = None) -> Dict[str, torch.Tensor]: a_mask = get_text_field_mask(article) a_embeddings = self.word_embeddings(article) a_vec = self.text_seq_encoder(a_embeddings, a_mask) o_mask = get_text_field_mask(outcome) o_embeddings = self.word_embeddings(outcome) o_vec = self.text_seq_encoder(o_embeddings, o_mask) i_mask = get_text_field_mask(intervention) i_embeddings = self.word_embeddings(intervention) i_vec = self.text_seq_encoder(i_embeddings, i_mask) c_mask = get_text_field_mask(comparator) c_embeddings = self.word_embeddings(comparator) c_vec = self.text_seq_encoder(c_embeddings, c_mask) logits = self.out(torch.cat((a_vec, o_vec, i_vec, c_vec), dim=1)) output = {'logits': logits} if labels is not None: self.accuracy(logits, labels) self.f_score_0(logits, labels) self.f_score_1(logits, labels) self.f_score_2(logits, labels) output['loss'] = self.loss(logits, labels) return output def get_metrics(self, reset: bool = False) -> Dict[str, float]: _, _, f_score0 = self.f_score_0.get_metric(reset) _, _, f_score1 = self.f_score_1.get_metric(reset) _, _, f_score2 = self.f_score_2.get_metric(reset) return {'accuracy': self.accuracy.get_metric(reset), 'f-score': np.mean([f_score0, f_score1, f_score2])}
class ESIM(Model): """ This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference" <https://www.semanticscholar.org/paper/Enhanced-LSTM-for-Natural-Language-Inference-Chen-Zhu/83e7654d545fbbaaf2328df365a781fb67b841b4>`_ by Chen et al., 2017. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the model. encoder : ``Seq2SeqEncoder`` Used to encode the premise and hypothesis. similarity_function : ``SimilarityFunction`` This is the similarity function used when computing the similarity matrix between encoded words in the premise and words in the hypothesis. projection_feedforward : ``FeedForward`` The feedforward network used to project down the encoded and enhanced premise and hypothesis. inference_encoder : ``Seq2SeqEncoder`` Used to encode the projected premise and hypothesis for prediction. output_feedforward : ``FeedForward`` Used to prepare the concatenated premise and hypothesis for prediction. output_logit : ``FeedForward`` This feedforward network computes the output logits. dropout : ``float``, optional (default=0.5) Dropout percentage to use. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, similarity_function: SimilarityFunction, projection_feedforward: FeedForward, inference_encoder: Seq2SeqEncoder, output_feedforward: FeedForward, output_logit: FeedForward, dropout: float = 0.5, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._encoder = encoder self._matrix_attention = LegacyMatrixAttention(similarity_function) self._projection_feedforward = projection_feedforward self._inference_encoder = inference_encoder if dropout: self.dropout = torch.nn.Dropout(dropout) self.rnn_input_dropout = InputVariationalDropout(dropout) else: self.dropout = None self.rnn_input_dropout = None self._output_feedforward = output_feedforward self._output_logit = output_logit self._num_labels = vocab.get_vocab_size(namespace="labels") check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), "text field embedding dim", "encoder input dim") check_dimensions_match(encoder.get_output_dim() * 4, projection_feedforward.get_input_dim(), "encoder output dim", "projection feedforward input") check_dimensions_match(projection_feedforward.get_output_dim(), inference_encoder.get_input_dim(), "proj feedforward output dim", "inference lstm input dim") self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward(self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) # encode premise and hypothesis encoded_premise = self._encoder(embedded_premise, premise_mask) encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat( [encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis], dim=-1 ) hypothesis_enhanced = torch.cat( [encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise], dim=-1 ) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_premise = self._projection_feedforward(premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max, _ = replace_masked_values( v_ai, premise_mask.unsqueeze(-1), -1e7 ).max(dim=1) v_b_max, _ = replace_masked_values( v_bi, hypothesis_mask.unsqueeze(-1), -1e7 ).max(dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum( premise_mask, 1, keepdim=True ) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True ) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {'accuracy': self._accuracy.get_metric(reset)}
class MultiGranularityHierarchicalAttentionFusionNetworks(Model): def __init__( self, vocab: Vocabulary, elmo_embedder: TextFieldEmbedder, tokens_embedder: TextFieldEmbedder, features_embedder: TextFieldEmbedder, phrase_layer: Seq2SeqEncoder, projected_layer: Seq2SeqEncoder, contextual_passage: Seq2SeqEncoder, contextual_question: Seq2SeqEncoder, dropout: float = 0.2, regularizer: Optional[RegularizerApplicator] = None, initializer: InitializerApplicator = InitializerApplicator(), ): super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer) self.elmo_embedder = elmo_embedder self.tokens_embedder = tokens_embedder self.features_embedder = features_embedder self._phrase_layer = phrase_layer self._encoding_dim = self._phrase_layer.get_output_dim() self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024, self._encoding_dim) self.fuse_p = FusionLayer(self._encoding_dim) self.fuse_q = FusionLayer(self._encoding_dim) self.fuse_s = FusionLayer(self._encoding_dim) self.projected_lstm = projected_layer self.contextual_layer_p = contextual_passage self.contextual_layer_q = contextual_question self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1) # self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') self._self_attention = BilinearMatrixAttention(self._encoding_dim, self._encoding_dim) self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim, self._encoding_dim) self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim, self._encoding_dim) self.yesno_predictor = FeedForward(self._encoding_dim, self._encoding_dim, 3) self.relu = torch.nn.ReLU() self._max_span_length = 30 self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() self._span_yesno_accuracy = CategoricalAccuracy() self._official_f1 = Average() self._variational_dropout = InputVariationalDropout(dropout) self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, yesno_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, max_qa_count, max_q_len, _ = question[ 'token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(yesno_list, 0).view(total_qa_count) # GloVe and simple cnn char embedding, embedding dim = 100 + 100 = 200 word_emb_ques = self.tokens_embedder( question, num_wrapping_dims=1).reshape(total_qa_count, max_q_len, self.tokens_embedder.get_output_dim()) word_emb_pass = self.tokens_embedder(passage) # Elmo embedding, embedding dim = 1024 elmo_ques = self.elmo_embedder(question, num_wrapping_dims=1).reshape( total_qa_count, max_q_len, self.elmo_embedder.get_output_dim()) elmo_pass = self.elmo_embedder(passage) # Passage features embedding, embedding dim = 20 + 20 = 40 pass_feat = self.features_embedder(passage) # GloVe + cnn + Elmo embedded_question = self._variational_dropout( torch.cat([word_emb_ques, elmo_ques], dim=2)) embedded_passage = self._variational_dropout( torch.cat([word_emb_pass, elmo_pass], dim=2)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat( 1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view( total_qa_count, passage_length) # Concatenate Elmo after encoded passage encode_passage = self._phrase_layer(embedded_passage, passage_mask) projected_passage = self.relu( self.projected_layer(torch.cat([encode_passage, elmo_pass], dim=2))) # Concatenate Elmo after encoded question encode_question = self._phrase_layer(embedded_question, question_mask) projected_question = self.relu( self.projected_layer(torch.cat([encode_question, elmo_ques], dim=2))) encoded_passage = self._variational_dropout(projected_passage) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat( 1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view( total_qa_count, passage_length, self._encoding_dim) repeated_pass_feat = (pass_feat.unsqueeze(1).repeat( 1, max_qa_count, 1, 1)).view(total_qa_count, passage_length, 40) encoded_question = self._variational_dropout(projected_question) # total_qa_count * max_q_len * passage_length # cnt * m * n s = torch.bmm(encoded_question, repeated_encoded_passage.transpose(2, 1)) alpha = util.masked_softmax(s, question_mask.unsqueeze(2).expand( s.size()), dim=1) # cnt * n * h aligned_p = torch.bmm(alpha.transpose(2, 1), encoded_question) # cnt * m * n beta = util.masked_softmax(s, repeated_passage_mask.unsqueeze(1).expand( s.size()), dim=2) # cnt * m * h aligned_q = torch.bmm(beta, repeated_encoded_passage) fused_p = self.fuse_p(repeated_encoded_passage, aligned_p) fused_q = self.fuse_q(encoded_question, aligned_q) # add manual features here q_aware_p = self._variational_dropout( self.projected_lstm( torch.cat([fused_p, repeated_pass_feat], dim=2), repeated_passage_mask)) # cnt * n * n # self_p = torch.bmm(q_aware_p, q_aware_p.transpose(2, 1)) # self_p = self.bilinear_self_align(q_aware_p) self_p = self._self_attention(q_aware_p, q_aware_p) mask = repeated_passage_mask.reshape( total_qa_count, passage_length, 1) * repeated_passage_mask.reshape( total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_p.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) lamb = util.masked_softmax(self_p, mask, dim=2) # lamb = util.masked_softmax(self_p, repeated_passage_mask, dim=2) # cnt * n * h self_aligned_p = torch.bmm(lamb, q_aware_p) # cnt * n * h fused_self_p = self.fuse_s(q_aware_p, self_aligned_p) contextual_p = self._variational_dropout( self.contextual_layer_p(fused_self_p, repeated_passage_mask)) # contextual_p = self.contextual_layer_p(fused_self_p, repeated_passage_mask) contextual_q = self._variational_dropout( self.contextual_layer_q(fused_q, question_mask)) # contextual_q = self.contextual_layer_q(fused_q, question_mask) # cnt * m gamma = util.masked_softmax( self.linear_self_align(contextual_q).squeeze(2), question_mask, dim=1) # cnt * h weighted_q = torch.bmm(gamma.unsqueeze(1), contextual_q).squeeze(1) span_start_logits = self.bilinear_layer_s(weighted_q, contextual_p) span_end_logits = self.bilinear_layer_e(weighted_q, contextual_p) # cnt * n * 1 cnt * 1 * h span_yesno_logits = self.yesno_predictor( torch.bmm(span_end_logits.unsqueeze(2), weighted_q.unsqueeze(1))) # span_yesno_logits = self.yesno_predictor(contextual_p) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss for training if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view( total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append( max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append( max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select( 0, gold_span_end_loc).view(-1, 3) loss += nll_loss(torch.nn.functional.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select( 0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] per_dialog_yesno_list.append(yesno_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append( squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths( squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) return output_dict def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: yesno_tags = [[ self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list ] for yn_list in output_dict.pop("yesno")] output_dict['yesno'] = yesno_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'yesno': self._span_yesno_accuracy.get_metric(reset), 'f1': self._official_f1.get_metric(reset), } @staticmethod def _get_best_span_yesno_followup(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor, span_yesno_logits: torch.Tensor, max_span_length: int) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 3), dtype=torch.long) span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() span_yesno_logits = span_yesno_logits.data.cpu().numpy() for b_i in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b_i, span_start_argmax[b_i]] if val1 < span_start_logits[b_i, j]: span_start_argmax[b_i] = j val1 = span_start_logits[b_i, j] val2 = span_end_logits[b_i, j] if val1 + val2 > max_span_log_prob[b_i]: if j - span_start_argmax[b_i] > max_span_length: continue best_word_span[b_i, 0] = span_start_argmax[b_i] best_word_span[b_i, 1] = j max_span_log_prob[b_i] = val1 + val2 for b_i in range(batch_size): j = best_word_span[b_i, 1] yesno_pred = np.argmax(span_yesno_logits[b_i, j]) best_word_span[b_i, 2] = int(yesno_pred) return best_word_span
class IntentParamClassifier(Model): """ This ``Model`` performs intent classification for user_utterance. We assume we're given a user_utterance, a prev_user_utterance and a prev_sys_utterance, and we predict some output label for intent. The basic model structure: we'll embed the user_utterance, the prev_user_utterance and the prev_sys_utterance, and encode each of them with separate Seq2VecEncoders, getting a single vector representing the content of each. We'll then concatenate those three vectors, and pass the result through a feedforward network, the output of which we'll use as our scores for each label. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. user_utterance_encoder : ``Seq2VecEncoder`` The encoder that we will use to convert the user_utterance to a vector. prev_user_utterance_encoder : ``Seq2VecEncoder`` The encoder that we will use to convert the prev_user_utterance to a vector. prev_sys_utterance_encoder : ``Seq2VecEncoder`` The encoder that we will use to convert the prev_sys_utterance to a vector. classifier_feedforward : ``FeedForward`` initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, user_utterance_encoder: Seq2VecEncoder, prev_user_utterance_encoder: Seq2VecEncoder, prev_sys_utterance_encoder: Seq2VecEncoder, classifier_feedforward: FeedForward, encoder: Seq2SeqEncoder, calculate_span_f1: bool = None, tag_encoding: Optional[str] = None, tag_namespace: str = "tags", verbose_metrics: bool = False, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(IntentParamClassifier, self).__init__(vocab, regularizer) # Intent task self.text_field_embedder = text_field_embedder self.label_num_classes = self.vocab.get_vocab_size("labels") self.user_utterance_encoder = user_utterance_encoder self.prev_user_utterance_encoder = prev_user_utterance_encoder self.prev_sys_utterance_encoder = prev_sys_utterance_encoder self.classifier_feedforward = classifier_feedforward if text_field_embedder.get_output_dim() != user_utterance_encoder.get_input_dim(): raise ConfigurationError("The output dimension of the text_field_embedder must match the " "input dimension of the user_utterance_encoder. Found {} and {}, " "respectively.".format(text_field_embedder.get_output_dim(), user_utterance_encoder.get_input_dim())) if text_field_embedder.get_output_dim() != prev_user_utterance_encoder.get_input_dim(): raise ConfigurationError("The output dimension of the text_field_embedder must match the " "input dimension of the prev_user_utterance_encoder. Found {} and {}, " "respectively.".format(text_field_embedder.get_output_dim(), prev_user_utterance_encoder.get_input_dim())) if text_field_embedder.get_output_dim() != prev_sys_utterance_encoder.get_input_dim(): raise ConfigurationError("The output dimension of the text_field_embedder must match the " "input dimension of the prev_sys_utterance_encoder. Found {} and {}, " "respectively.".format(text_field_embedder.get_output_dim(), prev_sys_utterance_encoder.get_input_dim())) self.label_accuracy = CategoricalAccuracy() self.label_f1_metrics = {} for i in range(self.label_num_classes): self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] = F1Measure(positive_label=i) self.loss = torch.nn.CrossEntropyLoss() # Param task self.tag_namespace = tag_namespace self.tag_num_classes = self.vocab.get_vocab_size(tag_namespace) self.encoder = encoder self._verbose_metrics = verbose_metrics self.tag_projection_layer = TimeDistributed(Linear(self.encoder.get_output_dim(), self.tag_num_classes)) check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), "text field embedding dim", "encoder input dim") # We keep calculate_span_f1 as a constructor argument for API consistency with # the CrfTagger, even it is redundant in this class # (tag_encoding serves the same purpose). if calculate_span_f1 and not tag_encoding: raise ConfigurationError("calculate_span_f1 is True, but " "no tag_encoding was specified.") self.tag_accuracy = CategoricalAccuracy() if calculate_span_f1 or tag_encoding: self._f1_metric = SpanBasedF1Measure(vocab, tag_namespace=tag_namespace, tag_encoding=tag_encoding) else: self._f1_metric = None self.f1 = SpanBasedF1Measure(vocab, tag_namespace=tag_namespace) self.tag_f1_metrics = {} for k in range(self.tag_num_classes): self.tag_f1_metrics[vocab.get_token_from_index(index=k, namespace=tag_namespace)] = F1Measure( positive_label=k) initializer(self) @overrides def forward(self, # type: ignore user_utterance: Dict[str, torch.LongTensor], prev_user_utterance: Dict[str, torch.LongTensor], prev_sys_utterance: Dict[str, torch.LongTensor], tokens: Dict[str, torch.LongTensor], label: torch.LongTensor = None, tags: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- user_utterance : Dict[str, Variable], required The output of ``TextField.as_array()``. prev_user_utterance : Dict[str, Variable], required The output of ``TextField.as_array()``. prev_sys_utterance : Dict[str, Variable], required The output of ``TextField.as_array()``. label : Variable, optional (default = None) A variable representing the intent label for each instance in the batch. Returns ------- An output dictionary consisting of: class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the label classes for each instance. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # Intent task embedded_user_utterance = self.text_field_embedder(user_utterance) user_utterance_mask = util.get_text_field_mask(user_utterance) encoded_user_utterance = self.user_utterance_encoder(embedded_user_utterance, user_utterance_mask) embedded_prev_user_utterance = self.text_field_embedder(prev_user_utterance) prev_user_utterance_mask = util.get_text_field_mask(prev_user_utterance) encoded_prev_user_utterance = self.prev_user_utterance_encoder(embedded_prev_user_utterance, prev_user_utterance_mask) embedded_prev_sys_utterance = self.text_field_embedder(prev_sys_utterance) prev_sys_utterance_mask = util.get_text_field_mask(prev_sys_utterance) encoded_prev_sys_utterance = self.prev_sys_utterance_encoder(embedded_prev_sys_utterance, prev_sys_utterance_mask) # Param task embedded_text_input = self.text_field_embedder(tokens) batch_size, sequence_length, _ = embedded_text_input.size() mask = get_text_field_mask(tokens) encoded_text = self.encoder(embedded_text_input, mask) label_logits = self.classifier_feedforward(torch.cat([encoded_user_utterance, encoded_prev_user_utterance, encoded_prev_sys_utterance], dim=-1)) label_class_probs = F.softmax(label_logits, dim=1) output_dict = {"label_logits": label_logits, "label_class_probs": label_class_probs} tag_logits = self.tag_projection_layer(encoded_text) reshaped_log_probs = tag_logits.view(-1, self.tag_num_classes) tag_class_probs = F.softmax(reshaped_log_probs, dim=-1).view([batch_size, sequence_length, self.tag_num_classes]) output_dict["tag_logits"] = tag_logits output_dict["tag_class_probs"] = tag_class_probs if label is not None: if tags is not None: loss = self.loss(label_logits, label) + sequence_cross_entropy_with_logits(tag_logits, tags, mask) output_dict["loss"] = loss # compute intent F1 per label for i in range(self.label_num_classes): metric = self.label_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="labels")] metric(label_class_probs, label) self.label_accuracy(label_logits, label) # compute param F1 per tag for i in range(self.tag_num_classes): metric = self.tag_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="tags")] metric(tag_class_probs, tags, mask.float()) self.tag_accuracy(tag_logits, tags, mask.float()) if metadata is not None: output_dict["words"] = [x["words"] for x in metadata] return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Does a simple argmax over the class probabilities, converts indices to string labels, and adds a ``"label"`` key to the dictionary with the result. """ # Intent task # label_class_probs = F.softmax(output_dict["label_logits"], dim=-1) # output_dict["label_class_probs"] = label_class_probs label_class_probs = output_dict["label_class_probs"] label_predictions = label_class_probs.cpu().data.numpy() label_argmax_indices = numpy.argmax(label_predictions, axis=-1) labels = [self.vocab.get_token_from_index(x, namespace="labels") for x in label_argmax_indices] output_dict["label"] = labels # Param task tag_all_predictions = output_dict["tag_class_probs"] tag_all_predictions = tag_all_predictions.cpu().data.numpy() if tag_all_predictions.ndim == 3: tag_predictions_list = [tag_all_predictions[i] for i in range(tag_all_predictions.shape[0])] else: tag_predictions_list = [tag_all_predictions] all_tags = [] for tag_predictions in tag_predictions_list: tag_argmax_indices = numpy.argmax(tag_predictions, axis=-1) tags = [self.vocab.get_token_from_index(y, namespace="tags") for y in tag_argmax_indices] all_tags.append(tags) output_dict["tags"] = all_tags return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: metric_dict = {} # intent task label_sum_f1 = 0.0 label_count = 0 for label_name, label_metric in self.label_f1_metrics.items(): label_metric_val = label_metric.get_metric(reset) metric_dict[label_name + "_P"] = label_metric_val[0] metric_dict[label_name + "_R"] = label_metric_val[1] metric_dict[label_name + "_F1"] = label_metric_val[2] if label_metric_val[2]: label_sum_f1 += label_metric_val[2] label_count += 1 if label_count: label_average_f1 = label_sum_f1 / label_count else: label_average_f1 = label_sum_f1 metric_dict["intent_average_F1"] = label_average_f1 metric_dict["label_accuracy"] = self.label_accuracy.get_metric(reset) # param task tag_sum_f1 = 0.0 tag_count = 0 for tag_name, tag_metric in self.tag_f1_metrics.items(): tag_metric_val = tag_metric.get_metric(reset) # if self.verbose_metrics: metric_dict[tag_name + "_P"] = tag_metric_val[0] metric_dict[tag_name + "_R"] = tag_metric_val[1] metric_dict[tag_name + "_F1"] = tag_metric_val[2] if tag_metric_val[2]: tag_sum_f1 += tag_metric_val[2] tag_count += 1 if tag_count: tag_average_f1 = tag_sum_f1 / tag_count else: tag_average_f1 = tag_sum_f1 metric_dict["param_average_F1"] = tag_average_f1 metric_dict["tag_accuracy"] = self.tag_accuracy.get_metric(reset) return metric_dict
class LstmSwag(Model): """ This model performs semantic role labeling using BIO tags using Propbank semantic roles. Specifically, it is an implmentation of `Deep Semantic Role Labeling - What works and what's next <https://homes.cs.washington.edu/~luheng/files/acl2017_hllz.pdf>`_ . This implementation is effectively a series of stacked interleaved LSTMs with highway connections, applied to embedded sequences of words concatenated with a binary indicator containing whether or not a word is the verbal predicate to generate predictions for in the sentence. Additionally, during inference, Viterbi decoding is applied to constrain the predictions to contain valid BIO sequences. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. encoder : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and predicting output tags. binary_feature_dim : int, required. The dimensionality of the embedding of the binary verb predicate features. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. label_smoothing : ``float``, optional (default = 0.0) Whether or not to use label smoothing on the labels when computing cross entropy loss. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, # binary_feature_dim: int, embedding_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(LstmSwag, self).__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder # For the span based evaluation, we don't want to consider labels # for verb, because the verb index is provided to the model. self.encoder = encoder self.embedding_dropout = Dropout(p=embedding_dropout) self.output_prediction = Linear(self.encoder.get_output_dim(), 1, bias=False) check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), "text embedding dim", "eq encoder input dim") self._accuracy = CategoricalAccuracy() self._loss = torch.nn.CrossEntropyLoss() initializer(self) def forward( self, # type: ignore hypothesis0: Dict[str, torch.LongTensor], hypothesis1: Dict[str, torch.LongTensor], hypothesis2: Dict[str, torch.LongTensor], hypothesis3: Dict[str, torch.LongTensor], label: torch.IntTensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ logits = [] for tokens in [hypothesis0, hypothesis1, hypothesis2, hypothesis3]: if isinstance(self.text_field_embedder, ElmoTokenEmbedder): self.text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states( ) embedded_text_input = self.embedding_dropout( self.text_field_embedder(tokens)) mask = get_text_field_mask(tokens) batch_size, sequence_length, _ = embedded_text_input.size() encoded_text = self.encoder(embedded_text_input, mask) logits.append(self.output_prediction(encoded_text.max(1)[0])) logits = torch.cat(logits, -1) class_probabilities = F.softmax(logits, dim=-1).view([batch_size, 4]) output_dict = { "label_logits": logits, "label_probs": class_probabilities } if label is not None: loss = self._loss(logits, label.long().view(-1)) self._accuracy(logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'accuracy': self._accuracy.get_metric(reset), } @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'LstmSwag': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params( vocab, embedder_params) encoder = Seq2SeqEncoder.from_params(params.pop("encoder")) initializer = InitializerApplicator.from_params( params.pop('initializer', [])) regularizer = RegularizerApplicator.from_params( params.pop('regularizer', [])) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, encoder=encoder, initializer=initializer, regularizer=regularizer)
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = RegularizerApplicator()): super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = CosineMatrixAttention() self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed( torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed( torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: # We call the inputs "logits" - they could either be unnormalized logits or normalized log # probabilities. A log_softmax operation is a constant shifting of the entire logit # vector, so taking an argmax over either one gives the same result. if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() device = span_start_logits.device # (batch_size, passage_length, passage_length) span_log_probs = span_start_logits.unsqueeze( 2) + span_end_logits.unsqueeze(1) # Only the upper triangle of the span matrix is valid; the lower triangle has entries where # the span ends before it starts. span_log_mask = torch.triu( torch.ones((passage_length, passage_length), device=device)).log().unsqueeze(0) valid_span_log_probs = span_log_probs + span_log_mask # Here we take the span matrix and flatten it, then find the best span using argmax. We # can recover the start and end indices from this flattened list using simple modular # arithmetic. # (batch_size, passage_length * passage_length) best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) span_start_indices = best_spans // passage_length span_end_indices = best_spans % passage_length return torch.stack([span_start_indices, span_end_indices], dim=-1)
class ModelSQUAD(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, attention_similarity_function: SimilarityFunction, residual_encoder: Seq2SeqEncoder, span_start_encoder: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, feed_forward: FeedForward, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(ModelSQUAD, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed( Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = MatrixAttention(attention_similarity_function) self._residual_encoder = residual_encoder self._span_end_encoder = span_end_encoder self._span_start_encoder = span_start_encoder self._feed_forward = feed_forward encoding_dim = phrase_layer.get_output_dim() self._span_start_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() self._span_end_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) self._no_answer_predictor = TimeDistributed( torch.nn.Linear(encoding_dim, 1)) self._self_matrix_attention = MatrixAttention( attention_similarity_function) self._linear_layer = TimeDistributed( torch.nn.Linear(4 * encoding_dim, encoding_dim)) self._residual_linear_layer = TimeDistributed( torch.nn.Linear(3 * encoding_dim, encoding_dim)) self._w_x = torch.nn.Parameter(torch.Tensor(encoding_dim)) self._w_y = torch.nn.Parameter(torch.Tensor(encoding_dim)) self._w_xy = torch.nn.Parameter(torch.Tensor(encoding_dim)) std = math.sqrt(6 / (encoding_dim * 3 + 1)) self._w_x.data.uniform_(-std, std) self._w_y.data.uniform_(-std, std) self._w_xy.data.uniform_(-std, std) self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.LongTensor = None, span_end: torch.LongTensor = None, spans=None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ embedded_question = self._highway_layer( self._text_field_embedder(question)) # Shape: (batch_size, 4, passage_length, embedding_dim) embedded_passage = self._text_field_embedder(passage) (batch_size, q_length, embedding_dim) = embedded_question.size() passage_length = embedded_passage.size(2) # reshape: (batch_size*4, -1, embedding_dim) embedded_passage = embedded_passage.view(-1, passage_length, embedding_dim) embedded_passage = self._highway_layer(embedded_passage) embedded_question = embedded_question.unsqueeze(0).expand( 4, -1, -1, -1).contiguous().view(-1, q_length, embedding_dim) question_mask = util.get_text_field_mask(question).float() question_mask = question_mask.unsqueeze(0).expand( 4, -1, -1).contiguous().view(-1, q_length) passage_mask = util.get_text_field_mask(passage, 1).float() passage_mask = passage_mask.view(-1, passage_length) question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) cuda_device = encoded_question.get_device() # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) question_attended_passage = relu( self._linear_layer(final_merged_passage)) # TODO: attach residual self-attention layer # Shape: (batch_size, passage_length, encoding_dim) residual_passage = self._dropout( self._residual_encoder(self._dropout(question_attended_passage), passage_lstm_mask)) mask = passage_mask.resize(batch_size, passage_length, 1) * passage_mask.resize( batch_size, 1, passage_length) self_mask = Variable( torch.eye(passage_length, passage_length).cuda(cuda_device)).resize( 1, passage_length, passage_length) mask = mask * (1 - self_mask) # Shape: (batch_size, passage_length, passage_length) x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2) y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1) dot_similarity = torch.bmm(residual_passage * self._w_xy, residual_passage.transpose(1, 2)) passage_self_similarity = dot_similarity + x_similarity + y_similarity #for i in range(passage_length): # passage_self_similarity[:, i, i] = float('-Inf') # Shape: (batch_size, passage_length, passage_length) passage_self_attention = util.last_dim_softmax(passage_self_similarity, mask) # Shape: (batch_size, passage_length, encoding_dim) passage_vectors = util.weighted_sum(residual_passage, passage_self_attention) # Shape: (batch_size, passage_length, encoding_dim * 3) merged_passage = torch.cat([ residual_passage, passage_vectors, residual_passage * passage_vectors ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) self_attended_passage = relu( self._residual_linear_layer(merged_passage)) # Shape: (batch_size, passage_length, encoding_dim) mixed_passage = question_attended_passage + self_attended_passage # Shape: (batch_size, passage_length, encoding_dim) encoded_span_start = self._dropout( self._span_start_encoder(mixed_passage, passage_lstm_mask)) span_start_logits = self._span_start_predictor( encoded_span_start).squeeze(-1) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, passage_length, encoding_dim * 2) concatenated_passage = torch.cat([mixed_passage, encoded_span_start], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(concatenated_passage, passage_lstm_mask)) span_end_logits = self._span_end_predictor(encoded_span_end).squeeze( -1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) # Shape: (batch_size, encoding_dim) v_1 = util.weighted_sum(encoded_span_start, span_start_probs) v_2 = util.weighted_sum(encoded_span_end, span_end_probs) no_span_logits = self._no_answer_predictor( self_attended_passage).squeeze(-1) no_span_probs = util.masked_softmax(no_span_logits, passage_mask) v_3 = util.weighted_sum(self_attended_passage, no_span_probs) # Shape: (batch_size, 1) z_score = self._feed_forward(torch.cat([v_1, v_2, v_3], dim=-1)) # compute no-answer score span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # create target tensor including no-answer label span_target = Variable(torch.ones(batch_size).long()).cuda(cuda_device) for b in range(batch_size): span_target[b].data[0] = span_start[ b, 0].data[0] * passage_length + span_end[b, 0].data[0] span_target[span_target < 0] = passage_length**2 # Shape: (batch_size, passage_length, passage_length) span_start_logits_tiled = span_start_logits.unsqueeze(1).expand( batch_size, passage_length, passage_length) span_end_logits_tiled = span_end_logits.unsqueeze(-1).expand( batch_size, passage_length, passage_length) span_logits = (span_start_logits_tiled + span_end_logits_tiled).view( batch_size, -1) answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view( batch_size, -1) no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device) combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1) all_logits = torch.cat([span_logits, z_score], dim=-1) loss = nll_loss(util.masked_log_softmax(all_logits, combined_mask), span_target) output_dict["loss"] = loss # Shape(batch_size, max_answers, num_span) # max_answers = spans.size(1) # span_logits = torch.bmm(span_start_logits.unsqueeze(-1), span_end_logits.unsqueeze(1)).view(batch_size, -1) # answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view(batch_size, -1) # no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device) # combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1) # # Shape: (batch_size, passage_length**2 + 1) # all_logits = torch.cat([span_logits, z_score], dim=-1) # # Shape: (batch_size, max_answers) # spans_combined = spans[:, :, 0] * passage_length + spans[:, :, 1] # spans_combined[spans_combined < 0] = passage_length*passage_length # # all_modified_logits = [] # for b in range(batch_size): # idxs = Variable(torch.LongTensor(range(passage_length**2 + 1))).cuda(cuda_device) # for i in range(max_answers): # idxs[spans_combined[b, i].data[0]].data = idxs[spans_combined[b, 0].data[0]].data # idxs[passage_length**2].data[0] = passage_length**2 # modified_logits = Variable(torch.zeros(all_logits.size(-1))).cuda(cuda_device) # modified_logits.index_add_(0, idxs, all_logits[b]) # all_modified_logits.append(modified_logits) # all_modified_logits = torch.stack(all_modified_logits, dim=0) # loss = nll_loss(util.masked_log_softmax(all_modified_logits, combined_mask), spans_combined[:, 0]) # output_dict["loss"] = loss if span_start is not None: self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError( "Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = Variable(span_start_logits.data.new().resize_( batch_size, 2).fill_(0)).long() span_start_logits = span_start_logits.data.cpu().numpy() span_end_logits = span_end_logits.data.cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> 'ModelSQUAD': embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params( vocab, embedder_params) num_highway_layers = params.pop_int("num_highway_layers") phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer")) similarity_function = SimilarityFunction.from_params( params.pop("similarity_function")) residual_encoder = Seq2SeqEncoder.from_params( params.pop("residual_encoder")) span_start_encoder = Seq2SeqEncoder.from_params( params.pop("span_start_encoder")) span_end_encoder = Seq2SeqEncoder.from_params( params.pop("span_end_encoder")) feed_forward = FeedForward.from_params(params.pop("feed_forward")) dropout = params.pop_float('dropout', 0.2) initializer = InitializerApplicator.from_params( params.pop('initializer', [])) regularizer = RegularizerApplicator.from_params( params.pop('regularizer', [])) mask_lstms = params.pop_bool('mask_lstms', True) params.assert_empty(cls.__name__) return cls(vocab=vocab, text_field_embedder=text_field_embedder, num_highway_layers=num_highway_layers, phrase_layer=phrase_layer, attention_similarity_function=similarity_function, residual_encoder=residual_encoder, span_start_encoder=span_start_encoder, span_end_encoder=span_end_encoder, feed_forward=feed_forward, dropout=dropout, mask_lstms=mask_lstms, initializer=initializer, regularizer=regularizer)
class BidirectionalAttentionFlow(Model): """ This class implements Minjoon Seo's `Bidirectional Attention Flow model <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ for answering reading comprehension questions (ICLR 2017). The basic layout is pretty simple: encode words as a combination of word embeddings and a character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of attentions to put question information into the passage word representations (this is the only part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and do a softmax over span start and span end. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. num_highway_layers : ``int`` The number of highway layers to use in between embedding the input and passing it through the phrase layer. phrase_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between embedding tokens and doing the bidirectional attention. similarity_function : ``SimilarityFunction`` The similarity function that we will use when comparing encoded passage and question representations. modeling_layer : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use in between the bidirectional attention and predicting span start and end. span_end_encoder : ``Seq2SeqEncoder`` The encoder that we will use to incorporate span start predictions into the passage state before predicting span end. dropout : ``float``, optional (default=0.2) If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer). mask_lstms : ``bool``, optional (default=True) If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup, with only a slight performance decrease, if any. We haven't experimented much with this yet, but have confirmed that we still get very similar performance with much faster training times. We still use the mask for all softmaxes, but avoid the shuffling that's required when using masking with pytorch LSTMs. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, num_highway_layers: int, phrase_layer: Seq2SeqEncoder, similarity_function: SimilarityFunction, modeling_layer: Seq2SeqEncoder, span_end_encoder: Seq2SeqEncoder, dropout: float = 0.2, mask_lstms: bool = True, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(), num_highway_layers)) self._phrase_layer = phrase_layer self._matrix_attention = LegacyMatrixAttention(similarity_function) self._modeling_layer = modeling_layer self._span_end_encoder = span_end_encoder encoding_dim = phrase_layer.get_output_dim() modeling_dim = modeling_layer.get_output_dim() span_start_input_dim = encoding_dim * 4 + modeling_dim self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1)) span_end_encoding_dim = span_end_encoder.get_output_dim() span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1)) # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily # obvious from the configuration files, so we check here. check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim, "modeling layer input dim", "4 * encoding dim") check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(), "text field embedder output dim", "phrase layer input dim") check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim, "span end encoder input dim", "4 * encoding dim + 3 * modeling dim") self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._squad_metrics = SquadEmAndF1() if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._mask_lstms = mask_lstms initializer(self) def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: exact_match, f1_score = self._squad_metrics.get_metric(reset) return { 'start_acc': self._span_start_accuracy.get_metric(reset), 'end_acc': self._span_end_accuracy.get_metric(reset), 'span_acc': self._span_accuracy.get_metric(reset), 'em': exact_match, 'f1': f1_score, } @staticmethod def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: raise ValueError("Input shapes must be (batch_size, passage_length)") batch_size, passage_length = span_start_logits.size() max_span_log_prob = [-1e20] * batch_size span_start_argmax = [0] * batch_size best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) span_start_logits = span_start_logits.detach().cpu().numpy() span_end_logits = span_end_logits.detach().cpu().numpy() for b in range(batch_size): # pylint: disable=invalid-name for j in range(passage_length): val1 = span_start_logits[b, span_start_argmax[b]] if val1 < span_start_logits[b, j]: span_start_argmax[b] = j val1 = span_start_logits[b, j] val2 = span_end_logits[b, j] if val1 + val2 > max_span_log_prob[b]: best_word_span[b, 0] = span_start_argmax[b] best_word_span[b, 1] = j max_span_log_prob[b] = val1 + val2 return best_word_span