def test_span_metrics_are_computed_correctly(self): from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) batch_verb_indices = [2] batch_sentences = [["The", "cat", "loves", "hats", "."]] batch_bio_predicted_tags = [["B-ARG0", "B-ARG1", "B-V", "B-ARG1", "O"]] batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [["B-ARG0", "I-ARG0", "B-V", "B-ARG1", "O"]] batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] srl_scorer = SrlEvalScorer(ignore_classes=["V"]) srl_scorer(batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags) metrics = srl_scorer.get_metric() assert len(metrics) == 9 assert_allclose(metrics["precision-ARG0"], 0.0) assert_allclose(metrics["recall-ARG0"], 0.0) assert_allclose(metrics["f1-measure-ARG0"], 0.0) assert_allclose(metrics["precision-ARG1"], 0.5) assert_allclose(metrics["recall-ARG1"], 1.0) assert_allclose(metrics["f1-measure-ARG1"], 2 / 3) assert_allclose(metrics["precision-overall"], 1 / 3) assert_allclose(metrics["recall-overall"], 1 / 2) assert_allclose(metrics["f1-measure-overall"], (2 * (1 / 3) * (1 / 2)) / ((1 / 3) + (1 / 2)))
def test_bio_tags_correctly_convert_to_conll_format(self): bio_tags = ["B-ARG-1", "I-ARG-1", "O", "B-V", "B-ARGM-ADJ", "O"] from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) conll_tags = convert_bio_tags_to_conll_format(bio_tags) assert conll_tags == ["(ARG-1*", "*)", "*", "(V*)", "(ARGM-ADJ*)", "*"]
def test_distributed_setting_throws_an_error(self): from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) batch_verb_indices = [2] batch_sentences = [["The", "cat", "loves", "hats", "."]] batch_bio_predicted_tags = [["B-ARG0", "B-ARG1", "B-V", "B-ARG1", "O"]] batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [["B-ARG0", "I-ARG0", "B-V", "B-ARG1", "O"]] batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] metric_kwargs = { "batch_verb_indices": [batch_verb_indices, batch_verb_indices], "batch_sentences": [batch_sentences, batch_sentences], "batch_conll_formatted_predicted_tags": [ batch_conll_predicted_tags, batch_conll_predicted_tags, ], "batch_conll_formatted_gold_tags": [batch_conll_gold_tags, batch_conll_gold_tags], } desired_values = {} # it does not matter, we expect the run to fail. with pytest.raises(Exception) as exc: run_distributed_test( [-1, -1], global_distributed_metric, SrlEvalScorer(ignore_classes=["V"]), metric_kwargs, desired_values, exact=True, ) assert ( "RuntimeError: Distributed aggregation for `SrlEvalScorer` is currently not supported." in str(exc.value))
def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, frame_indicator: torch.Tensor, metadata: List[Any], tags: torch.LongTensor = None, frame_tags: torch.LongTensor = None, ): """ # Parameters tokens : `TextFieldTensors`, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which indexes wordpieces from the BERT vocabulary. verb_indicator: `torch.LongTensor`, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. frame_indicator: torch.LongTensor, required. An integer ``SequenceFeatureField`` representation of the position of the frame in the sentence. This should have shape (batch_size, num_tokens). Similar to verb_indicator, but handles bert wordpiece tokenizer by cosnidering a frame only the first subtoken. tags : `torch.LongTensor`, optional (default = `None`) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` frame_tags : torch.LongTensor, optional (default = None) A torch tensor representing the gold frames of shape ``(batch_size, num_tokens)`` metadata : `List[Dict[str, Any]]`, optional, (default = `None`) metadata containg the original words in the sentence, the verb to compute the frame for, and start offsets for converting wordpieces back to a sequence of words, under 'words', 'verb' and 'offsets' keys, respectively. # Returns An output dictionary consisting of: logits : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ mask = get_text_field_mask(tokens) input_ids = util.get_token_ids_from_text_field_tensors(tokens) bert_embeddings, _ = self.transformer( input_ids=input_ids, token_type_ids=verb_indicator, attention_mask=mask, return_dict=False, ) # extract embeddings embedded_text_input = self.embedding_dropout(bert_embeddings) frame_embeddings = embedded_text_input[frame_indicator == 1] # get sizes batch_size, sequence_length, _ = embedded_text_input.size() # outputs logits = self.tag_projection_layer(embedded_text_input) frame_logits = self.frame_projection_layer(frame_embeddings) reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view( [batch_size, sequence_length, self.num_classes]) frame_probabilities = F.softmax(frame_logits, dim=-1) # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.make_output_human_readable. output_dict = { "logits": logits, "frame_logits": frame_logits, "class_probabilities": class_probabilities, "frame_probabilities": frame_probabilities, "mask": mask, } # We add in the offsets here so we can compute the un-wordpieced tags. words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata]) lemmas = [l for x in metadata for l in x["lemmas"]] output_dict["words"] = list(words) output_dict["lemma"] = list(lemmas) output_dict["verb"] = list(verbs) output_dict["wordpiece_offsets"] = list(offsets) if tags is not None: # compute role loss role_loss = sequence_cross_entropy_with_logits( logits, tags, mask, label_smoothing=self._label_smoothing) # compute frame loss frame_tags_filtered = frame_tags[frame_indicator == 1] frame_loss = self.frame_criterion(frame_logits, frame_tags_filtered) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from make_output_human_readable() batch_bio_predicted_tags = self.make_output_human_readable( output_dict).pop("tags") from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) self.f1_frame_metric(frame_logits, frame_tags_filtered) output_dict["frame_loss"] = frame_loss output_dict["role_loss"] = role_loss output_dict["loss"] = (role_loss + frame_loss) / 2 return output_dict
def test_srl_eval_correctly_scores_identical_tags(self): batch_verb_indices = [3, 8, 2, 0] batch_sentences = [ [ "Mali", "government", "officials", "say", "the", "woman", "'s", "confession", "was", "forced", ".", ], [ "Mali", "government", "officials", "say", "the", "woman", "'s", "confession", "was", "forced", ".", ], [ "The", "prosecution", "rested", "its", "case", "last", "month", "after", "four", "months", "of", "hearings", ".", ], ["Come", "in", "and", "buy", "."], ] batch_bio_predicted_tags = [ [ "B-ARG0", "I-ARG0", "I-ARG0", "B-V", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "O", ], [ "O", "O", "O", "O", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "B-V", "B-ARG2", "O" ], [ "B-ARG0", "I-ARG0", "B-V", "B-ARG1", "I-ARG1", "B-ARGM-TMP", "I-ARGM-TMP", "B-ARGM-TMP", "I-ARGM-TMP", "I-ARGM-TMP", "I-ARGM-TMP", "I-ARGM-TMP", "O", ], ["B-V", "B-AM-DIR", "O", "O", "O"], ] from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ [ "B-ARG0", "I-ARG0", "I-ARG0", "B-V", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "O", ], [ "O", "O", "O", "O", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1", "B-V", "B-ARG2", "O" ], [ "B-ARG0", "I-ARG0", "B-V", "B-ARG1", "I-ARG1", "B-ARGM-TMP", "I-ARGM-TMP", "B-ARGM-TMP", "I-ARGM-TMP", "I-ARGM-TMP", "I-ARGM-TMP", "I-ARGM-TMP", "O", ], ["B-V", "B-AM-DIR", "O", "O", "O"], ] batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] srl_scorer = SrlEvalScorer(ignore_classes=["V"]) srl_scorer(batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags) metrics = srl_scorer.get_metric() assert len(metrics) == 18 assert_allclose(metrics["precision-ARG0"], 1.0) assert_allclose(metrics["recall-ARG0"], 1.0) assert_allclose(metrics["f1-measure-ARG0"], 1.0) assert_allclose(metrics["precision-ARG1"], 1.0) assert_allclose(metrics["recall-ARG1"], 1.0) assert_allclose(metrics["f1-measure-ARG1"], 1.0) assert_allclose(metrics["precision-ARG2"], 1.0) assert_allclose(metrics["recall-ARG2"], 1.0) assert_allclose(metrics["f1-measure-ARG2"], 1.0) assert_allclose(metrics["precision-ARGM-TMP"], 1.0) assert_allclose(metrics["recall-ARGM-TMP"], 1.0) assert_allclose(metrics["f1-measure-ARGM-TMP"], 1.0) assert_allclose(metrics["precision-AM-DIR"], 1.0) assert_allclose(metrics["recall-AM-DIR"], 1.0) assert_allclose(metrics["f1-measure-AM-DIR"], 1.0) assert_allclose(metrics["precision-overall"], 1.0) assert_allclose(metrics["recall-overall"], 1.0) assert_allclose(metrics["f1-measure-overall"], 1.0)
def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, sentence_end: torch.LongTensor, metadata: List[Any], tags: torch.LongTensor = None, offsets: torch.LongTensor = None): """ # Parameters tokens : `TextFieldTensors`, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which indexes wordpieces from the BERT vocabulary. verb_indicator: `torch.LongTensor`, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : `torch.LongTensor`, optional (default = `None`) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` metadata : `List[Dict[str, Any]]`, optional, (default = `None`) metadata containing the original words in the sentence, the verb to compute the frame for, and start offsets for converting wordpieces back to a sequence of words, under 'words', 'verb' and 'offsets' keys, respectively. # Returns An output dictionary consisting of: logits : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ if isinstance(self.bert_model, PretrainedTransformerMismatchedEmbedder): encoder_inputs = tokens["tokens"] if self.bert_config.type_vocab_size > 1: encoder_inputs["type_ids"] = verb_indicator encoded_text = self.bert_model(**encoder_inputs) batch_size = encoded_text.shape[0] if self.bert_config.type_vocab_size == 1: verb_embeddings = encoded_text[ torch.arange(batch_size).to(encoded_text.device), verb_indicator.argmax(1), :] verb_embeddings = torch.where( (verb_indicator.sum(1, keepdim=True) > 0).repeat( 1, verb_embeddings.shape[-1]), verb_embeddings, torch.zeros_like(verb_embeddings)) encoded_text = torch.cat( (encoded_text, verb_embeddings.unsqueeze(1).repeat( 1, encoded_text.shape[1], 1)), dim=2) mask = tokens["tokens"]["mask"] index = mask.sum(1).argmax().item() # print(mask.shape, encoded_text.shape, tokens["tokens"]["token_ids"].shape, tags.shape, max([len(x['words']) for x in metadata]), mask.sum(1)[index].item()) # print(tokens["tokens"]["token_ids"][index,:]) else: mask = get_text_field_mask(tokens) bert_embeddings, _ = self.bert_model( input_ids=util.get_token_ids_from_text_field_tensors(tokens), # token_type_ids=verb_indicator, attention_mask=mask, ) batch_size, _ = mask.size() embedded_text_input = self.embedding_dropout(bert_embeddings) # Restrict to sentence part sentence_mask = (torch.arange(mask.shape[1]).unsqueeze(0).repeat( batch_size, 1).to(mask.device) < sentence_end.unsqueeze(1).repeat( 1, mask.shape[1])).long() cutoff = sentence_end.max().item() if self._encoder is None: encoded_text = embedded_text_input mask = sentence_mask[:, :cutoff].contiguous() encoded_text = encoded_text[:, :cutoff, :] tags = tags[:, :cutoff].contiguous() else: predicate_embeddings = self.predicate_embedding(verb_indicator) encoder_inputs = torch.cat( (embedded_text_input, predicate_embeddings), dim=-1) encoded_text = self._encoder(encoder_inputs, mask=sentence_mask.bool()) # print(verb_indicator) predicate_index = (verb_indicator * torch.arange( start=verb_indicator.shape[-1] - 1, end=-1, step=-1).to(mask.device).unsqueeze(0).repeat( batch_size, 1)).argmax(1) # print(predicate_index) predicate_hidden = encoded_text[ torch.arange(batch_size).to(mask.device), predicate_index] predicate_exists, _ = verb_indicator.max(1) encoded_text = encoded_text[:, :cutoff, :] tags = tags[:, :cutoff].contiguous() mask = sentence_mask[:, :cutoff].contiguous() predicate_exists = predicate_exists.unsqueeze(1).repeat( 1, encoded_text.shape[-1]) predicate_hidden = torch.where( predicate_exists > 0, predicate_hidden, torch.zeros_like(predicate_hidden)) encoded_text = torch.cat( (encoded_text, predicate_hidden.unsqueeze(1).repeat( 1, encoded_text.shape[1], 1)), dim=-1) sequence_length = encoded_text.shape[1] logits = self.tag_projection_layer(encoded_text) # print(mask, logits) if self._lp and sequence_length <= 100: eps = 1e-4 Q = eps * torch.eye( sequence_length * self.num_classes, sequence_length * self.num_classes).unsqueeze(0).repeat( batch_size, 1, 1).to(logits.device).float() p = logits.view(batch_size, -1) G = -1 * torch.eye( sequence_length * self.num_classes).unsqueeze(0).repeat( batch_size, 1, 1).to(logits.device).float() h = torch.zeros_like(p) A = torch.arange(sequence_length * self.num_classes).unsqueeze(0).repeat( sequence_length, 1) A2 = torch.arange(sequence_length).unsqueeze(1).repeat( 1, sequence_length * self.num_classes) * self.num_classes A = torch.where((A >= A2) & (A < A2 + self.num_classes), torch.ones_like(A), torch.zeros_like(A)) A = A.unsqueeze(0).repeat(batch_size, 1, 1).to(logits.device).float() b = torch.ones_like(A[:, :, 0]) probs = QPFunction()(Q, p, torch.autograd.Variable(torch.Tensor()), torch.autograd.Variable(torch.Tensor()), A, b) probs = probs.view(batch_size, sequence_length, self.num_classes) """logits_shape = logits.shape logits = torch.where(mask.bool().unsqueeze(-1).repeat(1, 1, logits.shape[-1]), logits, logits-10000) max_sequence_length = min([l for l in self.lengths if l >= sequence_length]) if max_sequence_length > logits_shape[1]: logits = torch.cat((logits, torch.zeros((batch_size, max_sequence_length-logits_shape[1], logits_shape[2])).to(logits.device)), dim=1) lp_layer = self._layer_list[self.length_map[max_sequence_length]] probs, = lp_layer(logits) print(torch.isnan(probs).any()) if max_sequence_length > logits_shape[1]: probs = probs[:,:logits_shape[1],:]""" logits = (torch.nn.functional.relu(probs) + 1e-4).log() if self._lpsmap: if self._lpsmap_core_only: all_logits = logits else: all_logits = torch.cat((logits, 0.5 * torch.ones( (batch_size, 1, logits.shape[-1])).to(logits.device)), dim=1) probs = [] for i in range(batch_size): if self.constrain_crf_decoding: unaries = logits[i, :, :].view(-1).cpu() additionals = self.crf.transitions.view(-1).repeat( sequence_length) + 10000 * ( self.crf._constraint_mask[:-2, :-2] - 1).view(-1).repeat(sequence_length) start_transitions = self.crf.start_transitions + 10000 * ( self.crf._constraint_mask[-2, :-2] - 1) end_transitions = self.crf.start_transitions + 10000 * ( self.crf._constraint_mask[-1, :-2] - 1) additionals = torch.cat( (additionals, start_transitions, end_transitions), dim=0).cpu() fg = TorchFactorGraph() x = fg.variable_from(unaries) f = PFactorSequence() f.initialize( [self.num_classes for _ in range(sequence_length)]) factor = TorchOtherFactor(f, x, additionals) fg.add(factor) # add budget constraint for each state for state in self._core_roles: vars_state = x[state::self.num_classes] fg.add(AtMostOne(vars_state)) # solve SparseMAP fg.solve(max_iter=200) probs.append( unaries.to(logits.device).view(sequence_length, self.num_classes)) else: fg = TorchFactorGraph() x = fg.variable_from(all_logits[i, :, :].cpu()) for j in range(sequence_length): fg.add(Xor(x[j, :])) for j in self._core_roles: fg.add(AtMostOne(x[:sequence_length, j])) if not self._lpsmap_core_only: full_sequence = list(range(sequence_length)) base_roles = set([ second for (_, second) in self._r_roles + self._c_roles ]) """for (r_role, base_role) in self._r_roles+self._c_roles: for j in range(sequence_length): fg.add(Imply(x[full_sequence+[j],[base_role]*sequence_length+[r_role]], negated=[True]*(sequence_length+1)))""" for base_role in base_roles: fg.add(OrOut(x[:, base_role])) for (r_role, base_role) in self._r_roles + self._c_roles: fg.add(OrOut(x[:, r_role])) fg.add( Or(x[[sequence_length, sequence_length], [r_role, base_role]], negated=[True, False])) max_iter = 100 if not self._lpsmap_core_only: max_iter = min(max_iter, 400) elif (not self.training) and not self._val_inference: max_iter = min(max_iter, 200) fg.solve(max_iter=max_iter) probs.append(x.value[:sequence_length, :].contiguous().to( logits.device)) class_probabilities = torch.stack(probs) # class_probabilities = self.lpsmap(logits) max_seq_length = 200 # if self.lpsmap is None: """with torch.no_grad(): # self.lpsmap = LpSparseMap(num_rows=sequence_length, num_cols=self.num_classes, batch_size=batch_size, device=logits.device, constraints=[('xor', ('row', list(range(sequence_length)))), ('budget', ('col', self._core_roles))]) max_iter = 1000 constraint_types = ["xor", "budget"] constraint_dims = ["row", "col"] constraint_sets = [list(range(sequence_length)), self._core_roles] class_probabilities = lpsmap(logits, constraint_types, constraint_dims, constraint_sets, max_iter) # if max_seq_length > sequence_length: # logits = torch.cat((logits, -9999.*torch.ones((batch_size, max_seq_length-sequence_length, self.num_classes)).to(logits.device)), dim=1) # class_probabilities = self.lpsmap.solve(logits, max_iter=max_iter)""" # logits = (class_probabilities+1e-4).log() else: reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view( [batch_size, sequence_length, self.num_classes]) output_dict = { "logits": logits, "class_probabilities": class_probabilities } # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.make_output_human_readable. output_dict["mask"] = mask # We add in the offsets here so we can compute the un-wordpieced tags. words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata]) output_dict["words"] = list(words) output_dict["verb"] = list(verbs) output_dict["wordpiece_offsets"] = list(offsets) if tags is not None: # print(mask.shape, tags.shape, logits.shape, tags.max(), tags.min()) if self._lpsmap: loss = LpsmapLoss.apply(logits, class_probabilities, tags, mask) # tags_1hot = torch.zeros_like(class_probabilities).scatter_(2, tags.unsqueeze(-1), torch.ones_like(class_probabilities)) # loss = -(tags_1hot*class_probabilities*mask.unsqueeze(-1).repeat(1, 1, class_probabilities.shape[-1])).sum() elif self.constrain_crf_decoding: loss = -self.crf(logits, tags, mask) else: loss = sequence_cross_entropy_with_logits( logits, tags, mask, label_smoothing=self._label_smoothing) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from make_output_human_readable() # TODO (nfliu): This is kind of a hack, consider splitting out part # of make_output_human_readable() to a separate function. batch_bio_predicted_tags = self.make_output_human_readable( output_dict).pop("tags") from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) if self.constrain_crf_decoding and not self._lpsmap: batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format([ self.vocab.get_token_from_index( tag, namespace=self._label_namespace) for tag in seq ]) for (seq, _) in self.crf.viterbi_tags(logits, mask) ] else: batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] # print(batch_bio_gold_tags) batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) output_dict["loss"] = loss output_dict["gold_tags"] = [x["gold_tags"] for x in metadata] return output_dict
def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, sentence_end: torch.LongTensor, spans: torch.LongTensor, span_labels: torch.LongTensor, metadata: List[Any], tags: torch.LongTensor = None, ): """ # Parameters tokens : `TextFieldTensors`, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which indexes wordpieces from the BERT vocabulary. verb_indicator: `torch.LongTensor`, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : `torch.LongTensor`, optional (default = `None`) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` metadata : `List[Dict[str, Any]]`, optional, (default = `None`) metadata containg the original words in the sentence, the verb to compute the frame for, and start offsets for converting wordpieces back to a sequence of words, under 'words', 'verb' and 'offsets' keys, respectively. # Returns An output dictionary consisting of: logits : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ mask = get_text_field_mask(tokens) start = time.time() bert_embeddings, _ = self.bert_model( input_ids=util.get_token_ids_from_text_field_tensors(tokens), # token_type_ids=verb_indicator, attention_mask=mask, ) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1) # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() embedded_text_input = self.embedding_dropout(bert_embeddings) batch_size, sequence_length, _ = embedded_text_input.size() # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( bert_embeddings, spans) if self._context_layer is not None: contextualized_embeddings = self._context_layer( embedded_text_input, mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) span_embeddings = endpoint_span_embeddings else: span_embeddings = attended_span_embeddings # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * sequence_length)) num_spans = spans.shape[1] num_spans_to_keep = min(num_spans_to_keep, num_spans) # Shape: (batch_size, num_spans) span_mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings)).squeeze(-1) # Shape: (batch_size, num_spans) for all 3 tensors top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk( span_mention_scores, span_mask, num_spans_to_keep) verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat( 1, 1, embedded_text_input.shape[-1]) verb_embeddings = torch.gather(embedded_text_input, 1, verb_index) assert len( verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1 verb_embeddings = verb_embeddings.squeeze(1) # print(verb_indicator.sum(1, keepdim=True) > 0) verb_embeddings = torch.where( (verb_indicator.sum(1, keepdim=True) > 0).repeat( 1, verb_embeddings.shape[-1]), verb_embeddings, torch.zeros_like(verb_embeddings)) # print(verb_embeddings) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, spans.shape[1]) span_embeddings = util.batched_index_select(span_embeddings, top_span_indices, flat_top_span_indices) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) top_span_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices).squeeze(-1) concatenated_span_embeddings = torch.cat( (span_embeddings, verb_embeddings.unsqueeze(1).repeat( 1, span_embeddings.shape[1], 1)), dim=2) # print(concatenated_span_embeddings[:,:,:]) hidden = self.hidden_layer(concatenated_span_embeddings) # print(hidden[1,:,:]) # print(top_span_indices) # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])]) # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels")) predictions = self.output_layer(hidden) # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1) predictions = torch.cat( (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1) # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)) output_dict = {} # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.make_output_human_readable. output_dict["mask"] = mask # We add in the offsets here so we can compute the un-wordpieced tags. words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata]) output_dict["words"] = list(words) output_dict["verb"] = list(verbs) output_dict["wordpiece_offsets"] = list(offsets) if tags is not None: loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]), top_span_labels.view(-1)) * top_span_mask.float().view(-1) ).sum() / top_span_mask.float().sum() # print(top_span_labels) # print(predictions.argmax(-1)) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from make_output_human_readable() # TODO (nfliu): This is kind of a hack, consider splitting out part # of make_output_human_readable() to a separate function. batch_bio_predicted_tags = self.get_tags( top_spans, predictions, mask.shape[1], top_span_mask, output_dict) from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] # print('G', batch_bio_gold_tags) batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) output_dict["loss"] = loss return output_dict