def test_attachment_scores_can_ignore_labels(self): scorer = AttachmentScores(ignore_classes=[1]) label_predictions = self.label_predictions # Change the predictions where the gold label is 1; # as we are ignoring 1, we should still get a perfect score. label_predictions[0, 3] = 2 scorer(self.predictions, label_predictions, self.gold_indices, self.gold_labels, self.mask) for value in scorer.get_metric().values(): assert value == 1.0
def test_attachment_scores_can_ignore_labels(self): scorer = AttachmentScores(ignore_classes=[1]) label_predictions = self.label_predictions # Change the predictions where the gold label is 1; # as we are ignoring 1, we should still get a perfect score. label_predictions[0, 3] = 2 scorer(self.predictions, label_predictions, self.gold_indices, self.gold_labels, self.mask) for value in list(scorer.get_metric().values()): assert value == 1.0
class ParsingEvaluator(Evaluator): def __init__(self): self._attachment_scores = AttachmentScores() def add(self, gold_idx, gold_label, prediction_idx, prediction_label, mask): self._attachment_scores(prediction_idx, prediction_label, gold_idx, gold_label, mask) def get_metric(self): return self._attachment_scores.get_metric(reset=False) def reset(self): self._attachment_scores.reset()
def multiple_runs( global_rank: int, world_size: int, gpu_id: Union[int, torch.device], metric: AttachmentScores, metric_kwargs: Dict[str, List[Any]], desired_values: Dict[str, Any], exact: Union[bool, Tuple[float, float]] = True, ): kwargs = {} # Use the arguments meant for the process with rank `global_rank`. for argname in metric_kwargs: kwargs[argname] = metric_kwargs[argname][global_rank] for i in range(200): metric(**kwargs) metrics = metric.get_metric() for key in metrics: assert desired_values[key] == metrics[key]
class AttachmentScoresTest(AllenNlpTestCase): def setup_method(self): super().setup_method() self.scorer = AttachmentScores() self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]]) self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]]) self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]]) self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]]) self.mask = torch.tensor([[True, True, True, True, True, True], [True, True, True, True, False, False]]) def _send_tensors_to_device(self, device: str): self.predictions = self.predictions.to(device) self.gold_indices = self.gold_indices.to(device) self.label_predictions = self.label_predictions.to(device) self.gold_labels = self.gold_labels.to(device) self.mask = self.mask.to(device) @multi_device def test_perfect_scores(self, device: str): self._send_tensors_to_device(device) self.scorer(self.predictions, self.label_predictions, self.gold_indices, self.gold_labels, self.mask) for value in self.scorer.get_metric().values(): assert value == 1.0 @multi_device def test_unlabeled_accuracy_ignores_incorrect_labels(self, device: str): self._send_tensors_to_device(device) label_predictions = self.label_predictions # Change some stuff so our 4 of our label predictions are wrong. label_predictions[0, 3:] = 3 label_predictions[1, 0] = 7 self.scorer(self.predictions, label_predictions, self.gold_indices, self.gold_labels, self.mask) metrics = self.scorer.get_metric() assert metrics["UAS"] == 1.0 assert metrics["UEM"] == 1.0 # 4 / 12 labels were wrong and 2 positions # are masked, so 6/10 = 0.6 LAS. assert metrics["LAS"] == 0.6 # Neither should have labeled exact match. assert metrics["LEM"] == 0.0 @multi_device def test_labeled_accuracy_is_affected_by_incorrect_heads( self, device: str): self._send_tensors_to_device(device) predictions = self.predictions # Change some stuff so our 4 of our predictions are wrong. predictions[0, 3:] = 3 predictions[1, 0] = 7 # This one is in the padded part, so it shouldn't affect anything. predictions[1, 5] = 7 self.scorer(predictions, self.label_predictions, self.gold_indices, self.gold_labels, self.mask) metrics = self.scorer.get_metric() # 4 heads are incorrect, so the unlabeled score should be # 6/10 = 0.6 LAS. assert metrics["UAS"] == 0.6 # All the labels were correct, but some heads # were wrong, so the LAS should equal the UAS. assert metrics["LAS"] == 0.6 # Neither batch element had a perfect labeled or unlabeled EM. assert metrics["LEM"] == 0.0 assert metrics["UEM"] == 0.0 @multi_device def test_attachment_scores_can_ignore_labels(self, device: str): self._send_tensors_to_device(device) scorer = AttachmentScores(ignore_classes=[1]) label_predictions = self.label_predictions # Change the predictions where the gold label is 1; # as we are ignoring 1, we should still get a perfect score. label_predictions[0, 3] = 2 scorer(self.predictions, label_predictions, self.gold_indices, self.gold_labels, self.mask) for value in scorer.get_metric().values(): assert value == 1.0 def test_distributed_attachment_scores(self): predictions = [ torch.Tensor([[0, 1, 3, 5, 2, 4]]), torch.Tensor([[0, 3, 2, 1, 0, 0]]) ] gold_indices = [ torch.Tensor([[0, 1, 3, 5, 2, 4]]), torch.Tensor([[0, 3, 2, 1, 0, 0]]) ] label_predictions = [ torch.Tensor([[0, 5, 2, 3, 3, 3]]), torch.Tensor([[7, 4, 8, 2, 0, 0]]), ] gold_labels = [ torch.Tensor([[0, 5, 2, 1, 4, 2]]), torch.Tensor([[0, 4, 8, 2, 0, 0]]) ] mask = [ torch.tensor([[True, True, True, True, True, True]]), torch.tensor([[True, True, True, True, False, False]]), ] metric_kwargs = { "predicted_indices": predictions, "gold_indices": gold_indices, "predicted_labels": label_predictions, "gold_labels": gold_labels, "mask": mask, } desired_metrics = { "UAS": 1.0, "LAS": 0.6, "UEM": 1.0, "LEM": 0.0, } run_distributed_test( [-1, -1], global_distributed_metric, AttachmentScores(), metric_kwargs, desired_metrics, exact=True, )
class BiaffineDependencyParser(Model): """ This dependency parser follows the model of ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . Word representations are generated using a bidirectional LSTM, followed by separate biaffine classifiers for pairs of words, predicting whether a directed arc exists between the two words and the dependency label the arc should have. Decoding can either be done greedily, or the optimial Minimum Spanning Tree can be decoded using Edmond's algorithm by viewing the dependency tree as a MST on a fully connected graph, where nodes are words and edges are scored dependency arcs. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. encoder : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use to generate representations of tokens. tag_representation_dim : ``int``, required. The dimension of the MLPs used for dependency tag prediction. arc_representation_dim : ``int``, required. The dimension of the MLPs used for head arc prediction. tag_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce tag representations. By default, a 1 layer feedforward network with an elu activation is used. arc_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce arc representations. By default, a 1 layer feedforward network with an elu activation is used. pos_tag_embedding : ``Embedding``, optional. Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model. use_mst_decoding_for_validation : ``bool``, optional (default = True). Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation. If false, decoding is greedy. dropout : ``float``, optional, (default = 0.0) The variational dropout applied to the output of the encoder and MLP layers. input_dropout : ``float``, optional, (default = 0.0) The dropout applied to the embedded text input. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, tag_representation_dim: int, arc_representation_dim: int, tag_feedforward: FeedForward = None, arc_feedforward: FeedForward = None, pos_tag_embedding: Embedding = None, use_mst_decoding_for_validation: bool = True, dropout: float = 0.0, input_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BiaffineDependencyParser, self).__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder self.encoder = encoder encoder_dim = encoder.get_output_dim() self.head_arc_feedforward = arc_feedforward or \ FeedForward(encoder_dim, 1, arc_representation_dim, Activation.by_name("elu")()) self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) self.arc_attention = BilinearMatrixAttention(arc_representation_dim, arc_representation_dim, use_input_biases=True) num_labels = self.vocab.get_vocab_size("head_tags") self.head_tag_feedforward = tag_feedforward or \ FeedForward(encoder_dim, 1, tag_representation_dim, Activation.by_name("elu")()) self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim, tag_representation_dim, num_labels) self._pos_tag_embedding = pos_tag_embedding or None self._dropout = InputVariationalDropout(dropout) self._input_dropout = Dropout(input_dropout) self._head_sentinel = torch.nn.Parameter( torch.randn([1, 1, encoder.get_output_dim()])) representation_dim = text_field_embedder.get_output_dim() if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match(representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim") check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(), "tag representation dim", "tag feedforward output dim") check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(), "arc representation dim", "arc feedforward output dim") self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = { tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE } self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info( f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self._attachment_scores = AttachmentScores() initializer(self) @overrides def forward( self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required. The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask } return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: head_tags = output_dict["head_tags"].cpu().detach().numpy() heads = output_dict["heads"].cpu().detach().numpy() lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"]) head_tag_labels = [] head_indices = [] for instance_heads, instance_tags, length in zip( heads, head_tags, lengths): instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [ self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags ] head_tag_labels.append(labels) head_indices.append(instance_heads) output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices return output_dict def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = last_dim_log_softmax( attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = last_dim_log_softmax( head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs indpendently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag( attended_arcs.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).byte().unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, heads) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _mst_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions using the Edmonds' Algorithm for finding minimum spanning trees on directed graphs. Nodes in the graph are the words in the sentence, and between each pair of nodes, there is an edge in each direction, where the weight of the edge corresponds to the most likely dependency label probability for that arc. The MST is then generated from this directed graph. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the optimally decoded heads of each word. """ batch_size, sequence_length, tag_representation_dim = head_tag_representation.size( ) lengths = mask.data.sum(dim=1).long().cpu().numpy() expanded_shape = [ batch_size, sequence_length, sequence_length, tag_representation_dim ] head_tag_representation = head_tag_representation.unsqueeze(2) head_tag_representation = head_tag_representation.expand( *expanded_shape).contiguous() child_tag_representation = child_tag_representation.unsqueeze(1) child_tag_representation = child_tag_representation.expand( *expanded_shape).contiguous() # Shape (batch_size, sequence_length, sequence_length, num_head_tags) pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) # Note that this log_softmax is over the tag dimension, and we don't consider pairs # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. # Shape (batch, num_labels,sequence_length, sequence_length) normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute( 0, 3, 1, 2) # Mask padded tokens, because we only want to consider actual words as heads. minus_inf = -1e8 minus_mask = (1 - mask.float()) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) # Shape (batch_size, sequence_length, sequence_length) normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) # Shape (batch_size, num_head_tags, sequence_length, sequence_length) batch_energy = torch.exp( normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) heads = [] head_tags = [] for energy, length in zip(batch_energy.detach().cpu().numpy(), lengths): head, head_tag = decode_mst(energy, length) heads.append(head) head_tags.append(head_tag) return torch.from_numpy(numpy.stack(heads)), torch.from_numpy( numpy.stack(head_tags)) def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[ range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous( ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return self._attachment_scores.get_metric(reset)
class GraphDependencyParser(Model): """ This dependency graph_dependency_parser is a blueprint for several graph-based dependency parsers. There are several possible edge models and loss functions. For decoding, the CLE algorithm is used (during training attachments scores are usually based on greedy decoding) Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. encoder : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use to generate representations of tokens. edge_model: ``components.edge_models.EdgeModel``, required. The edge model to be used. loss_function: ``components.losses.EdgeLoss``, required. The loss function to be used. pos_tag_embedding : ``Embedding``, optional. Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model. use_mst_decoding_for_validation : ``bool``, optional (default = True). Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation. If false, decoding is greedy. dropout : ``float``, optional, (default = 0.0) The variational dropout applied to the output of the encoder and MLP layers. input_dropout : ``float``, optional, (default = 0.0) The dropout applied to the embedded text input. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, edge_model: graph_dependency_parser.components.edge_models. EdgeModel, loss_function: graph_dependency_parser.components.losses.EdgeLoss, pos_tag_embedding: Embedding = None, use_mst_decoding_for_validation: bool = True, dropout: float = 0.0, input_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, validation_evaluator: Optional[ValidationEvaluator] = None ) -> None: super(GraphDependencyParser, self).__init__(vocab, regularizer) self.validation_evaluator = validation_evaluator self.text_field_embedder = text_field_embedder self.encoder = encoder self._pos_tag_embedding = pos_tag_embedding or None self._dropout = InputVariationalDropout(dropout) self._input_dropout = Dropout(input_dropout) self._head_sentinel = torch.nn.Parameter( torch.randn([1, 1, encoder.get_output_dim()])) representation_dim = text_field_embedder.get_output_dim() if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match(representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim") check_dimensions_match(encoder.get_output_dim(), edge_model.encoder_dim(), "encoder output dim", "input dim edge model") self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = { tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE } self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info( f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self._attachment_scores = AttachmentScores() initializer(self) self.edge_model = edge_model self.loss_function = loss_function #Being able to detect what state we are in, probably not the best idea. self.current_epoch = 1 self.pass_over_data_just_started = True @overrides def forward( self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. metadata : List[Dict[str, Any]], optional (default=None) A dictionary of metadata for each batch element which has keys: words : ``List[str]``, required. The tokens in the original sentence. pos : ``List[str]``, required. The dependencies POS tags for each word. head_tags : = edge_labels torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold edge labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. edge_label_loss : ``torch.FloatTensor`` The loss contribution from the edge labels. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). edge_labels : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) encoded_text = self._dropout(encoded_text) edge_existence_scores = self.edge_model.edge_existence( encoded_text, mask) if self.training or not self.use_mst_decoding_for_validation: predicted_heads = self._greedy_decode_arcs(edge_existence_scores, mask) edge_label_logits = self.edge_model.label_scores( encoded_text, predicted_heads) predicted_edge_labels = self._greedy_decode_edge_labels( edge_label_logits) else: #Find best tree with CLE predicted_heads = cle_decode(edge_existence_scores, mask.data.sum(dim=1).long()) #With info about tree structure, get edge label scores edge_label_logits = self.edge_model.label_scores( encoded_text, predicted_heads) #Predict edge labels predicted_edge_labels = self._greedy_decode_edge_labels( edge_label_logits) output_dict = { "heads": predicted_heads, "edge_labels": predicted_edge_labels, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata], "position_in_corpus": [meta["position_in_corpus"] for meta in metadata], } if head_indices is not None and head_tags is not None: gold_edge_label_logits = self.edge_model.label_scores( encoded_text, head_indices) edge_label_loss = self.loss_function.label_loss( gold_edge_label_logits, mask, head_tags) arc_nll = self.loss_function.edge_existence_loss( edge_existence_scores, head_indices, mask) loss = arc_nll + edge_label_loss evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attachment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_edge_labels[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) output_dict["arc_loss"] = arc_nll output_dict["edge_label_loss"] = edge_label_loss output_dict["loss"] = loss if self.pass_over_data_just_started: # here we could decide if we want to start collecting info for the decoder. pass self.pass_over_data_just_started = False return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]): """ Takes the result of forward and creates human readable, non-padded dependency trees. :param output_dict: :return: output_dict with two new keys, "predicted_labels" and "predicted_heads", which are lists of lists. """ head_tags = output_dict.pop("edge_labels").cpu().detach().numpy() heads = output_dict.pop("heads").cpu().detach().numpy() mask = output_dict.pop("mask") lengths = get_lengths_from_binary_sequence_mask(mask) edge_labels = [] head_indices = [] for instance_heads, instance_tags, length in zip( heads, head_tags, lengths): instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [ self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags ] edge_labels.append(labels) head_indices.append(instance_heads) output_dict["predicted_labels"] = edge_labels output_dict["predicted_heads"] = head_indices return output_dict def _greedy_decode_edge_labels( self, edge_label_logits: torch.Tensor) -> torch.Tensor: """ Assigns edge labels according to (existing) edges. Parameters ---------- edge_label_logits: ``torch.Tensor`` of shape (batch_size, sequence_length, num_head_tags) Returns ------- head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded head tags (labels of incoming edges) of each word. """ _, head_tags = edge_label_logits.max(dim=2) return head_tags def _greedy_decode_arcs(self, existence_scores: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Decodes the head predictions by decoding the unlabeled arcs independently for each word. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- existence_scores : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. mask: torch.Tensor, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. existence_scores = existence_scores + torch.diag( existence_scores.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).byte().unsqueeze(2) existence_scores.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = existence_scores.max(dim=2) return heads def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: r = self._attachment_scores.get_metric(reset) if reset: if self.training: #done on the training data pass self.current_epoch += 1 self.pass_over_data_just_started = True else: #done on dev/test data if self.validation_evaluator: metrics = self.validation_evaluator.eval(self) for name, val in metrics.items(): r[name] = val return r
class DecompTransformerSyntaxParser(DecompTransformerParser): def __init__(self, vocab: Vocabulary, # source-side bert_encoder: BaseBertWrapper, encoder_token_embedder: TextFieldEmbedder, encoder_pos_embedding: Embedding, encoder: Seq2SeqEncoder, # target-side decoder_token_embedder: TextFieldEmbedder, decoder_node_index_embedding: Embedding, decoder_pos_embedding: Embedding, decoder: MisoTransformerDecoder, extended_pointer_generator: ExtendedPointerGenerator, tree_parser: DecompTreeParser, node_attribute_module: NodeAttributeDecoder, edge_attribute_module: EdgeAttributeDecoder, # misc label_smoothing: LabelSmoothing, target_output_namespace: str, pos_tag_namespace: str, edge_type_namespace: str, syntax_edge_type_namespace: str = None, biaffine_parser: DeepTreeParser = None, syntactic_method: str = None, dropout: float = 0.0, beam_size: int = 5, max_decoding_steps: int = 50, eps: float = 1e-20, loss_mixer: LossMixer = None, intermediate_graph: bool = False, pretrained_weights: str = None ) -> None: super().__init__(vocab=vocab, # source-side bert_encoder=bert_encoder, encoder_token_embedder=encoder_token_embedder, encoder_pos_embedding=encoder_pos_embedding, encoder=encoder, # target-side decoder_token_embedder=decoder_token_embedder, decoder_node_index_embedding=decoder_node_index_embedding, decoder_pos_embedding=decoder_pos_embedding, decoder=decoder, extended_pointer_generator=extended_pointer_generator, tree_parser=tree_parser, node_attribute_module=node_attribute_module, edge_attribute_module=edge_attribute_module, # misc label_smoothing=label_smoothing, target_output_namespace=target_output_namespace, pos_tag_namespace=pos_tag_namespace, edge_type_namespace=edge_type_namespace, syntax_edge_type_namespace=syntax_edge_type_namespace, dropout=dropout, beam_size=beam_size, max_decoding_steps=max_decoding_steps, eps=eps, pretrained_weights=pretrained_weights) self.syntactic_method = syntactic_method self.biaffine_parser = biaffine_parser self.intermediate_graph = intermediate_graph self.loss_mixer = loss_mixer self._syntax_metrics = AttachmentScores() self.syntax_las = 0.0 self.syntax_uas = 0.0 if self.pretrained_weights is not None: self.load_partial(self.pretrained_weights) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: node_pred_metrics = self._node_pred_metrics.get_metric(reset) edge_pred_metrics = self._edge_pred_metrics.get_metric(reset) decomp_metrics = self._decomp_metrics.get_metric(reset) syntax_metrics = self._syntax_metrics.get_metric(reset) metrics = OrderedDict( ppl=node_pred_metrics["ppl"], node_pred=node_pred_metrics["accuracy"] * 100, generate=node_pred_metrics["generate"] * 100, src_copy=node_pred_metrics["src_copy"] * 100, tgt_copy=node_pred_metrics["tgt_copy"] * 100, node_pearson=decomp_metrics["node_pearson_r"], edge_pearson=decomp_metrics["edge_pearson_r"], pearson=decomp_metrics["pearson_r"], uas=edge_pred_metrics["UAS"] * 100, las=edge_pred_metrics["LAS"] * 100, syn_uas=syntax_metrics["UAS"] * 100, syn_las=syntax_metrics["LAS"] * 100, ) metrics["s_f1"] = self.val_s_f1 metrics["syn_las"] = self.syntax_las metrics["syn_uas"] = self.syntax_uas return metrics def _update_syntax_scores(self): scores = self._syntax_metrics.get_metric(reset=True) self.syntax_las = scores["LAS"] * 100 self.syntax_uas = scores["UAS"] * 100 def _compute_biaffine_loss(self, biaffine_outputs, inputs): edge_prediction_loss = self._compute_edge_prediction_loss( biaffine_outputs['edge_head_ll'], biaffine_outputs['edge_type_ll'], biaffine_outputs['edge_heads'], biaffine_outputs['edge_types'], inputs['syn_edge_heads'], inputs['syn_edge_types']['syn_edge_types'], inputs['syn_valid_node_mask'], syntax=True) return edge_prediction_loss['loss_per_node'] def _parse_syntax(self, encoder_outputs: torch.Tensor, edge_head_mask: torch.Tensor, edge_heads: torch.Tensor = None, valid_node_mask: torch.Tensor = None, do_mst = False) -> Dict: parser_outputs = self.biaffine_parser( query=encoder_outputs, key=encoder_outputs, edge_head_mask=edge_head_mask, gold_edge_heads=edge_heads, decode_mst = do_mst, valid_node_mask = valid_node_mask ) return parser_outputs @overrides def _training_forward(self, inputs: Dict) -> Dict[str, torch.Tensor]: encoding_outputs = self._encode( tokens=inputs["source_tokens"], pos_tags=inputs["source_pos_tags"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"] ) just_syntax = False encoder_side = False # if we're doing encoder-side if "syn_tokens_str" in inputs.keys(): pass biaffine_outputs = self._parse_syntax(encoding_outputs['encoder_outputs'], inputs["syn_edge_head_mask"], inputs["syn_edge_heads"], do_mst = False) biaffine_loss = self._compute_biaffine_loss(biaffine_outputs, inputs) self._update_syntax_scores() encoder_side=True if self.intermediate_graph: encoder_side = DecompSyntaxParser._add_biaffine_to_encoder(encoding_outputs, biaffine_outputs) else: biaffine_loss = 0.0 if self.intermediate_graph: # TODO: put op vec back in decoding_outputs = self._decode( tokens=inputs["target_tokens"], # op_vec=inputs["op_vec"], node_indices=inputs["target_node_indices"], pos_tags=inputs["target_pos_tags"], encoder_outputs=encoding_outputs["encoder_outputs"], source_mask=inputs["source_mask"], target_mask=inputs["target_mask"] ) else: decoding_outputs = self._decode( tokens=inputs["target_tokens"], node_indices=inputs["target_node_indices"], pos_tags=inputs["target_pos_tags"], encoder_outputs=encoding_outputs["encoder_outputs"], source_mask=inputs["source_mask"], target_mask=inputs["target_mask"] ) node_prediction_outputs = self._extended_pointer_generator( inputs=decoding_outputs["attentional_tensors"], source_attention_weights=decoding_outputs["source_attention_weights"], target_attention_weights=decoding_outputs["target_attention_weights"], source_attention_map=inputs["source_attention_map"], target_attention_map=inputs["target_attention_map"] ) try: # compute node attributes node_attribute_outputs = self._node_attribute_predict( decoding_outputs["outputs"][:,:-1,:], inputs["node_attribute_truth"], inputs["node_attribute_mask"] ) except ValueError: # concat-just-syntax case node_attribute_outputs = {"loss": 0.0, "pred_dict": {"pred_attributes": []}} just_syntax = True edge_prediction_outputs = self._parse( decoding_outputs["outputs"][:,:,:], edge_head_mask=inputs["edge_head_mask"], edge_heads=inputs["edge_heads"] ) try: edge_attribute_outputs = self._edge_attribute_predict( edge_prediction_outputs["edge_type_query"], edge_prediction_outputs["edge_type_key"], edge_prediction_outputs["edge_heads"], inputs["edge_attribute_truth"], inputs["edge_attribute_mask"] ) except ValueError: # concat-just-syntax case edge_attribute_outputs = {"loss": 0.0, "pred_dict": {"pred_attributes": []}} just_syntax = True node_pred_loss = self._compute_node_prediction_loss( prob_dist=node_prediction_outputs["hybrid_prob_dist"], generation_outputs=inputs["generation_outputs"], source_copy_indices=inputs["source_copy_indices"], target_copy_indices=inputs["target_copy_indices"], source_dynamic_vocab_size=inputs["source_dynamic_vocab_size"], source_attention_weights=decoding_outputs["source_attention_weights"], coverage_history=decoding_outputs["coverage_history"] #coverage_history=None ) edge_pred_loss = self._compute_edge_prediction_loss( edge_head_ll=edge_prediction_outputs["edge_head_ll"], edge_type_ll=edge_prediction_outputs["edge_type_ll"], pred_edge_heads=edge_prediction_outputs["edge_heads"], pred_edge_types=edge_prediction_outputs["edge_types"], gold_edge_heads=inputs["edge_heads"], gold_edge_types=inputs["edge_types"], valid_node_mask=inputs["valid_node_mask"] ) if encoder_side: # learn a loss ratio loss = self.compute_training_loss(node_pred_loss["loss_per_node"], edge_pred_loss["loss_per_node"], node_attribute_outputs["loss"], edge_attribute_outputs["loss"], biaffine_loss) else: # no biaffine loss loss = node_pred_loss["loss_per_node"] + edge_pred_loss["loss_per_node"] + \ node_attribute_outputs['loss'] + edge_attribute_outputs['loss'] if not just_syntax: # compute combined pearson self._decomp_metrics(None, None, None, None, "both") return dict(loss=loss, node_attributes = node_attribute_outputs['pred_dict']['pred_attributes'], edge_attributes = edge_attribute_outputs['pred_dict']['pred_attributes']) @overrides def _test_forward(self, inputs: Dict) -> Dict: encoding_outputs = self._encode( tokens=inputs["source_tokens"], pos_tags=inputs["source_pos_tags"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"] ) # if we're doing encoder-side if self.biaffine_parser is not None: biaffine_outputs = self._parse_syntax(encoding_outputs['encoder_outputs'], inputs["syn_edge_head_mask"], None, valid_node_mask = inputs["syn_valid_node_mask"], do_mst=True) if self.intermediate_graph: encoding_outputs = DecompSyntaxParser._add_biaffine_to_encoder(encoding_outputs, biaffine_outputs) start_predictions, start_state, auxiliaries, misc = self._prepare_decoding_start_state(inputs, encoding_outputs) # all_predictions: [batch_size, beam_size, max_steps] # outputs: [batch_size, beam_size, max_steps, hidden_vector_dim] # log_probs: [batch_size, beam_size] all_predictions, outputs, log_probs, target_dynamic_vocabs = self._beam_search.search( start_predictions=start_predictions, start_state=start_state, auxiliaries=auxiliaries, step=lambda x, y, z: self._take_one_step_node_prediction(x, y, z, misc), tracked_state_name="output", tracked_auxiliary_name="target_dynamic_vocabs" ) node_predictions, node_index_predictions, edge_head_mask, valid_node_mask = self._read_node_predictions( # Remove the last one because we can't get the RNN state for the last one. predictions=all_predictions[:, 0, :-1], meta_data=inputs["instance_meta"], target_dynamic_vocabs=target_dynamic_vocabs[0], source_dynamic_vocab_size=inputs["source_dynamic_vocab_size"] ) node_attribute_outputs = self._node_attribute_predict( outputs[:,:,:-1,:], None, None ) edge_predictions = self._parse( rnn_outputs=outputs[:, 0], edge_head_mask=edge_head_mask ) (edge_head_predictions, edge_type_predictions, edge_type_ind_predictions) = self._read_edge_predictions(edge_predictions, is_syntax=False) edge_attribute_outputs = self._edge_attribute_predict( edge_predictions["edge_type_query"], edge_predictions["edge_type_key"], edge_predictions["edge_heads"], None, None ) edge_pred_loss = self._compute_edge_prediction_loss( edge_head_ll=edge_predictions["edge_head_ll"], edge_type_ll=edge_predictions["edge_type_ll"], pred_edge_heads=edge_predictions["edge_heads"], pred_edge_types=edge_predictions["edge_types"], gold_edge_heads=edge_predictions["edge_heads"], gold_edge_types=edge_predictions["edge_types"], valid_node_mask=valid_node_mask ) loss = -log_probs[:, 0].sum() / edge_pred_loss["num_nodes"] + edge_pred_loss["loss_per_node"] if "syn_tokens_str" not in inputs: inputs['syn_tokens_str'] = [] syn_edge_head_predictions, syn_edge_type_predictions, syn_edge_type_inds = [], [], [] else: syn_edge_head_predictions, syn_edge_type_predictions, syn_edge_type_inds = self._read_edge_predictions(biaffine_outputs, is_syntax = True) outputs = dict( loss=loss, nodes=node_predictions, node_indices=node_index_predictions, syn_nodes=inputs['syn_tokens_str'], syn_edge_heads=syn_edge_head_predictions, syn_edge_types=syn_edge_type_predictions, syn_edge_type_inds=syn_edge_type_inds, edge_heads=edge_head_predictions, edge_types=edge_type_predictions, edge_types_inds=edge_type_ind_predictions, node_attributes=node_attribute_outputs['pred_dict']['pred_attributes'], node_attributes_mask=node_attribute_outputs['pred_dict']['pred_mask'], edge_attributes=edge_attribute_outputs['pred_dict']['pred_attributes'], edge_attributes_mask=edge_attribute_outputs['pred_dict']['pred_mask'], ) return outputs @overrides def _prepare_decoding_start_state(self, inputs: Dict, encoding_outputs: Dict[str, torch.Tensor]) \ -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict]: batch_size = inputs["source_tokens"]["source_tokens"].size(0) bos = self.vocab.get_token_index(START_SYMBOL, self._target_output_namespace) start_predictions = inputs["source_tokens"]["source_tokens"].new_full((batch_size,), bos) start_state = { # [batch_size, *] "source_memory_bank": encoding_outputs["encoder_outputs"], "source_mask": inputs["source_mask"], "target_mask": inputs["target_mask"], "source_attention_map": inputs["source_attention_map"], "target_attention_map": inputs["source_attention_map"].new_zeros( (batch_size, self._max_decoding_steps, self._max_decoding_steps + 1)), "input_history": None, } if "op_vec" in inputs.keys() and inputs["op_vec"] is not None: start_state["op_vec"] = inputs["op_vec"] auxiliaries = { "target_dynamic_vocabs": inputs["target_dynamic_vocab"] } misc = { "batch_size": batch_size, "last_decoding_step": -1, # At <BOS>, we set it to -1. "source_dynamic_vocab_size": inputs["source_dynamic_vocab_size"], "instance_meta": inputs["instance_meta"] } return start_predictions, start_state, auxiliaries, misc @overrides def _take_one_step_node_prediction(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], auxiliaries: Dict[str, List[Any]], misc: Dict, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, List[Any]]]: inputs = self._prepare_next_inputs( predictions=last_predictions, target_attention_map=state["target_attention_map"], target_dynamic_vocabs=auxiliaries["target_dynamic_vocabs"], meta_data=misc["instance_meta"], batch_size=misc["batch_size"], last_decoding_step=misc["last_decoding_step"], source_dynamic_vocab_size=misc["source_dynamic_vocab_size"] ) # TODO: HERE we go, just concatenate "inputs" to history stored in the state # need a node index history and a token history # no need to update history inside of _prepare_next_inputs or double-iterate decoder_inputs = torch.cat([ self._decoder_token_embedder(inputs["tokens"]), self._decoder_node_index_embedding(inputs["node_indices"]), ], dim=2) # if previously decoded steps, concat them in before current input if state['input_history'] is not None: decoder_inputs = torch.cat([state['input_history'], decoder_inputs], dim = 1) # set previously decoded to current step state['input_history'] = decoder_inputs if self.intermediate_graph: # TODO: put op vec back in decoding_outputs = self._decoder.one_step_forward( inputs=decoder_inputs, source_memory_bank=state["source_memory_bank"], # op_vec=state["op_vec"], source_mask=state["source_mask"], decoding_step=misc["last_decoding_step"] + 1, total_decoding_steps=self._max_decoding_steps, coverage=state.get("coverage", None) ) else: decoding_outputs = self._decoder.one_step_forward( inputs=decoder_inputs, source_memory_bank=state["source_memory_bank"], source_mask=state["source_mask"], decoding_step=misc["last_decoding_step"] + 1, total_decoding_steps=self._max_decoding_steps, coverage=state.get("coverage", None) ) state['attentional_tensor'] = decoding_outputs['attentional_tensor'].squeeze(1) state['output'] = decoding_outputs['output'].squeeze(1) if decoding_outputs["coverage"] is not None: state["coverage"] = decoding_outputs["coverage"] node_prediction_outputs = self._extended_pointer_generator( inputs=decoding_outputs["attentional_tensor"], source_attention_weights=decoding_outputs["source_attention_weights"], target_attention_weights=decoding_outputs["target_attention_weights"], source_attention_map=state["source_attention_map"], target_attention_map=state["target_attention_map"] ) log_probs = (node_prediction_outputs["hybrid_prob_dist"] + self._eps).squeeze(1).log() misc["last_decoding_step"] += 1 return log_probs, state, auxiliaries
class BiaffineDependencyParser(Model): """ This dependency parser follows the model of [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)] (https://arxiv.org/abs/1611.01734) . Word representations are generated using a bidirectional LSTM, followed by separate biaffine classifiers for pairs of words, predicting whether a directed arc exists between the two words and the dependency label the arc should have. Decoding can either be done greedily, or the optimal Minimum Spanning Tree can be decoded using Edmond's algorithm by viewing the dependency tree as a MST on a fully connected graph, where nodes are words and edges are scored dependency arcs. # Parameters vocab : `Vocabulary`, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : `TextFieldEmbedder`, required Used to embed the `tokens` `TextField` we get as input to the model. encoder : `Seq2SeqEncoder` The encoder (with its own internal stacking) that we will use to generate representations of tokens. tag_representation_dim : `int`, required. The dimension of the MLPs used for dependency tag prediction. arc_representation_dim : `int`, required. The dimension of the MLPs used for head arc prediction. tag_feedforward : `FeedForward`, optional, (default = None). The feedforward network used to produce tag representations. By default, a 1 layer feedforward network with an elu activation is used. arc_feedforward : `FeedForward`, optional, (default = None). The feedforward network used to produce arc representations. By default, a 1 layer feedforward network with an elu activation is used. pos_tag_embedding : `Embedding`, optional. Used to embed the `pos_tags` `SequenceLabelField` we get as input to the model. use_mst_decoding_for_validation : `bool`, optional (default = True). Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation. If false, decoding is greedy. dropout : `float`, optional, (default = 0.0) The variational dropout applied to the output of the encoder and MLP layers. input_dropout : `float`, optional, (default = 0.0) The dropout applied to the embedded text input. initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`) Used to initialize the model parameters. regularizer : `RegularizerApplicator`, optional (default=`None`) If provided, will be used to calculate the regularization penalty during training. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, tag_representation_dim: int, arc_representation_dim: int, model_name: str = None, tag_feedforward: FeedForward = None, arc_feedforward: FeedForward = None, pos_tag_embedding: Embedding = None, use_mst_decoding_for_validation: bool = True, dropout: float = 0.0, input_dropout: float = 0.0, word_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder self.encoder = encoder if model_name: from src.data.token_indexers import PretrainedAutoTokenizer self._tokenizer = PretrainedAutoTokenizer.load(model_name) encoder_dim = encoder.get_output_dim() self.head_arc_feedforward = arc_feedforward or FeedForward( encoder_dim, 1, arc_representation_dim, Activation.by_name("elu")()) self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) self.arc_attention = BilinearMatrixAttention(arc_representation_dim, arc_representation_dim, use_input_biases=True) num_labels = self.vocab.get_vocab_size("head_tags") self.head_tag_feedforward = tag_feedforward or FeedForward( encoder_dim, 1, tag_representation_dim, Activation.by_name("elu")()) self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim, tag_representation_dim, num_labels) self._pos_tag_embedding = pos_tag_embedding or None self._dropout = InputVariationalDropout(dropout) self._input_dropout = Dropout(input_dropout) self._word_dropout = word_dropout self._head_sentinel = torch.nn.Parameter( torch.randn([1, 1, encoder.get_output_dim()])) representation_dim = text_field_embedder.get_output_dim() if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match( representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim", ) check_dimensions_match( tag_representation_dim, self.head_tag_feedforward.get_output_dim(), "tag representation dim", "tag feedforward output dim", ) check_dimensions_match( arc_representation_dim, self.head_arc_feedforward.get_output_dim(), "arc representation dim", "arc feedforward output dim", ) self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = { tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE } self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info( f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self._attachment_scores = AttachmentScores() initializer(self) @overrides def forward( self, # type: ignore words: TextFieldTensors, pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, lemmas: TextFieldTensors = None, feats: TextFieldTensors = None, ) -> Dict[str, torch.Tensor]: """ # Parameters words : TextFieldTensors, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer` tensors. At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens": Tensor(batch_size, sequence_length)}`. This dictionary will have the same keys as were used for the `TokenIndexers` when you created the `TextField` representing your sequence. The dictionary is designed to be passed directly to a `TextFieldEmbedder`, which knows how to combine different word representations into a single vector per token in your input. pos_tags : `torch.LongTensor`, required The output of a `SequenceLabelField` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. metadata : List[Dict[str, Any]], optional (default=None) A dictionary of metadata for each batch element which has keys: words : `List[str]`, required. The tokens in the original sentence. pos : `List[str]`, required. The dependencies POS tags for each word. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape `(batch_size, sequence_length)`. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape `(batch_size, sequence_length)`. # Returns An output dictionary consisting of: loss : `torch.FloatTensor`, optional A scalar loss to be optimised. arc_loss : `torch.FloatTensor` The loss contribution from the unlabeled arcs. loss : `torch.FloatTensor`, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : `torch.FloatTensor` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : `torch.FloatTensor` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : `torch.LongTensor` A mask denoting the padded elements in the batch. """ mask = get_text_field_mask(words) words = self._apply_token_dropout(words) embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: pos_tags_dict = {"tokens": pos_tags, "mask": mask} self._apply_token_dropout(pos_tags_dict) pos_tags = pos_tags_dict["tokens"] embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll = self._parse( encoded_text, mask, head_tags, head_indices) loss = arc_nll + tag_nll if head_indices is not None and head_tags is not None: evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores( predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices, head_tags, evaluation_mask, ) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, } self._add_metadata_to_output_dict(metadata, output_dict) return output_dict def _add_metadata_to_output_dict(self, metadata, output_dict): if metadata is not None: output_dict["words"] = [x["words"] for x in metadata] output_dict["upos"] = [x["upos"] for x in metadata] output_dict["xpos"] = [x["xpos"] for x in metadata] output_dict["feats"] = [x["feats"] for x in metadata] output_dict["lemmas"] = [x["lemmas"] for x in metadata] output_dict["ids"] = [x["ids"] for x in metadata if "ids" in x] output_dict["multiword_ids"] = [ x["multiword_ids"] for x in metadata if "multiword_ids" in x ] output_dict["multiword_forms"] = [ x["multiword_forms"] for x in metadata if "multiword_forms" in x ] return output_dict def _apply_token_dropout(self, words): # Word dropout def mask_words(tokens, drop_mask, drop_token): drop_fill = tokens.new_empty( tokens.size()).long().fill_(drop_token) return torch.where(drop_mask, drop_fill, tokens) if "tokens" in words: drop_mask = self._get_dropout_mask(words["mask"].bool(), p=self._word_dropout, training=self.training) drop_token = self.vocab.get_token_index(self.vocab._oov_token) words["tokens"] = mask_words(words["tokens"], drop_mask, drop_token) def mask_subwords(words, drop_mask, drop_token): token_ids = words["token_ids"] offsets = words["offsets"] subword_drop_mask = token_ids.new_zeros(token_ids.size()).bool() batch_size, seq_len, _ = offsets.size() for i in range(batch_size): for j in range(seq_len): start, end = offsets[i, j].tolist() subword_drop_mask[i, start:(end + 1)] = drop_mask[i, j] drop_fill = token_ids.new_empty( token_ids.size()).long().fill_(drop_token) return torch.where(subword_drop_mask, drop_fill, token_ids) if "roberta" in words: drop_mask = self._get_dropout_mask(words["roberta"]["mask"].bool(), p=self._word_dropout, training=self.training) drop_token = self._tokenizer.encode("<mask>", add_special_tokens=False)[0] words["roberta"]["token_ids"] = mask_subwords( words["roberta"], drop_mask, drop_token) @staticmethod def _get_dropout_mask(mask: torch.Tensor, p: float = 0.0, training: float = True) -> torch.LongTensor: """ During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p`` :param tokens: The current batch of padded sentences with word ids :param drop_token: The mask token :param padding_tokens: The tokens for padding the input batch :param p: The probability a word gets mapped to the unknown token :param training: Applies the dropout if set to ``True`` :return: A copy of the input batch with token dropout applied """ if training and p > 0: # Create a uniformly random mask selecting either() the original words or OOV tokens dropout_mask = (mask.new_empty(mask.size()).float().uniform_() < p) drop_mask = dropout_mask & mask return drop_mask else: return mask.new_zeros(mask.size()).bool() def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: head_tags = output_dict.pop("head_tags").cpu().detach().numpy() heads = output_dict.pop("heads").cpu().detach().numpy() mask = output_dict.pop("mask") lengths = get_lengths_from_binary_sequence_mask(mask) head_tag_labels = [] head_indices = [] for instance_heads, instance_tags, length in zip( heads, head_tags, lengths): instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [ self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags ] head_tag_labels.append(labels) head_indices.append(instance_heads) output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices return output_dict def _parse( self, encoded_text: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask, ) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask, ) return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : `torch.Tensor`, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. # Returns arc_nll : `torch.Tensor`, required. The negative log likelihood from the arc loss. tag_nll : `torch.Tensor`, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = (masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax( head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = (timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long()) # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. # Returns heads : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag( attended_arcs.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, heads) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _mst_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions using the Edmonds' Algorithm for finding minimum spanning trees on directed graphs. Nodes in the graph are the words in the sentence, and between each pair of nodes, there is an edge in each direction, where the weight of the edge corresponds to the most likely dependency label probability for that arc. The MST is then generated from this directed graph. # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. # Returns heads : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the optimally decoded heads of each word. """ batch_size, sequence_length, tag_representation_dim = head_tag_representation.size( ) lengths = mask.data.sum(dim=1).long().cpu().numpy() expanded_shape = [ batch_size, sequence_length, sequence_length, tag_representation_dim ] head_tag_representation = head_tag_representation.unsqueeze(2) head_tag_representation = head_tag_representation.expand( *expanded_shape).contiguous() child_tag_representation = child_tag_representation.unsqueeze(1) child_tag_representation = child_tag_representation.expand( *expanded_shape).contiguous() # Shape (batch_size, sequence_length, sequence_length, num_head_tags) pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) # Note that this log_softmax is over the tag dimension, and we don't consider pairs # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. # Shape (batch, num_labels,sequence_length, sequence_length) normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute( 0, 3, 1, 2) # Mask padded tokens, because we only want to consider actual words as heads. minus_inf = -1e8 minus_mask = (1 - mask.float()) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) # Shape (batch_size, sequence_length, sequence_length) normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) # Shape (batch_size, num_head_tags, sequence_length, sequence_length) # This energy tensor expresses the following relation: # energy[i,j] = "Score that i is the head of j". In this # case, we have heads pointing to their children. batch_energy = torch.exp( normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) return self._run_mst_decoding(batch_energy, lengths) @staticmethod def _run_mst_decoding( batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: heads = [] head_tags = [] for energy, length in zip(batch_energy.detach().cpu(), lengths): scores, tag_ids = energy.max(dim=0) # Although we need to include the root node so that the MST includes it, # we do not want any word to be the parent of the root node. # Here, we enforce this by setting the scores for all word -> ROOT edges # edges to be 0. scores[0, :] = 0 # Decode the heads. Because we modify the scores to prevent # adding in word -> ROOT edges, we need to find the labels ourselves. instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False) # Find the labels which correspond to the edges in the max spanning tree. instance_head_tags = [] for child, parent in enumerate(instance_heads): instance_head_tags.append(tag_ids[parent, child].item()) # We don't care what the head or tag is for the root token, but by default it's # not necesarily the same in the batched vs unbatched case, which is annoying. # Here we'll just set them to zero. instance_heads[0] = 0 instance_head_tags[0] = 0 heads.append(instance_heads) head_tags.append(instance_head_tags) return torch.from_numpy(numpy.stack(heads)), torch.from_numpy( numpy.stack(head_tags)) def _get_head_tags( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor, ) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. # Returns head_tag_logits : `torch.Tensor` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[ range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous( ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. # Parameters mask : `torch.LongTensor`, required. The original mask. pos_tags : `torch.LongTensor`, required. The pos tags for the sequence. # Returns A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return self._attachment_scores.get_metric(reset)
class AMTask(Model): """ A class that implements a task-specific model. It conceptually belongs to a formalism or corpus. """ loss_names = ["edge_existence", "edge_label", "supertagging", "lexlabel"] def __init__(self, vocab: Vocabulary, name: str, edge_model: EdgeModel, loss_function: EdgeLoss, supertagger: Supertagger, lexlabeltagger: Supertagger, supertagger_loss: SupertaggingLoss, lexlabel_loss: SupertaggingLoss, output_null_lex_label: bool = True, loss_mixing: Dict[str, float] = None, dropout: float = 0.0, validation_evaluator: Optional[Evaluator] = None, regularizer: Optional[RegularizerApplicator] = None): super().__init__(vocab, regularizer) self.name = name self.edge_model = edge_model self.supertagger = supertagger self.lexlabeltagger = lexlabeltagger self.supertagger_loss = supertagger_loss self.lexlabel_loss = lexlabel_loss self.loss_function = loss_function self.loss_mixing = loss_mixing or dict() self.validation_evaluator = validation_evaluator self.output_null_lex_label = output_null_lex_label self._dropout = InputVariationalDropout(dropout) for loss_name in AMTask.loss_names: if loss_name not in self.loss_mixing: self.loss_mixing[loss_name] = 1.0 logger.info( f"Loss name {loss_name} not found in loss_mixing, using a weight of 1.0" ) else: if self.loss_mixing[loss_name] is None: if loss_name not in ["supertagging", "lexlabel"]: raise ConfigurationError( "Only the loss mixing coefficients for supertagging and lexlabel may be None, but not " + loss_name) not_contained = set(self.loss_mixing.keys()) - set(AMTask.loss_names) if len(not_contained): logger.critical( f"The following loss name(s) are unknown: {not_contained}") raise ValueError( f"The following loss name(s) are unknown: {not_contained}") self._supertagging_acc = CategoricalAccuracy() self._top_6supertagging_acc = CategoricalAccuracy(top_k=6) self._lexlabel_acc = CategoricalAccuracy() self._attachment_scores = AttachmentScores() self.current_epoch = 0 tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = { tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE } self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info( f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self.compute_softmax_for_scores = False # set to true when dumping scores to incorporate softmax computation into computation time def check_all_dimensions_match(self, encoder_output_dim): check_dimensions_match(encoder_output_dim, self.edge_model.encoder_dim(), "encoder output dim", self.name + " input dim edge model") check_dimensions_match(encoder_output_dim, self.supertagger.encoder_dim(), "encoder output dim", self.name + " supertagger input dim") check_dimensions_match(encoder_output_dim, self.lexlabeltagger.encoder_dim(), "encoder output dim", self.name + " lexical label tagger input dim") @overrides def forward( self, # type: ignore encoded_text_parsing: torch.Tensor, encoded_text_tagging: torch.Tensor, mask: torch.Tensor, pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], supertags: torch.LongTensor = None, lexlabels: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: """ Takes a batch of encoded sentences and returns a dictionary with loss and predictions. :param encoded_text_parsing: sentence representation of shape (batch_size, seq_len, encoder_output_dim) :param encoded_text_tagging: sentence representation of shape (batch_size, seq_len, encoder_output_dim) or None if formalism of batch doesn't need supertagging :param mask: matching the sentence representation of shape (batch_size, seq_len) :param pos_tags: the accompanying pos tags (batch_size, seq_len) :param metadata: :param supertags: the accompanying supertags (batch_size, seq_len) :param lexlabels: the accompanying lexical labels (batch_size, seq_len) :param head_tags: the gold heads of every word (batch_size, seq_len) :param head_indices: the gold edge labels for each word (incoming edge, see amconll files) (batch_size, seq_len) :return: """ encoded_text_parsing = self._dropout(encoded_text_parsing) if encoded_text_tagging is not None: encoded_text_tagging = self._dropout(encoded_text_tagging) batch_size, seq_len, _ = encoded_text_parsing.shape edge_existence_scores = self.edge_model.edge_existence( encoded_text_parsing, mask) # shape (batch_size, seq_len, seq_len) # shape (batch_size, seq_len, num_supertags) if encoded_text_tagging is not None and self.supertagger is not None and self.lexlabeltagger is not None\ and self.loss_mixing["supertagging"] is not None\ and self.loss_mixing["lexlabel"] is not None: supertagger_logits = self.supertagger.compute_logits( encoded_text_tagging) lexlabel_logits = self.lexlabeltagger.compute_logits( encoded_text_tagging ) # shape (batch_size, seq_len, num label tags) else: supertagger_logits = None lexlabel_logits = None # Make predictions on data: if self.training: predicted_heads = self._greedy_decode_arcs(edge_existence_scores, mask) edge_label_logits = self.edge_model.label_scores( encoded_text_parsing, predicted_heads ) # shape (batch_size, seq_len, num edge labels) predicted_edge_labels = self._greedy_decode_edge_labels( edge_label_logits) else: # Find best tree with CLE predicted_heads = cle_decode(edge_existence_scores, mask.data.sum(dim=1).long()) # With info about tree structure, get edge label scores edge_label_logits = self.edge_model.label_scores( encoded_text_parsing, predicted_heads) # Predict edge labels predicted_edge_labels = self._greedy_decode_edge_labels( edge_label_logits) output_dict = { "heads": predicted_heads, "edge_existence_scores": edge_existence_scores, "label_logits": edge_label_logits, # shape (batch_size, seq_len, num edge labels) "full_label_logits": self.edge_model.full_label_scores( encoded_text_parsing ), #these are mostly required for the projective decoder "mask": mask, "words": [meta["words"] for meta in metadata], "attributes": [meta["attributes"] for meta in metadata], "token_ranges": [meta["token_ranges"] for meta in metadata], "encoded_text_parsing": encoded_text_parsing, "encoded_text_tagging": encoded_text_tagging, "position_in_corpus": [meta["position_in_corpus"] for meta in metadata], "formalism": self.name } if encoded_text_tagging is not None and self.loss_mixing[ "supertagging"] is not None: output_dict[ "supertag_scores"] = supertagger_logits # shape (batch_size, seq_len, num supertags) output_dict["best_supertags"] = Supertagger.top_k_supertags( supertagger_logits, 1).squeeze(2) # shape (batch_size, seq_len) if encoded_text_tagging is not None and self.loss_mixing[ "lexlabel"] is not None: if not self.output_null_lex_label: bottom_lex_label_index = self.vocab.get_token_index( "_", namespace=self.name + "_lex_labels") masked_lexlabel_logits = lexlabel_logits.clone().detach( ) # shape (batch_size, seq_len, num label tags) masked_lexlabel_logits[:, :, bottom_lex_label_index] = -1e20 else: masked_lexlabel_logits = lexlabel_logits output_dict["lexlabels"] = Supertagger.top_k_supertags( masked_lexlabel_logits, 1).squeeze(2) # shape (batch_size, seq_len) is_annotated = metadata[0]["is_annotated"] if any(metadata[i]["is_annotated"] != is_annotated for i in range(batch_size)): raise ValueError( "Batch contained inconsistent information if data is annotated." ) # Compute loss: if is_annotated and head_indices is not None and head_tags is not None: gold_edge_label_logits = self.edge_model.label_scores( encoded_text_parsing, head_indices) edge_label_loss = self.loss_function.label_loss( gold_edge_label_logits, mask, head_tags) edge_existence_loss = self.loss_function.edge_existence_loss( edge_existence_scores, head_indices, mask) # compute loss, remove loss for artificial root if encoded_text_tagging is not None and self.loss_mixing[ "supertagging"] is not None: supertagger_logits = supertagger_logits[:, 1:, :].contiguous() supertagging_nll = self.supertagger_loss.loss( supertagger_logits, supertags, mask[:, 1:]) else: supertagging_nll = None if encoded_text_tagging is not None and self.loss_mixing[ "lexlabel"] is not None: lexlabel_logits = lexlabel_logits[:, 1:, :].contiguous() lexlabel_nll = self.lexlabel_loss.loss(lexlabel_logits, lexlabels, mask[:, 1:]) else: lexlabel_nll = None loss = mix_loss(self.loss_mixing["edge_existence"], edge_existence_loss) + mix_loss( self.loss_mixing["edge_label"], edge_label_loss) if supertagging_nll is not None: loss += mix_loss(self.loss_mixing["supertagging"], supertagging_nll) if lexlabel_nll is not None: loss += mix_loss(self.loss_mixing["lexlabel"], lexlabel_nll) # Compute LAS/UAS/Supertagging acc/Lex label acc: evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attachment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. if edge_existence_loss is not None and edge_label_loss is not None: self._attachment_scores(predicted_heads[:, 1:], predicted_edge_labels[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) evaluation_mask = mask[:, 1:].contiguous() if supertagging_nll is not None: self._top_6supertagging_acc(supertagger_logits, supertags, evaluation_mask) self._supertagging_acc( supertagger_logits, supertags, evaluation_mask) # compare against gold data if lexlabel_nll is not None: self._lexlabel_acc( lexlabel_logits, lexlabels, evaluation_mask) # compare against gold data output_dict["arc_loss"] = edge_existence_loss output_dict["edge_label_loss"] = edge_label_loss output_dict["supertagging_loss"] = supertagging_nll output_dict["lexlabel_loss"] = lexlabel_nll output_dict["loss"] = loss if self.compute_softmax_for_scores: # We don't use the results but we want it to be included in the time measurement # See dump_scores what part of computation is done outside of the time measurement in forward() F.log_softmax(output_dict["full_label_logits"], 3) F.log_softmax(output_dict["supertag_scores"], 2) torch.argsort(output_dict["supertag_scores"], descending=True, dim=2) return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]): """ In contrast to its name, this function does not perform the decoding but only prepares it. :param output_dict: :return: """ if self.supertagger is not None and self.loss_mixing[ "supertagging"] is not None: #we have a supertagger, so this is proper AM dependency parsing return self.prepare_for_ftd(output_dict) else: #we don't have a supertagger, perform good old dependency parsing return self.only_cle(output_dict) def only_cle(self, output_dict: Dict[str, torch.Tensor]): """ Therefore, we take the result of forward and perform the following steps (for each sentence in batch): - remove padding :param output_dict: result of forward :return: output_dict with the following keys added: - lexlabels: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root) - supertags: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root) """ full_label_logits = output_dict.pop("full_label_logits").cpu().detach( ).numpy() #shape (batch size, seq len, seq len, num edge labels) edge_existence_scores = output_dict.pop( "edge_existence_scores").cpu().detach().numpy( ) #shape (batch size, seq len, seq len, num edge labels) heads = output_dict.pop("heads") heads_cpu = heads.cpu().detach().numpy() mask = output_dict.pop("mask") edge_label_logits = output_dict.pop("label_logits").cpu().detach( ).numpy() # shape (batch_size, seq_len, num edge labels) output_dict.pop("encoded_text_parsing") output_dict.pop("encoded_text_tagging") #don't need that lengths = get_lengths_from_binary_sequence_mask(mask) #here we collect things, in the end we will have one entry for each sentence: all_edge_label_logits = [] head_indices = [] all_full_label_logits = [] all_edge_existence_scores = [] for i, length in enumerate(lengths): instance_heads_cpu = list(heads_cpu[i, 1:length]) #apply changes to instance_heads tensor: instance_heads = heads[i, :] for j, x in enumerate(instance_heads_cpu): instance_heads[j + 1] = torch.tensor( x ) #+1 because we removed the first position from instance_heads_cpu all_edge_label_logits.append(edge_label_logits[i, 1:length, :]) all_full_label_logits.append( full_label_logits[i, :length, :length, :]) all_edge_existence_scores.append( edge_existence_scores[i, :length, :length]) head_indices.append(instance_heads_cpu) output_dict["label_logits"] = all_edge_label_logits output_dict["predicted_heads"] = head_indices output_dict["full_label_logits"] = all_full_label_logits output_dict["edge_existence_scores"] = all_edge_existence_scores return output_dict def prepare_for_ftd(self, output_dict: Dict[str, torch.Tensor]): """ This function does not perform the decoding but only prepares it. Therefore, we take the result of forward and perform the following steps (for each sentence in batch): - remove padding - identify the root of the sentence, group other root-candidates under the proper root - collect a selection of supertags to speed up computation (top k selection is done later) :param output_dict: result of forward :return: output_dict with the following keys added: - lexlabels: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root) - supertags: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root) """ t0 = time() best_supertags = output_dict.pop( "best_supertags").cpu().detach().numpy() supertag_scores = output_dict.pop( "supertag_scores") # shape (batch_size, seq_len, num supertags) full_label_logits = output_dict.pop("full_label_logits").cpu().detach( ).numpy() #shape (batch size, seq len, seq len, num edge labels) edge_existence_scores = output_dict.pop( "edge_existence_scores").cpu().detach().numpy( ) #shape (batch size, seq len, seq len, num edge labels) k = 10 if self.validation_evaluator: #retrieve k supertags from validation evaluator. if isinstance(self.validation_evaluator.predictor, AMconllPredictor): k = self.validation_evaluator.predictor.k k += 10 # perhaps there are some ill-formed supertags, make that very unlikely that there are not enough left after filtering. top_k_supertags = Supertagger.top_k_supertags( supertag_scores, k).cpu().detach().numpy() # shape (batch_size, seq_len, k) supertag_scores = supertag_scores.cpu().detach().numpy() lexlabels = output_dict.pop( "lexlabels").cpu().detach().numpy() #shape (batch_size, seq_len) heads = output_dict.pop("heads") heads_cpu = heads.cpu().detach().numpy() mask = output_dict.pop("mask") edge_label_logits = output_dict.pop("label_logits").cpu().detach( ).numpy() # shape (batch_size, seq_len, num edge labels) encoded_text_parsing = output_dict.pop("encoded_text_parsing") output_dict.pop("encoded_text_tagging") #don't need that lengths = get_lengths_from_binary_sequence_mask(mask) #here we collect things, in the end we will have one entry for each sentence: all_edge_label_logits = [] all_supertags = [] head_indices = [] roots = [] all_predicted_lex_labels = [] all_full_label_logits = [] all_edge_existence_scores = [] all_supertag_scores = [] #we need the following to identify the root root_edge_label_id = self.vocab.get_token_index("ROOT", namespace=self.name + "_head_tags") bot_id = self.vocab.get_token_index(AMSentence.get_bottom_supertag(), namespace=self.name + "_supertag_labels") for i, length in enumerate(lengths): instance_heads_cpu = list(heads_cpu[i, 1:length]) #Postprocess heads and find root of sentence: instance_heads_cpu, root = find_root( instance_heads_cpu, best_supertags[i, 1:length], edge_label_logits[i, 1:length, :], root_edge_label_id, bot_id, modify=True) roots.append(root) #apply changes to instance_heads tensor: instance_heads = heads[i, :] for j, x in enumerate(instance_heads_cpu): instance_heads[j + 1] = torch.tensor( x ) #+1 because we removed the first position from instance_heads_cpu # re-calculate edge label logits since heads might have changed: label_logits = self.edge_model.label_scores( encoded_text_parsing[i].unsqueeze(0), instance_heads.unsqueeze(0)).squeeze(0).detach().cpu().numpy() #(un)squeeze: fake batch dimension all_edge_label_logits.append(label_logits[1:length, :]) all_full_label_logits.append( full_label_logits[i, :length, :length, :]) all_edge_existence_scores.append( edge_existence_scores[i, :length, :length]) #calculate supertags for this sentence: all_supertag_scores.append(supertag_scores[ i, 1:length, :]) #new shape (sent length, num supertags) supertags_for_this_sentence = [] for word in range(1, length): supertags_for_this_word = [] for top_k in top_k_supertags[i, word]: fragment, typ = AMSentence.split_supertag( self.vocab.get_token_from_index(top_k, namespace=self.name + "_supertag_labels")) score = supertag_scores[i, word, top_k] supertags_for_this_word.append((score, fragment, typ)) if bot_id not in top_k_supertags[ i, word]: #\bot is not in the top k, but we have to add it anyway in order for the decoder to work properly. fragment, typ = AMSentence.split_supertag( AMSentence.get_bottom_supertag()) supertags_for_this_word.append( (supertag_scores[i, word, bot_id], fragment, typ)) supertags_for_this_sentence.append(supertags_for_this_word) all_supertags.append(supertags_for_this_sentence) all_predicted_lex_labels.append([ self.vocab.get_token_from_index(label, namespace=self.name + "_lex_labels") for label in lexlabels[i, 1:length] ]) head_indices.append(instance_heads_cpu) t1 = time() normalized_diff = (t1 - t0) / len(lengths) output_dict["normalized_prepare_ftd_time"] = [ normalized_diff for _ in range(len(lengths)) ] output_dict["lexlabels"] = all_predicted_lex_labels output_dict["supertags"] = all_supertags output_dict["root"] = roots output_dict["label_logits"] = all_edge_label_logits output_dict["predicted_heads"] = head_indices output_dict["full_label_logits"] = all_full_label_logits output_dict["edge_existence_scores"] = all_edge_existence_scores output_dict["supertag_scores"] = all_supertag_scores return output_dict def _greedy_decode_edge_labels( self, edge_label_logits: torch.Tensor) -> torch.Tensor: """ Assigns edge labels according to (existing) edges. Parameters ---------- edge_label_logits: ``torch.Tensor`` of shape (batch_size, sequence_length, num_head_tags) Returns ------- head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded head tags (labels of incoming edges) of each word. """ _, head_tags = edge_label_logits.max(dim=2) return head_tags def _greedy_decode_arcs(self, existence_scores: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Decodes the head predictions by decoding the unlabeled arcs independently for each word. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- existence_scores : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. mask: torch.Tensor, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. existence_scores = existence_scores + torch.diag( existence_scores.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).byte().unsqueeze(2) existence_scores.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = existence_scores.max(dim=2) return heads def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask def metrics(self, parser_model, reset: bool = False, model_path=None) -> Dict[str, float]: """ Is called by a GraphDependencyParser :param parser_model: a GraphDependencyParser :param reset: :return: """ r = self.get_metrics(reset) if reset: #epoch done if self.training: #done on the training data self.current_epoch += 1 else: #done on dev/test data if self.validation_evaluator: metrics = self.validation_evaluator.eval( parser_model, self.current_epoch, model_path) for name, val in metrics.items(): r[name] = val return r def get_metrics(self, reset: bool = False) -> Dict[str, float]: r = self._attachment_scores.get_metric(reset) if self.loss_mixing["supertagging"] is not None: r["Constant_Acc"] = self._supertagging_acc.get_metric(reset) r["Constant_Acc_6_best"] = self._top_6supertagging_acc.get_metric( reset) if self.loss_mixing["lexlabel"] is not None: r["Label_Acc"] = self._lexlabel_acc.get_metric(reset) las = r["LAS"] if "Constant_Acc" in r: r["mean_constant_acc_las"] = (las + r["Constant_Acc"]) / 2 return r
class DecompSyntaxParser(DecompParser): def __init__( self, vocab: Vocabulary, # source-side bert_encoder: BaseBertWrapper, encoder_token_embedder: TextFieldEmbedder, encoder_pos_embedding: Embedding, encoder: Seq2SeqEncoder, # target-side decoder_token_embedder: TextFieldEmbedder, decoder_node_index_embedding: Embedding, decoder_pos_embedding: Embedding, decoder: RNNDecoder, extended_pointer_generator: ExtendedPointerGenerator, tree_parser: DecompTreeParser, node_attribute_module: NodeAttributeDecoder, edge_attribute_module: EdgeAttributeDecoder, # misc label_smoothing: LabelSmoothing, target_output_namespace: str, pos_tag_namespace: str, edge_type_namespace: str, syntax_edge_type_namespace: str = None, biaffine_parser: DeepTreeParser = None, syntactic_method: str = None, dropout: float = 0.0, beam_size: int = 5, max_decoding_steps: int = 50, eps: float = 1e-20, loss_mixer: LossMixer = None, intermediate_graph: bool = False, pretrained_weights: str = None, ) -> None: super(DecompSyntaxParser, self).__init__( vocab=vocab, # source-side bert_encoder=bert_encoder, encoder_token_embedder=encoder_token_embedder, encoder_pos_embedding=encoder_pos_embedding, encoder=encoder, # target-side decoder_token_embedder=decoder_token_embedder, decoder_node_index_embedding=decoder_node_index_embedding, decoder_pos_embedding=decoder_pos_embedding, decoder=decoder, extended_pointer_generator=extended_pointer_generator, tree_parser=tree_parser, node_attribute_module=node_attribute_module, edge_attribute_module=edge_attribute_module, # misc label_smoothing=label_smoothing, target_output_namespace=target_output_namespace, pos_tag_namespace=pos_tag_namespace, edge_type_namespace=edge_type_namespace, syntax_edge_type_namespace=syntax_edge_type_namespace, dropout=dropout, beam_size=beam_size, max_decoding_steps=max_decoding_steps, eps=eps, pretrained_weights=pretrained_weights) self.syntactic_method = syntactic_method self.biaffine_parser = biaffine_parser self.loss_mixer = loss_mixer self.intermediate_graph = intermediate_graph self._syntax_metrics = AttachmentScores() self.syntax_las = 0.0 self.syntax_uas = 0.0 if self.pretrained_weights is not None: self.load_partial(self.pretrained_weights) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: node_pred_metrics = self._node_pred_metrics.get_metric(reset) edge_pred_metrics = self._edge_pred_metrics.get_metric(reset) decomp_metrics = self._decomp_metrics.get_metric(reset) syntax_metrics = self._syntax_metrics.get_metric(reset) metrics = OrderedDict( ppl=node_pred_metrics["ppl"], node_pred=node_pred_metrics["accuracy"] * 100, generate=node_pred_metrics["generate"] * 100, src_copy=node_pred_metrics["src_copy"] * 100, tgt_copy=node_pred_metrics["tgt_copy"] * 100, node_pearson=decomp_metrics["node_pearson_r"], edge_pearson=decomp_metrics["edge_pearson_r"], pearson=decomp_metrics["pearson_r"], uas=edge_pred_metrics["UAS"] * 100, las=edge_pred_metrics["LAS"] * 100, syn_uas=syntax_metrics["UAS"] * 100, syn_las=syntax_metrics["LAS"] * 100, ) metrics["s_f1"] = self.val_s_f1 metrics["syn_las"] = self.syntax_las metrics["syn_uas"] = self.syntax_uas return metrics def _update_syntax_scores(self): scores = self._syntax_metrics.get_metric(reset=True) self.syntax_las = scores["LAS"] * 100 self.syntax_uas = scores["UAS"] * 100 def _compute_biaffine_loss(self, biaffine_outputs, inputs): edge_prediction_loss = self._compute_edge_prediction_loss( biaffine_outputs['edge_head_ll'], biaffine_outputs['edge_type_ll'], biaffine_outputs['edge_heads'], biaffine_outputs['edge_types'], inputs['syn_edge_heads'], inputs['syn_edge_types']['syn_edge_types'], inputs['syn_valid_node_mask'], syntax=True) return edge_prediction_loss['loss_per_node'] def _parse_syntax(self, encoder_outputs: torch.Tensor, edge_head_mask: torch.Tensor, edge_heads: torch.Tensor = None, valid_node_mask: torch.Tensor = None, do_mst=False) -> Dict: parser_outputs = self.biaffine_parser(query=encoder_outputs, key=encoder_outputs, edge_head_mask=edge_head_mask, gold_edge_heads=edge_heads, decode_mst=do_mst, valid_node_mask=valid_node_mask) return parser_outputs @staticmethod def _add_biaffine_to_encoder(encoding_outputs, biaffine_outputs): enc_outputs = encoding_outputs["encoder_outputs"] # concatenate in biaffine reps enc_outputs = torch.cat([enc_outputs, biaffine_outputs["edge_reps"]], dim=2) encoding_outputs["encoder_outputs"] = enc_outputs return encoding_outputs @overrides def _training_forward(self, inputs: Dict) -> Dict[str, torch.Tensor]: encoding_outputs = self._encode( tokens=inputs["source_tokens"], pos_tags=inputs["source_pos_tags"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"]) just_syntax = False encoder_side = False # if we're doing encoder-side if "syn_tokens_str" in inputs.keys(): biaffine_outputs = self._parse_syntax( encoding_outputs['encoder_outputs'], inputs["syn_edge_head_mask"], inputs["syn_edge_heads"], do_mst=False) biaffine_loss = self._compute_biaffine_loss( biaffine_outputs, inputs) self._update_syntax_scores() encoder_side = True if self.intermediate_graph: encoding_outputs = DecompSyntaxParser._add_biaffine_to_encoder( encoding_outputs, biaffine_outputs) else: biaffine_loss = 0.0 decoding_outputs = self._decode( tokens=inputs["target_tokens"], node_indices=inputs["target_node_indices"], pos_tags=inputs["target_pos_tags"], encoder_outputs=encoding_outputs["encoder_outputs"], hidden_states=encoding_outputs["final_states"], mask=inputs["source_mask"]) node_prediction_outputs = self._extended_pointer_generator( inputs=decoding_outputs["attentional_tensors"], source_attention_weights=decoding_outputs[ "source_attention_weights"], target_attention_weights=decoding_outputs[ "target_attention_weights"], source_attention_map=inputs["source_attention_map"], target_attention_map=inputs["target_attention_map"]) try: # compute node attributes node_attribute_outputs = self._node_attribute_predict( decoding_outputs["rnn_outputs"][:, :-1, :], inputs["node_attribute_truth"], inputs["node_attribute_mask"]) except ValueError: # concat-just-syntax case node_attribute_outputs = { "loss": 0.0, "pred_dict": { "pred_attributes": [] } } just_syntax = True edge_prediction_outputs = self._parse( rnn_outputs=decoding_outputs["rnn_outputs"], edge_head_mask=inputs["edge_head_mask"], edge_heads=inputs["edge_heads"]) try: edge_attribute_outputs = self._edge_attribute_predict( edge_prediction_outputs["edge_type_query"], edge_prediction_outputs["edge_type_key"], edge_prediction_outputs["edge_heads"], inputs["edge_attribute_truth"], inputs["edge_attribute_mask"]) except ValueError: # concat-just-syntax case edge_attribute_outputs = { "loss": 0.0, "pred_dict": { "pred_attributes": [] } } just_syntax = True node_pred_loss = self._compute_node_prediction_loss( prob_dist=node_prediction_outputs["hybrid_prob_dist"], generation_outputs=inputs["generation_outputs"], source_copy_indices=inputs["source_copy_indices"], target_copy_indices=inputs["target_copy_indices"], source_dynamic_vocab_size=inputs["source_dynamic_vocab_size"], source_attention_weights=decoding_outputs[ "source_attention_weights"], coverage_history=decoding_outputs["coverage_history"]) edge_pred_loss = self._compute_edge_prediction_loss( edge_head_ll=edge_prediction_outputs["edge_head_ll"], edge_type_ll=edge_prediction_outputs["edge_type_ll"], pred_edge_heads=edge_prediction_outputs["edge_heads"], pred_edge_types=edge_prediction_outputs["edge_types"], gold_edge_heads=inputs["edge_heads"], gold_edge_types=inputs["edge_types"], valid_node_mask=inputs["valid_node_mask"]) if encoder_side: # learn a loss ratio loss = self.compute_training_loss(node_pred_loss["loss_per_node"], edge_pred_loss["loss_per_node"], node_attribute_outputs["loss"], edge_attribute_outputs["loss"], biaffine_loss) else: # no biaffine loss loss = node_pred_loss["loss_per_node"] + edge_pred_loss["loss_per_node"] + \ node_attribute_outputs['loss'] + edge_attribute_outputs['loss'] if not just_syntax: # compute combined pearson self._decomp_metrics(None, None, None, None, "both") return dict(loss=loss, node_attributes=node_attribute_outputs['pred_dict'] ['pred_attributes'], edge_attributes=edge_attribute_outputs['pred_dict'] ['pred_attributes']) def compute_training_loss(self, node_loss, edge_loss, node_attr_loss, edge_attr_loss, biaffine_loss): sem_loss = node_loss + edge_loss + node_attr_loss + edge_attr_loss syn_loss = biaffine_loss if self.loss_mixer is not None: return self.loss_mixer(sem_loss, syn_loss) # default to 1-to-1 weighting return sem_loss + syn_loss @overrides def _test_forward(self, inputs: Dict) -> Dict: encoding_outputs = self._encode( tokens=inputs["source_tokens"], pos_tags=inputs["source_pos_tags"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"]) # if we're doing encoder-side if self.biaffine_parser is not None: biaffine_outputs = self._parse_syntax( encoding_outputs['encoder_outputs'], inputs["syn_edge_head_mask"], None, valid_node_mask=inputs["syn_valid_node_mask"], do_mst=True) if self.intermediate_graph: encoding_outputs = self._add_biaffine_to_encoder( encoding_outputs, biaffine_outputs) start_predictions, start_state, auxiliaries, misc = self._prepare_decoding_start_state( inputs, encoding_outputs) # all_predictions: [batch_size, beam_size, max_steps] # log_probs: [batch_size, beam_size] all_predictions, rnn_outputs, log_probs, target_dynamic_vocabs = self._beam_search.search( start_predictions=start_predictions, start_state=start_state, auxiliaries=auxiliaries, step=lambda x, y, z: self._take_one_step_node_prediction( x, y, z, misc), tracked_state_name="rnn_output", tracked_auxiliary_name="target_dynamic_vocabs") node_predictions, node_index_predictions, edge_head_mask, valid_node_mask = self._read_node_predictions( # Remove the last one because we can't get the RNN state for the last one. predictions=all_predictions[:, 0, :-1], meta_data=inputs["instance_meta"], target_dynamic_vocabs=target_dynamic_vocabs[0], source_dynamic_vocab_size=inputs["source_dynamic_vocab_size"]) node_attribute_outputs = self._node_attribute_predict( rnn_outputs[:, :, :-1, :], None, None) edge_predictions = self._parse( # Remove the first RNN state because it represents <BOS>. rnn_outputs=rnn_outputs[:, 0], edge_head_mask=edge_head_mask) (edge_head_predictions, edge_type_predictions, edge_type_ind_predictions ) = self._read_edge_predictions(edge_predictions) edge_attribute_outputs = self._edge_attribute_predict( edge_predictions["edge_type_query"], edge_predictions["edge_type_key"], edge_predictions["edge_heads"], None, None) edge_pred_loss = self._compute_edge_prediction_loss( edge_head_ll=edge_predictions["edge_head_ll"], edge_type_ll=edge_predictions["edge_type_ll"], pred_edge_heads=edge_predictions["edge_heads"], pred_edge_types=edge_predictions["edge_types"], gold_edge_heads=edge_predictions["edge_heads"], gold_edge_types=edge_predictions["edge_types"], valid_node_mask=valid_node_mask) loss = -log_probs[:, 0].sum( ) / edge_pred_loss["num_nodes"] + edge_pred_loss["loss_per_node"] if "syn_tokens_str" not in inputs: inputs['syn_tokens_str'] = [] #biaffine_outputs = {"edge_heads": [], "edge_types":[]} syn_edge_head_predictions, syn_edge_type_predictions, syn_edge_type_inds = [], [], [] else: syn_edge_head_predictions, syn_edge_type_predictions, syn_edge_type_inds = self._read_edge_predictions( biaffine_outputs, is_syntax=True) outputs = dict( loss=loss, nodes=node_predictions, node_indices=node_index_predictions, syn_nodes=inputs['syn_tokens_str'], syn_edge_heads=syn_edge_head_predictions, syn_edge_types=syn_edge_type_predictions, syn_edge_type_inds=syn_edge_type_inds, edge_heads=edge_head_predictions, edge_types=edge_type_predictions, edge_types_inds=edge_type_ind_predictions, node_attributes=node_attribute_outputs['pred_dict'] ['pred_attributes'], node_attributes_mask=node_attribute_outputs['pred_dict'] ['pred_mask'], edge_attributes=edge_attribute_outputs['pred_dict'] ['pred_attributes'], edge_attributes_mask=edge_attribute_outputs['pred_dict'] ['pred_mask'], ) return outputs
class AttachmentScoresTest(AllenNlpTestCase): def setUp(self): super().setUp() self.scorer = AttachmentScores() self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]]) self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]]) self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]]) self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]]) self.mask = torch.Tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) def test_perfect_scores(self): self.scorer(self.predictions, self.label_predictions, self.gold_indices, self.gold_labels, self.mask) for value in self.scorer.get_metric().values(): assert value == 1.0 def test_unlabeled_accuracy_ignores_incorrect_labels(self): label_predictions = self.label_predictions # Change some stuff so our 4 of our label predictions are wrong. label_predictions[0, 3:] = 3 label_predictions[1, 0] = 7 self.scorer(self.predictions, label_predictions, self.gold_indices, self.gold_labels, self.mask) metrics = self.scorer.get_metric() assert metrics["UAS"] == 1.0 assert metrics["UEM"] == 1.0 # 4 / 12 labels were wrong and 2 positions # are masked, so 6/10 = 0.6 LAS. assert metrics["LAS"] == 0.6 # Neither should have labeled exact match. assert metrics["LEM"] == 0.0 def test_labeled_accuracy_is_affected_by_incorrect_heads(self): predictions = self.predictions # Change some stuff so our 4 of our predictions are wrong. predictions[0, 3:] = 3 predictions[1, 0] = 7 # This one is in the padded part, so it shouldn't affect anything. predictions[1, 5] = 7 self.scorer(predictions, self.label_predictions, self.gold_indices, self.gold_labels, self.mask) metrics = self.scorer.get_metric() # 4 heads are incorrect, so the unlabeled score should be # 6/10 = 0.6 LAS. assert metrics["UAS"] == 0.6 # All the labels were correct, but some heads # were wrong, so the LAS should equal the UAS. assert metrics["LAS"] == 0.6 # Neither batch element had a perfect labeled or unlabeled EM. assert metrics["LEM"] == 0.0 assert metrics["UEM"] == 0.0 def test_attachment_scores_can_ignore_labels(self): scorer = AttachmentScores(ignore_classes=[1]) label_predictions = self.label_predictions # Change the predictions where the gold label is 1; # as we are ignoring 1, we should still get a perfect score. label_predictions[0, 3] = 2 scorer(self.predictions, label_predictions, self.gold_indices, self.gold_labels, self.mask) for value in scorer.get_metric().values(): assert value == 1.0
class DependencyParser(Model): """ This dependency parser follows the model of ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . Word representations are generated using a bidirectional LSTM, followed by separate biaffine classifiers for pairs of words, predicting whether a directed arc exists between the two words and the dependency label the arc should have. Decoding can either be done greedily, or the optimal Minimum Spanning Tree can be decoded using Edmond's algorithm by viewing the dependency tree as a MST on a fully connected graph, where nodes are words and edges are scored dependency arcs. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. encoder : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use to generate representations of tokens. tag_representation_dim : ``int``, required. The dimension of the MLPs used for dependency tag prediction. arc_representation_dim : ``int``, required. The dimension of the MLPs used for head arc prediction. tag_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce tag representations. By default, a 1 layer feedforward network with an elu activation is used. arc_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce arc representations. By default, a 1 layer feedforward network with an elu activation is used. pos_tag_embedding : ``Embedding``, optional. Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model. use_mst_decoding_for_validation : ``bool``, optional (default = True). Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation. If false, decoding is greedy. dropout : ``float``, optional, (default = 0.0) The variational dropout applied to the output of the encoder and MLP layers. input_dropout : ``float``, optional, (default = 0.0) The dropout applied to the embedded text input. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, tag_representation_dim: int, arc_representation_dim: int, lemmatize_helper: LemmatizeHelper, task_config: TaskConfig, morpho_vector_dim: int = 0, gram_val_representation_dim: int = -1, lemma_representation_dim: int = -1, tag_feedforward: FeedForward = None, arc_feedforward: FeedForward = None, pos_tag_embedding: Embedding = None, use_mst_decoding_for_validation: bool = True, dropout: float = 0.0, input_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(DependencyParser, self).__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder self.encoder = encoder self.lemmatize_helper = lemmatize_helper self.task_config = task_config encoder_dim = encoder.get_output_dim() self.head_arc_feedforward = arc_feedforward or \ FeedForward(encoder_dim, 1, arc_representation_dim, Activation.by_name("elu")()) self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) self.arc_attention = BilinearMatrixAttention(arc_representation_dim, arc_representation_dim, use_input_biases=True) num_labels = self.vocab.get_vocab_size("head_tags") self.head_tag_feedforward = tag_feedforward or \ FeedForward(encoder_dim, 1, tag_representation_dim, Activation.by_name("elu")()) self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim, tag_representation_dim, num_labels) self._pos_tag_embedding = pos_tag_embedding or None assert self.task_config.params.get("use_pos_tag", False) == (self._pos_tag_embedding is not None) self._dropout = InputVariationalDropout(dropout) self._input_dropout = Dropout(input_dropout) self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder.get_output_dim()])) # рекуррентная сеть, порождающая цепочку вариантов разбора self.multilabeler_lstm = torch.nn.LSTM(encoder_dim, encoder_dim, num_layers=1, batch_first=True, bidirectional=False) if gram_val_representation_dim <= 0: self._gram_val_output = torch.nn.Linear(encoder_dim, self.vocab.get_vocab_size("grammar_value_tags")) else: self._gram_val_output = torch.nn.Sequential( Dropout(dropout), torch.nn.Linear(encoder_dim, gram_val_representation_dim), Dropout(dropout), torch.nn.Linear(gram_val_representation_dim, self.vocab.get_vocab_size("grammar_value_tags")) ) if lemma_representation_dim <= 0: self._lemma_output = torch.nn.Linear(encoder_dim, len(lemmatize_helper)) else: self._lemma_output = torch.nn.Sequential( Dropout(dropout), torch.nn.Linear(encoder_dim, lemma_representation_dim), Dropout(dropout), torch.nn.Linear(lemma_representation_dim, len(lemmatize_helper)) ) representation_dim = text_field_embedder.get_output_dim() + morpho_vector_dim if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match(representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim") check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(), "tag representation dim", "tag feedforward output dim") check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(), "arc representation dim", "arc feedforward output dim") self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE} self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self._attachment_scores = AttachmentScores() self._gram_val_prediction_accuracy = CategoricalAccuracy() self._lemma_prediction_accuracy = CategoricalAccuracy() initializer(self) @overrides def forward(self, # type: ignore words: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], morpho_embedding: torch.FloatTensor = None, pos_tags: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. metadata : List[Dict[str, Any]], optional (default=None) A dictionary of metadata for each batch element which has keys: words : ``List[str]``, required. The tokens in the original sentence. pos : ``List[str]``, required. The dependencies POS tags for each word. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) # для отладки # for name, param in self.named_parameters(): # if torch.any(torch.isnan(param)): # assert False, "NaN in {} layer".format(name) # if torch.any(torch.isinf(param)): # assert False, "INF in {} layer".format(name) if morpho_embedding is not None: embedded_text_input = torch.cat([embedded_text_input, morpho_embedding], -1) if grammar_values is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(grammar_values) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) output_dict = self._parse(embedded_text_input, mask, head_tags, head_indices, grammar_values, lemma_indices) if self.task_config.task_type == "multitask": losses = ["arc_nll", "tag_nll", "grammar_nll", "lemma_nll"] elif self.task_config.task_type == "single": if self.task_config.params["model"] == "morphology": losses = ["grammar_nll"] elif self.task_config.params["model"] == "lemmatization": losses = ["lemma_nll"] elif self.task_config.params["model"] == "syntax": losses = ["arc_nll", "tag_nll"] else: assert False, "Unknown model type {}".format(self.task_config.params["model"]) else: assert False, "Unknown task type {}".format(self.task_config.task_type) output_dict["loss"] = sum(output_dict[loss_name] for loss_name in losses) if head_indices is not None and head_tags is not None: evaluation_mask = self._get_mask_for_eval(mask, pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(output_dict["heads"][:, 1:], output_dict["head_tags"][:, 1:], head_indices, head_tags, evaluation_mask) output_dict["words"] = [meta["words"] for meta in metadata] if metadata and "pos" in metadata[0]: output_dict["pos"] = [meta["pos"] for meta in metadata] return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: head_tags = output_dict.pop("head_tags").cpu().detach().numpy() heads = output_dict.pop("heads").cpu().detach().numpy() predicted_gram_vals = output_dict.pop("gram_vals").cpu().detach().numpy() predicted_lemmas = output_dict.pop("lemmas").cpu().detach().numpy() mask = output_dict.pop("mask") lengths = get_lengths_from_binary_sequence_mask(mask) assert len(head_tags) == len(heads) == len(lengths) == len(predicted_gram_vals) == len(predicted_lemmas) head_tag_labels, head_indices, decoded_gram_vals, decoded_lemmas = [], [], [], [] for instance_index in range(len(head_tags)): instance_heads, instance_tags = heads[instance_index], head_tags[instance_index] words, length = output_dict["words"][instance_index], lengths[instance_index] gram_vals, lemmas = predicted_gram_vals[instance_index], predicted_lemmas[instance_index] words = words[: length.item() - 1] gram_vals = gram_vals[: length.item() - 1, :] lemmas = lemmas[: length.item() - 1, :] instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags] head_tag_labels.append(labels) head_indices.append(instance_heads) inst_gram_vals = [] for tok_gram_vals in gram_vals: dtgv = [self.vocab.get_token_from_index(gram_val, "grammar_value_tags") for gram_val in tok_gram_vals] inst_gram_vals.append(dtgv) decoded_gram_vals.append(inst_gram_vals) # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: decoded_gram_vals = {}".format(decoded_gram_vals) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") inst_lemmas = [] for word, word_lrules in zip(words, lemmas): dtl = [self.lemmatize_helper.lemmatize(word, lrule) for lrule in word_lrules] inst_lemmas.append(dtl) decoded_lemmas.append(inst_lemmas) if self.task_config.task_type == "multitask": output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices output_dict["predicted_gram_vals"] = decoded_gram_vals output_dict["predicted_lemmas"] = decoded_lemmas elif self.task_config.task_type == "single": if self.task_config.params["model"] == "morphology": output_dict["predicted_gram_vals"] = decoded_gram_vals elif self.task_config.params["model"] == "lemmatization": output_dict["predicted_lemmas"] = decoded_lemmas elif self.task_config.params["model"] == "syntax": output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices else: assert False, "Unknown model type {}".format(self.task_config.params["model"]) else: assert False, "Unknown task type {}".format(self.task_config.task_type) return output_dict def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) # добавим измеремение, которое каждому выходу энкодера ставит в соответствие три его копии encoded_text_3 = encoded_text encoded_text_3 = torch.unsqueeze(encoded_text_3, 2) encoded_text_3 = encoded_text_3.repeat(1,1,3,1) # пропустим три копии вектора (с выхода энкодера) через lstm seq_len = encoded_text.size()[1] emb_div_val = encoded_text.size()[2] multi_triplets = torch.reshape(encoded_text_3, (-1, 3, emb_div_val)) label_variants, _ = self.multilabeler_lstm(multi_triplets) batched_label_variants = torch.reshape(label_variants, (-1, seq_len, 3, emb_div_val)) # # отладочный вывод # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: encoded_text.size() = {}".format(encoded_text.size()) ) # print( "ITLOG: encoded_text_3.size() = {}".format(encoded_text_3.size()) ) # print( "ITLOG: multi_triplets.size() = {}".format(multi_triplets.size()) ) # print( "ITLOG: label_variants.size() = {}".format(label_variants.size()) ) # print( "ITLOG: batched_label_variants.size() = {}".format(batched_label_variants.size()) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") # grammar_value_logits = self._gram_val_output(encoded_text) grammar_value_logits = self._gram_val_output(batched_label_variants) # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: grammar_value_logits.size() = {}".format(grammar_value_logits.size()) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") # grammar_value_logits = grammar_value_logits.select(2, 0) predicted_gram_vals = grammar_value_logits.argmax(-1) # lemma_logits = self._lemma_output(encoded_text) lemma_logits = self._lemma_output(batched_label_variants) predicted_lemmas = lemma_logits.argmax(-1) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: token_mask_3 = token_mask token_mask_3 = torch.unsqueeze(token_mask_3, 2) token_mask_3 = token_mask_3.repeat(1,1,3) # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: token_mask.size = {}".format(token_mask.size()) ) # print( "ITLOG: token_mask_3.size = {}".format(token_mask_3.size()) ) # print( "ITLOG: token_mask_3 = {}".format(token_mask_3) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") grammar_nll = self._update_multiclass_prediction_metrics_3( logits=grammar_value_logits, targets=grammar_values, mask=token_mask_3, accuracy_metric=self._gram_val_prediction_accuracy ) lemma_nll = torch.tensor(0.) if lemma_indices is not None: token_mask_3 = token_mask token_mask_3 = torch.unsqueeze(token_mask_3, 2) token_mask_3 = token_mask_3.repeat(1,1,3) lemma_nll = self._update_multiclass_prediction_metrics_3( logits=lemma_logits, targets=lemma_indices, mask=token_mask_3, accuracy_metric=self._lemma_prediction_accuracy #, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX ) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, } return output_dict @staticmethod def _update_multiclass_prediction_metrics(logits, targets, mask, accuracy_metric, masked_index=None): # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: logits.size() = {}".format(logits.size()) ) # print( "ITLOG: targets.size() = {}".format(targets.size()) ) # print( "ITLOG: mask = {}".format(mask) ) # print( "ITLOG: targets = {}".format(targets) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") accuracy_metric(logits, targets, mask) logits = logits.view(-1, logits.shape[-1]) loss = F.cross_entropy(logits, targets.view(-1), reduction='none') if masked_index is not None: mask = mask * (targets != masked_index) loss_mask = mask.view(-1) return (loss * loss_mask).sum() / loss_mask.sum() @staticmethod def _update_multiclass_prediction_metrics_3(logits, targets, mask, accuracy_metric, masked_index=None): # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: logits.size() = {}".format(logits.size()) ) # print( "ITLOG: targets.size() = {}".format(targets.size()) ) # print( "ITLOG: mask = {}".format(mask) ) # print( "ITLOG: targets = {}".format(targets) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") accuracy_metric(logits, targets, mask) # будем вычислять cross_entropy только для незамаскированных элементов тензора # non_masked_coords = torch.nonzero(mask) # print( "ITLOG: non_masked_coords = {}".format(non_masked_coords) ) bmask = torch.unsqueeze(mask, -1) logits_m = torch.masked_select(logits, bmask.bool()) logits_m = torch.reshape(logits_m, (-1, logits.shape[-1])) # print( "ITLOG: logits_m.size() = {}".format(logits_m.size()) ) # print( "ITLOG: logits_m = {}".format(logits_m) ) targets_m = torch.masked_select(targets, mask.bool()) # print( "ITLOG: targets_m.size() = {}".format(targets_m.size()) ) # print( "ITLOG: targets_m = {}".format(targets_m) ) loss = F.cross_entropy(logits_m, targets_m, reduction='none') # if masked_index is not None: # mask = mask * (targets != masked_index) # loss_mask = mask.view(-1) # return (loss * loss_mask).sum() / loss_mask.sum() # print( "ITLOG: loss.size() = {}".format(loss.size()) ) # print( "ITLOG: loss = {}".format(loss) ) return loss.sum() def _construct_loss(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax(head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, heads) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _mst_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions using the Edmonds' Algorithm for finding minimum spanning trees on directed graphs. Nodes in the graph are the words in the sentence, and between each pair of nodes, there is an edge in each direction, where the weight of the edge corresponds to the most likely dependency label probability for that arc. The MST is then generated from this directed graph. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the optimally decoded heads of each word. """ batch_size, sequence_length, tag_representation_dim = head_tag_representation.size() lengths = mask.data.sum(dim=1).long().cpu().numpy() expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim] head_tag_representation = head_tag_representation.unsqueeze(2) head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous() child_tag_representation = child_tag_representation.unsqueeze(1) child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous() # Shape (batch_size, sequence_length, sequence_length, num_head_tags) pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) # Note that this log_softmax is over the tag dimension, and we don't consider pairs # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. # Shape (batch, num_labels,sequence_length, sequence_length) normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2) # Mask padded tokens, because we only want to consider actual words as heads. minus_inf = -1e8 minus_mask = (1 - mask.float()) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) # Shape (batch_size, sequence_length, sequence_length) normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) # Shape (batch_size, num_head_tags, sequence_length, sequence_length) # This energy tensor expresses the following relation: # energy[i,j] = "Score that i is the head of j". In this # case, we have heads pointing to their children. batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) return self._run_mst_decoding(batch_energy, lengths) @staticmethod def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: heads = [] head_tags = [] for energy, length in zip(batch_energy.detach().cpu(), lengths): scores, tag_ids = energy.max(dim=0) # Although we need to include the root node so that the MST includes it, # we do not want any word to be the parent of the root node. # Here, we enforce this by setting the scores for all word -> ROOT edges # edges to be 0. scores[0, :] = 0 # Decode the heads. Because we modify the scores to prevent # adding in word -> ROOT edges, we need to find the labels ourselves. instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False) # Find the labels which correspond to the edges in the max spanning tree. instance_head_tags = [] for child, parent in enumerate(instance_heads): instance_head_tags.append(tag_ids[parent, child].item()) # We don't care what the head or tag is for the root token, but by default it's # not necesarily the same in the batched vs unbatched case, which is annoying. # Here we'll just set them to zero. instance_heads[0] = 0 instance_head_tags[0] = 0 heads.append(instance_heads) head_tags.append(instance_head_tags) return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags)) def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous() # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = self._attachment_scores.get_metric(reset) metrics['GramValAcc'] = self._gram_val_prediction_accuracy.get_metric(reset) metrics['LemmaAcc'] = self._lemma_prediction_accuracy.get_metric(reset) metrics['MeanAcc'] = (metrics['GramValAcc'] + metrics['LemmaAcc'] + metrics['LAS']) / 3. return metrics
class DecompSyntaxTrainer(DecompTrainer): def __init__(self, validation_data_path: str, validation_prediction_path: str, semantics_only: bool, drop_syntax: bool, include_attribute_scores: bool = False, warmup_epochs: int = 0, syntactic_method: str = 'concat-after', *args, **kwargs): super(DecompSyntaxTrainer, self).__init__(validation_data_path, validation_prediction_path, semantics_only, drop_syntax, include_attribute_scores, warmup_epochs, *args, **kwargs) self.attachment_scorer = AttachmentScores() self.syntactic_method = syntactic_method if self.model.loss_mixer is not None: self.model.loss_mixer.update_weights(curr_epoch=0, total_epochs=self._num_epochs) def _update_attachment_scores(self, pred_instances, true_instances): las = [] uas = [] # flatten true instances if self.syntactic_method.startswith("concat"): token_key = "tgt_tokens_str" head_key = "edge_heads" pred_label_key = "edge_types_inds" true_label_key = "edge_types" mask_key = "valid_node_mask" pred_node_key = "nodes" else: token_key = "syn_tokens_str" head_key = "syn_edge_heads" pred_label_key = "syn_edge_type_inds" true_label_key = "syn_edge_types" mask_key = "syn_valid_node_mask" pred_node_key = "syn_nodes" all_true_nodes = [ true_inst for batch in true_instances for true_inst in batch[0][token_key] ] all_true_edge_heads = [ true_inst for batch in true_instances for true_inst in batch[0][head_key] ] all_true_edge_types = [ true_inst for batch in true_instances for true_inst in batch[0][true_label_key][true_label_key] ] all_true_masks = [ true_inst for batch in true_instances for true_inst in batch[0][mask_key] ] assert (len(all_true_nodes) == len(all_true_edge_heads) == len(all_true_edge_types) == len(all_true_masks) == len(pred_instances)) for i in range(len(pred_instances)): # get rid of @start@ symbol true_nodes = all_true_nodes[i] pred_nodes = pred_instances[i][pred_node_key] if self.syntactic_method.startswith("concat"): if self.syntactic_method == "concat-just-syntax": split_point = -1 end_point = min( true_nodes.index("@end@") - 1, len(pred_nodes) - 1) else: split_point = true_nodes.index("@syntax-sep@") - 1 end_point = min( true_nodes.index("@end@") - 1, len(pred_nodes) - 1) else: split_point = -1 end_point = len(true_nodes) try: pred_edge_heads = pred_instances[i][head_key][split_point + 1:end_point] pred_edge_types = pred_instances[i][pred_label_key][ split_point + 1:end_point] except IndexError: las.append(0) uas.append(0) continue gold_edge_heads = all_true_edge_heads[i][split_point + 1:end_point] gold_edge_types = all_true_edge_types[i][split_point + 1:end_point] valid_node_mask = all_true_masks[i][split_point + 1:end_point] pred_edge_heads = torch.tensor(pred_edge_heads) pred_edge_types = torch.tensor(pred_edge_types) try: self.attachment_scorer(predicted_indices=pred_edge_heads, predicted_labels=pred_edge_types, gold_indices=gold_edge_heads, gold_labels=gold_edge_types, mask=valid_node_mask) except RuntimeError: continue scores = self.attachment_scorer.get_metric(reset=True) self.model.syntax_las = scores["LAS"] * 100 self.model.syntax_uas = scores["UAS"] * 100 @overrides def _update_validation_s_score(self, pred_instances: List[Dict[str, numpy.ndarray]], true_instances): """Write the validation output in pkl format, and compute the S score.""" # compute attachement scores here without having to override another function self._update_attachment_scores(pred_instances, true_instances) if isinstance(self.model, DecompSyntaxOnlyParser) or \ isinstance(self.model, DecompTransformerSyntaxOnlyParser) or \ isinstance(self.model, UDParser): return logger.info("Computing S") for batch in true_instances: assert (len(batch) == 1) true_graphs = [ true_inst for batch in true_instances for true_inst in batch[0]['graph'] ] true_sents = [ true_inst for batch in true_instances for true_inst in batch[0]['src_tokens_str'] ] pred_graphs = [ DecompGraphWithSyntax.from_prediction(pred_inst, self.syntactic_method) for pred_inst in pred_instances ] pred_sem_graphs, pred_syn_graphs, __ = zip(*pred_graphs) ret = compute_s_metric(true_graphs, pred_sem_graphs, true_sents, self.semantics_only, self.drop_syntax, self.include_attribute_scores) self.model.val_s_precision = float(ret[0]) * 100 self.model.val_s_recall = float(ret[1]) * 100 self.model.val_s_f1 = float(ret[2]) * 100
class KGParser(Model): """ A Parser based on edge-scoring method in Kiperwasser and Goldberg (2016). Based on the implementation in: https://github.com/coli-saar/am-parser At the moment, the parser doesn't use cost-augmentation or hinge-loss. Registered as a `Model` with name "kg_parser". # Parameters vocab : `Vocabulary`, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : `TextFieldEmbedder`, required Used to embed the `tokens` `TextField` we get as input to the model. encoder : `Seq2SeqEncoder` The encoder (with its own internal stacking) that we will use to generate representations of tokens. tag_representation_dim : `int`, required. The dimension of the MLPs used for arc tag prediction. arc_representation_dim : `int`, required. The dimension of the MLPs used for arc prediction. tag_feedforward : `FeedForward`, optional, (default = None). The feedforward network used to produce tag representations. By default, a 1 layer feedforward network with an elu activation is used. arc_feedforward : `FeedForward`, optional, (default = None). The feedforward network used to produce arc representations. By default, a 1 layer feedforward network with an elu activation is used. pos_tag_embedding : `Embedding`, optional. Used to embed the `pos_tags` `SequenceLabelField` we get as input to the model. dropout : `float`, optional, (default = 0.0) The variational dropout applied to the output of the encoder and MLP layers. input_dropout : `float`, optional, (default = 0.0) The dropout applied to the embedded text input. edge_prediction_threshold : `int`, optional (default = 0.5) The probability at which to consider a scored edge to be 'present' in the decoded graph. Must be between 0 and 1. initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`) Used to initialize the model parameters. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, tag_representation_dim: int, arc_representation_dim: int, activation = Activation.by_name("tanh")(), tag_feedforward: FeedForward = None, arc_feedforward: FeedForward = None, pos_tag_embedding: Embedding = None, use_mst_decoding_for_validation: bool = False, dropout: float = 0.0, input_dropout: float = 0.0, edge_prediction_threshold: float = 0.5, initializer: InitializerApplicator = InitializerApplicator(), **kwargs, ) -> None: super().__init__(vocab, **kwargs) self.text_field_embedder = text_field_embedder self.encoder = encoder self.activation = activation encoder_dim = encoder.get_output_dim() # edge FeedForward self.head_arc_feedforward = arc_feedforward or FeedForward( encoder_dim, 1, arc_representation_dim, Activation.by_name("tanh")() ) self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) # label FeedForward self.head_tag_feedforward = tag_feedforward or FeedForward( encoder_dim, 1, tag_representation_dim, Activation.by_name("tanh")() ) self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) self.arc_out_layer = Linear(arc_representation_dim, 1) num_labels = self.vocab.get_vocab_size("head_tags") self.tag_out_layer = Linear(arc_representation_dim, num_labels) self._pos_tag_embedding = pos_tag_embedding or None self._dropout = InputVariationalDropout(dropout) self._input_dropout = Dropout(input_dropout) # add a head sentinel to accommodate for extra root token self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder.get_output_dim()])) representation_dim = text_field_embedder.get_output_dim() if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match( representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim", ) check_dimensions_match( tag_representation_dim, self.head_tag_feedforward.get_output_dim(), "tag representation dim", "tag feedforward output dim", ) check_dimensions_match( arc_representation_dim, self.head_arc_feedforward.get_output_dim(), "arc representation dim", "arc feedforward output dim", ) self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = { tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE } self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info( f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation." ) self._attachment_scores = AttachmentScores() initializer(self) @overrides def forward( self, # type: ignore tokens: TextFieldTensors, pos_tags: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, enhanced_tags: torch.LongTensor = None ) -> Dict[str, torch.Tensor]: """ # Parameters tokens : TextFieldTensors, required The output of `TextField.as_array()`. pos_tags : torch.LongTensor, optional (default = None) The output of a `SequenceLabelField` containing POS tags. metadata : List[Dict[str, Any]], optional (default = None) A dictionary of metadata for each batch element which has keys: tokens : `List[str]`, required. The original string tokens in the sentence. enhanced_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length, sequence_length)``. # Returns An output dictionary. """ embedded_text_input = self.text_field_embedder(tokens) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(tokens) predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll = self._parse( embedded_text_input, mask, head_tags, head_indices ) loss = arc_nll + tag_nll if head_indices is not None and head_tags is not None: evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores( predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices, head_tags, evaluation_mask, ) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "tokens": [meta["tokens"] for meta in metadata], "pos_tags": [meta["pos_tags"] for meta in metadata], } return output_dict @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: head_tags = output_dict.pop("head_tags").cpu().detach().numpy() heads = output_dict.pop("heads").cpu().detach().numpy() mask = output_dict.pop("mask") lengths = get_lengths_from_binary_sequence_mask(mask) head_tag_labels = [] head_indices = [] for instance_heads, instance_tags, length in zip(heads, head_tags, lengths): instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [ self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags ] head_tag_labels.append(labels) head_indices.append(instance_heads) output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices return output_dict def _parse( self, embedded_text_input: torch.Tensor, mask: torch.BoolTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, sequence_length, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self.head_arc_feedforward(encoded_text) child_arc_representation = self.child_arc_feedforward(encoded_text) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self.head_tag_feedforward(encoded_text) child_tag_representation = self.child_tag_feedforward(encoded_text) # calculate dimensions again as sequence_length is now + 1 from adding the head_sentinel batch_size, sequence_length, arc_dim = head_arc_representation.size() # now repeat the token representations to form a matrix: # shape (batch_size, sequence_length, sequence_length, arc_representation_dim) heads = head_arc_representation.repeat(1, sequence_length, 1).reshape(batch_size, sequence_length, sequence_length, arc_dim) # heads in one direction deps = child_arc_representation.repeat(1, sequence_length, 1).reshape(batch_size, sequence_length, sequence_length, arc_dim).transpose(1, 2) # deps in the other direction # shape (batch_size, sequence_length, sequence_length, arc_representation_dim) combined_arcs = self.activation(heads + deps) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_out_layer(combined_arcs).squeeze(3) minus_inf = -1e8 minus_mask = ~mask * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask ) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask ) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask, ) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask, ) return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.BoolTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : `torch.BoolTensor`, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. # Returns arc_nll : `torch.Tensor`, required. The negative log likelihood from the arc loss. tag_nll : `torch.Tensor`, required. The negative log likelihood from the arc tag loss. """ batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = ( masked_log_softmax(attended_arcs, mask) * mask.unsqueeze(2) * mask.unsqueeze(1) ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags( head_tag_representation, child_tag_representation, head_indices ) normalised_head_tag_logits = masked_log_softmax( head_tag_logits, mask.unsqueeze(-1) ) * mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = ( timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() ) # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.BoolTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. # Returns heads : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : `torch.Tensor` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag( attended_arcs.new(mask.size(1)).fill_(-numpy.inf) ) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = ~mask.unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags( head_tag_representation, child_tag_representation, heads ) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _get_head_tags( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor, ) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. # Parameters head_tag_representation : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. # Returns head_tag_logits : `torch.Tensor` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation) ).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous() # shape (batch_size, sequence_length, num_head_tags) combined = self.activation(selected_head_tag_representations + child_tag_representation) #(batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_out_layer(combined) return head_tag_logits def _get_mask_for_eval( self, mask: torch.BoolTensor, pos_tags: torch.LongTensor ) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. # Parameters mask : `torch.BoolTensor`, required. The original mask. pos_tags : `torch.LongTensor`, required. The pos tags for the sequence. # Returns A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label) new_mask = new_mask & ~label_mask return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return self._attachment_scores.get_metric(reset)
class DecompSyntaxOnlyParser(DecompSyntaxParser): def __init__( self, vocab: Vocabulary, # source-side bert_encoder: BaseBertWrapper, encoder_token_embedder: TextFieldEmbedder, encoder_pos_embedding: Embedding, encoder: Seq2SeqEncoder, # target-side decoder_token_embedder: TextFieldEmbedder, decoder_node_index_embedding: Embedding, decoder_pos_embedding: Embedding, decoder: RNNDecoder, extended_pointer_generator: ExtendedPointerGenerator, tree_parser: DecompTreeParser, node_attribute_module: NodeAttributeDecoder, edge_attribute_module: EdgeAttributeDecoder, # misc label_smoothing: LabelSmoothing, target_output_namespace: str, pos_tag_namespace: str, edge_type_namespace: str, syntax_edge_type_namespace: str = None, biaffine_parser: DeepTreeParser = None, syntactic_method: str = None, dropout: float = 0.0, beam_size: int = 5, max_decoding_steps: int = 50, eps: float = 1e-20, loss_mixer: LossMixer = None) -> None: super(DecompSyntaxOnlyParser, self).__init__( vocab=vocab, # source-side bert_encoder=bert_encoder, encoder_token_embedder=encoder_token_embedder, encoder_pos_embedding=encoder_pos_embedding, encoder=encoder, # target-side decoder_token_embedder=decoder_token_embedder, decoder_node_index_embedding=decoder_node_index_embedding, decoder_pos_embedding=decoder_pos_embedding, decoder=decoder, extended_pointer_generator=extended_pointer_generator, tree_parser=tree_parser, node_attribute_module=node_attribute_module, edge_attribute_module=edge_attribute_module, # misc label_smoothing=label_smoothing, target_output_namespace=target_output_namespace, pos_tag_namespace=pos_tag_namespace, edge_type_namespace=edge_type_namespace, syntax_edge_type_namespace=syntax_edge_type_namespace, syntactic_method=syntactic_method, dropout=dropout, beam_size=beam_size, max_decoding_steps=max_decoding_steps, eps=eps) self.syntactic_method = "encoder-side" self.biaffine_parser = biaffine_parser self.loss_mixer = None self._syntax_metrics = AttachmentScores() self.syntax_las = 0.0 self.syntax_uas = 0.0 @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: syntax_metrics = self._syntax_metrics.get_metric(reset) metrics = OrderedDict( syn_uas=syntax_metrics["UAS"] * 100, syn_las=syntax_metrics["LAS"] * 100, ) metrics["syn_las"] = self.syntax_las metrics["syn_uas"] = self.syntax_uas return metrics def _update_syntax_scores(self): scores = self._syntax_metrics.get_metric(reset=True) self.syntax_las = scores["LAS"] * 100 self.syntax_uas = scores["UAS"] * 100 def _compute_biaffine_loss(self, biaffine_outputs, inputs): edge_prediction_loss = self._compute_edge_prediction_loss( biaffine_outputs['edge_head_ll'], biaffine_outputs['edge_type_ll'], biaffine_outputs['edge_heads'], biaffine_outputs['edge_types'], inputs['syn_edge_heads'], inputs['syn_edge_types']['syn_edge_types'], inputs['syn_valid_node_mask'], syntax=True) return edge_prediction_loss['loss_per_node'] def _parse_syntax(self, encoder_outputs: torch.Tensor, edge_head_mask: torch.Tensor, edge_heads: torch.Tensor = None, valid_node_mask: torch.Tensor = None, do_mst=False) -> Dict: parser_outputs = self.biaffine_parser(query=encoder_outputs, key=encoder_outputs, edge_head_mask=edge_head_mask, gold_edge_heads=edge_heads, decode_mst=do_mst, valid_node_mask=valid_node_mask) return parser_outputs @overrides def _training_forward(self, inputs: Dict) -> Dict[str, torch.Tensor]: encoding_outputs = self._encode( tokens=inputs["source_tokens"], pos_tags=inputs["source_pos_tags"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"]) biaffine_outputs = self._parse_syntax( encoding_outputs['encoder_outputs'], inputs["syn_edge_head_mask"], inputs["syn_edge_heads"], do_mst=False) biaffine_loss = self._compute_biaffine_loss(biaffine_outputs, inputs) self._update_syntax_scores() return dict(loss=biaffine_loss) @overrides def _test_forward(self, inputs: Dict) -> Dict: encoding_outputs = self._encode( tokens=inputs["source_tokens"], pos_tags=inputs["source_pos_tags"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"]) biaffine_outputs = self._parse_syntax( encoding_outputs['encoder_outputs'], inputs["syn_edge_head_mask"], None, valid_node_mask=inputs["syn_valid_node_mask"], do_mst=True) syn_edge_head_predictions, syn_edge_type_predictions, syn_edge_type_inds = self._read_edge_predictions( biaffine_outputs, is_syntax=True) bsz, __ = inputs["source_tokens"]["source_tokens"].shape outputs = dict(syn_nodes=inputs['syn_tokens_str'], syn_edge_heads=syn_edge_head_predictions, syn_edge_types=syn_edge_type_predictions, syn_edge_type_inds=syn_edge_type_inds, loss=torch.tensor([0.0]), nodes=torch.ones((bsz, 1)), node_indices=torch.ones((bsz, 1)), edge_heads=torch.ones((bsz, 1)), edge_types=torch.ones((bsz, 1)), edge_types_inds=torch.ones((bsz, 1)), node_attributes=torch.ones((bsz, 1, 44)), node_attributes_mask=torch.ones((bsz, 1, 44)), edge_attributes=torch.ones((bsz, 1, 14)), edge_attributes_mask=torch.ones((bsz, 1, 14))) return outputs
class AttachmentScoresTest(AllenNlpTestCase): def setUp(self): super(AttachmentScoresTest, self).setUp() self.scorer = AttachmentScores() self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]]) self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]]) self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]]) self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]]) self.mask = torch.Tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) def test_perfect_scores(self): self.scorer(self.predictions, self.label_predictions, self.gold_indices, self.gold_labels, self.mask) for value in list(self.scorer.get_metric().values()): assert value == 1.0 def test_unlabeled_accuracy_ignores_incorrect_labels(self): label_predictions = self.label_predictions # Change some stuff so our 4 of our label predictions are wrong. label_predictions[0, 3:] = 3 label_predictions[1, 0] = 7 self.scorer(self.predictions, label_predictions, self.gold_indices, self.gold_labels, self.mask) metrics = self.scorer.get_metric() assert metrics[u"UAS"] == 1.0 assert metrics[u"UEM"] == 1.0 # 4 / 12 labels were wrong and 2 positions # are masked, so 6/10 = 0.6 LAS. assert metrics[u"LAS"] == 0.6 # Neither should have labeled exact match. assert metrics[u"LEM"] == 0.0 def test_labeled_accuracy_is_affected_by_incorrect_heads(self): predictions = self.predictions # Change some stuff so our 4 of our predictions are wrong. predictions[0, 3:] = 3 predictions[1, 0] = 7 # This one is in the padded part, so it shouldn't affect anything. predictions[1, 5] = 7 self.scorer(predictions, self.label_predictions, self.gold_indices, self.gold_labels, self.mask) metrics = self.scorer.get_metric() # 4 heads are incorrect, so the unlabeled score should be # 6/10 = 0.6 LAS. assert metrics[u"UAS"] == 0.6 # All the labels were correct, but some heads # were wrong, so the LAS should equal the UAS. assert metrics[u"LAS"] == 0.6 # Neither batch element had a perfect labeled or unlabeled EM. assert metrics[u"LEM"] == 0.0 assert metrics[u"UEM"] == 0.0 def test_attachment_scores_can_ignore_labels(self): scorer = AttachmentScores(ignore_classes=[1]) label_predictions = self.label_predictions # Change the predictions where the gold label is 1; # as we are ignoring 1, we should still get a perfect score. label_predictions[0, 3] = 2 scorer(self.predictions, label_predictions, self.gold_indices, self.gold_labels, self.mask) for value in list(scorer.get_metric().values()): assert value == 1.0
class Transduction(Model): def __init__(self, vocab: Vocabulary, # source-side bert_encoder: Seq2SeqBertEncoder, encoder_token_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, # target-side decoder_token_embedder: TextFieldEmbedder, decoder_node_index_embedding: Embedding, decoder: RNNDecoder, extended_pointer_generator: ExtendedPointerGenerator, tree_parser: DeepTreeParser, # misc label_smoothing: LabelSmoothing, target_output_namespace: str, dropout: float = 0.0, eps: float = 1e-20, pretrained_weights: str = None, ) -> None: super().__init__(vocab=vocab) # source-side self._bert_encoder = bert_encoder self._encoder_token_embedder = encoder_token_embedder self._encoder = encoder # target-side self._decoder_token_embedder = decoder_token_embedder self._decoder_node_index_embedding = decoder_node_index_embedding self._decoder = decoder self._extended_pointer_generator = extended_pointer_generator self._tree_parser = tree_parser # metrics self._node_pred_metrics = ExtendedPointerGeneratorMetrics() self._edge_pred_metrics = AttachmentScores() self._synt_edge_pred_metrics = AttachmentScores() self._label_smoothing = label_smoothing self._dropout = InputVariationalDropout(p=dropout) self._eps = eps # dynamic initialization self._target_output_namespace = target_output_namespace self._vocab_size = self.vocab.get_vocab_size(target_output_namespace) self._vocab_pad_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN, target_output_namespace) # loading partial weights self.pretrained_weights = pretrained_weights @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: node_pred_metrics = self._node_pred_metrics.get_metric(reset) edge_pred_metrics = self._edge_pred_metrics.get_metric(reset) synt_edge_pred_metrics = self._synt_edge_pred_metrics.get_metric(reset) metrics = OrderedDict( ppl=node_pred_metrics["ppl"], node_pred=node_pred_metrics["accuracy"] * 100, generate=node_pred_metrics["generate"] * 100, src_copy=node_pred_metrics["src_copy"] * 100, tgt_copy=node_pred_metrics["tgt_copy"] * 100, uas=edge_pred_metrics["UAS"] * 100, las=edge_pred_metrics["LAS"] * 100, ) return metrics @overrides def forward(self, **raw_inputs: Dict) -> Dict: inputs = self._prepare_inputs(raw_inputs) if self.training: return self._training_forward(inputs) else: return self._test_forward(inputs) def _compute_edge_prediction_loss(self, edge_head_ll: torch.Tensor, edge_type_ll: torch.Tensor, pred_edge_heads: torch.Tensor, pred_edge_types: torch.Tensor, gold_edge_heads: torch.Tensor, gold_edge_types: torch.Tensor, valid_node_mask: torch.Tensor, syntax: bool = False) -> Dict: """ Compute the edge prediction loss. :param edge_head_ll: [batch_size, target_length, target_length + 1 (for sentinel)]. :param edge_type_ll: [batch_size, target_length, num_labels]. :param pred_edge_heads: [batch_size, target_length]. :param pred_edge_types: [batch_size, target_length]. :param gold_edge_heads: [batch_size, target_length]. :param gold_edge_types: [batch_size, target_length]. :param valid_node_mask: [batch_size, target_length]. """ # Index the log-likelihood (ll) of gold edge heads and types. batch_size, target_length, _ = edge_head_ll.size() batch_indices = torch.arange(0, batch_size).view(batch_size, 1).type_as(gold_edge_heads) node_indices = torch.arange(0, target_length).view(1, target_length) \ .expand(batch_size, target_length).type_as(gold_edge_heads) gold_edge_head_ll = edge_head_ll[batch_indices, node_indices, gold_edge_heads] gold_edge_type_ll = edge_type_ll[batch_indices, node_indices, gold_edge_types] # Set the ll of invalid nodes to 0. num_nodes = valid_node_mask.sum().float() if not syntax: # don't incur loss on EOS/SOS token valid_node_mask[gold_edge_heads == -1] = 0 valid_node_mask = valid_node_mask.bool() gold_edge_head_ll.masked_fill_(~valid_node_mask, 0) gold_edge_type_ll.masked_fill_(~valid_node_mask, 0) # Negative log-likelihood. loss = -(gold_edge_head_ll.sum() + gold_edge_type_ll.sum()) # Update metrics. if self.training and not syntax: self._edge_pred_metrics( predicted_indices=pred_edge_heads, predicted_labels=pred_edge_types, gold_indices=gold_edge_heads, gold_labels=gold_edge_types, mask=valid_node_mask ) elif self.training and syntax: self._syntax_metrics( predicted_indices=pred_edge_heads, predicted_labels=pred_edge_types, gold_indices=gold_edge_heads, gold_labels=gold_edge_types, mask=valid_node_mask ) return dict( loss=loss, num_nodes=num_nodes, loss_per_node=loss / num_nodes, ) def _compute_node_prediction_loss(self, prob_dist: torch.Tensor, generation_outputs: torch.Tensor, source_copy_indices: torch.Tensor, target_copy_indices: torch.Tensor, source_dynamic_vocab_size: int, source_attention_weights: torch.Tensor = None, coverage_history: torch.Tensor = None) -> Dict: """ Compute the node prediction loss based on the final hybrid probability distribution. :param prob_dist: probability distribution, [batch_size, target_length, vocab_size + source_dynamic_vocab_size + target_dynamic_vocab_size]. :param generation_outputs: generated node indices in the pre-defined vocabulary, [batch_size, target_length]. :param source_copy_indices: source-side copied node indices in the source dynamic vocabulary, [batch_size, target_length]. :param target_copy_indices: target-side copied node indices in the source dynamic vocabulary, [batch_size, target_length]. :param source_dynamic_vocab_size: int. :param source_attention_weights: None or [batch_size, target_length, source_length]. :param coverage_history: None or a tensor recording the source-side coverage history. [batch_size, target_length, source_length]. """ _, prediction = prob_dist.max(2) batch_size, target_length = prediction.size() not_pad_mask = generation_outputs.ne(self._vocab_pad_index) num_nodes = not_pad_mask.sum() # Priority: target_copy > source_copy > generation # Prepare mask. valid_target_copy_mask = target_copy_indices.ne(0) & not_pad_mask # 0 for sentinel. valid_source_copy_mask = (~valid_target_copy_mask & not_pad_mask & source_copy_indices.ne(1) & source_copy_indices.ne(0)) # 1 for unk; 0 for pad. valid_generation_mask = ~(valid_target_copy_mask | valid_source_copy_mask) & not_pad_mask # Prepare hybrid targets. _target_copy_indices = ((target_copy_indices + self._vocab_size + source_dynamic_vocab_size) * valid_target_copy_mask.long()) _source_copy_indices = (source_copy_indices + self._vocab_size) * valid_source_copy_mask.long() _generation_outputs = generation_outputs * valid_generation_mask.long() hybrid_targets = _target_copy_indices + _source_copy_indices + _generation_outputs # Compute loss. log_prob_dist = (prob_dist.view(batch_size * target_length, -1) + self._eps).log() flat_hybrid_targets = hybrid_targets.view(batch_size * target_length) loss = self._label_smoothing(log_prob_dist, flat_hybrid_targets) # Coverage loss. if coverage_history is not None: #coverage_loss = torch.sum(torch.min(coverage_history.unsqueeze(-1), source_attention_weights), 2) coverage_loss = torch.sum(torch.min(coverage_history, source_attention_weights), 2) coverage_loss = (coverage_loss * not_pad_mask.float()).sum() loss = loss + coverage_loss # Update metric stats. self._node_pred_metrics( loss=loss, prediction=prediction, generation_outputs=_generation_outputs, valid_generation_mask=valid_generation_mask, source_copy_indices=_source_copy_indices, valid_source_copy_mask=valid_source_copy_mask, target_copy_indices=_target_copy_indices, valid_target_copy_mask=valid_target_copy_mask ) return dict( loss=loss, num_nodes=num_nodes, loss_per_node=loss / num_nodes, ) def _decode(self, tokens: Dict[str, torch.Tensor], node_indices: torch.Tensor, encoder_outputs: torch.Tensor, hidden_states: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor, **kwargs) -> Dict: # [batch, num_tokens, embedding_size] decoder_inputs = torch.cat([ self._decoder_token_embedder(tokens), self._decoder_node_index_embedding(node_indices), ], dim=2) decoder_inputs = self._dropout(decoder_inputs) decoder_outputs = self._decoder( inputs=decoder_inputs, source_memory_bank=encoder_outputs, source_mask=mask, hidden_state=hidden_states ) return decoder_outputs def _encode(self, tokens: Dict[str, torch.Tensor], subtoken_ids: torch.Tensor, token_recovery_matrix: torch.Tensor, mask: torch.Tensor, **kwargs) -> Dict: # [batch, num_tokens, embedding_size] encoder_inputs = [self._encoder_token_embedder(tokens)] if subtoken_ids is not None and self._bert_encoder is not None: bert_embeddings = self._bert_encoder( input_ids=subtoken_ids, attention_mask=subtoken_ids.ne(0), output_all_encoded_layers=False, token_recovery_matrix=token_recovery_matrix ).detach() encoder_inputs += [bert_embeddings] encoder_inputs = torch.cat(encoder_inputs, 2) encoder_inputs = self._dropout(encoder_inputs) # [batch, num_tokens, encoder_output_size] encoder_outputs = self._encoder(encoder_inputs, mask) encoder_outputs = self._dropout(encoder_outputs) # A tuple of (state, memory) with shape [num_layers, batch, encoder_output_size] encoder_final_states = self._encoder.get_final_states() self._encoder.reset_states() return dict( encoder_outputs=encoder_outputs, final_states=encoder_final_states ) def _parse(self, rnn_outputs: torch.Tensor, edge_head_mask: torch.Tensor, edge_heads: torch.Tensor = None) -> Dict: """ Based on the vector representation for each node, predict its head and edge type. :param rnn_outputs: vector representations of nodes, including <BOS>. [batch_size, target_length + 1, hidden_vector_dim]. :param edge_head_mask: mask used in the edge head search. [batch_size, target_length, target_length]. :param edge_heads: None or gold head indices, [batch_size, target_length] """ # Exclude <BOS>. # <EOS> is already excluded in ``_prepare_inputs''. rnn_outputs = self._dropout(rnn_outputs[:, 1:]) parser_outputs = self._tree_parser( query=rnn_outputs, key=rnn_outputs, edge_head_mask=edge_head_mask, gold_edge_heads=edge_heads ) return parser_outputs def _prepare_inputs(self, raw_inputs: Dict) -> Dict: return raw_inputs def _test_forward(self, inputs: Dict) -> Dict: raise NotImplementedError def _training_forward(self, inputs: Dict) -> Dict[str, torch.Tensor]: encoding_outputs = self._encode( tokens=inputs["source_tokens"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"] ) decoding_outputs = self._decode( tokens=inputs["target_tokens"], node_indices=inputs["target_node_indices"], encoder_outputs=encoding_outputs["encoder_outputs"], hidden_states=encoding_outputs["final_states"], mask=inputs["source_mask"] ) node_prediction_outputs = self._extended_pointer_generator( inputs=decoding_outputs["attentional_tensors"], source_attention_weights=decoding_outputs["source_attention_weights"], target_attention_weights=decoding_outputs["target_attention_weights"], source_attention_map=inputs["source_attention_map"], target_attention_map=inputs["target_attention_map"] ) edge_prediction_outputs = self._parse( rnn_outputs=decoding_outputs["rnn_outputs"], edge_head_mask=inputs["edge_head_mask"], edge_heads=inputs["edge_heads"] ) node_pred_loss = self._compute_node_prediction_loss( prob_dist=node_prediction_outputs["hybrid_prob_dist"], generation_outputs=inputs["generation_outputs"], source_copy_indices=inputs["source_copy_indices"], target_copy_indices=inputs["target_copy_indices"], source_dynamic_vocab_size=inputs["source_dynamic_vocab_size"], source_attention_weights=decoding_outputs["source_attention_weights"], coverage_history=decoding_outputs["coverage_history"] ) edge_pred_loss = self._compute_edge_prediction_loss( edge_head_ll=edge_prediction_outputs["edge_head_ll"], edge_type_ll=edge_prediction_outputs["edge_type_ll"], pred_edge_heads=edge_prediction_outputs["edge_heads"], pred_edge_types=edge_prediction_outputs["edge_types"], gold_edge_heads=inputs["edge_heads"], gold_edge_types=inputs["edge_types"], valid_node_mask=inputs["valid_node_mask"] ) loss = node_pred_loss["loss_per_node"] + edge_pred_loss["loss_per_node"] return dict(loss=loss) def load_partial(self, param_file: str): """ loads weights and matches the ones it can """ logger.info(f"Attempting to load pretrained weights from {param_file}") pretrained_state_dict = torch.load(param_file) current_state_dict = self.state_dict() for k, v in pretrained_state_dict.items(): if isinstance(v, torch.nn.Parameter): v = v.data try: current_state_dict[k].copy_(v) print(f"matched {k}") logger.info(f"matched {k}") except RuntimeError: new_shape = pretrained_state_dict[k].shape og_shape = current_state_dict[k].shape print(f"Unable to match {k} due to shape error: pretrained: {new_shape} vs original: {og_shape}") logger.warning(f"Unable to match {k} due to shape error: pretrained: {new_shape} vs original: {og_shape}") continue except KeyError: logger.warning(f"Unable to match {k} because it does not exist in original model") print(f"Unable to match {k} because it does not exist in original model") continue key = "biaffine_parser.edge_type_query_linear.weight" self.load_state_dict(current_state_dict)
class DependencyDecoder(Model): """ Modifies BiaffineDependencyParser, removing the input TextFieldEmbedder dependency to allow the model to essentially act as a decoder when given intermediate word embeddings instead of as a standalone model. """ def __init__(self, vocab: Vocabulary, encoder: Seq2SeqEncoder, tag_representation_dim: int, arc_representation_dim: int, pos_embed_dim: int = None, tag_feedforward: FeedForward = None, arc_feedforward: FeedForward = None, use_mst_decoding_for_validation: bool = True, dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(DependencyDecoder, self).__init__(vocab, regularizer) self.pos_tag_embedding = None if pos_embed_dim is not None: self.pos_tag_embedding = Embedding(self.vocab.get_vocab_size("upos"), pos_embed_dim) self.dropout = torch.nn.Dropout(p=dropout) self.encoder = encoder encoder_output_dim = encoder.get_output_dim() self.head_arc_feedforward = arc_feedforward or \ FeedForward(encoder_output_dim, 1, arc_representation_dim, Activation.by_name("elu")()) self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) self.arc_attention = BilinearMatrixAttention(arc_representation_dim, arc_representation_dim, use_input_biases=True) num_labels = self.vocab.get_vocab_size("head_tags") self.head_tag_feedforward = tag_feedforward or \ FeedForward(encoder_output_dim, 1, tag_representation_dim, Activation.by_name("elu")()) self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim, tag_representation_dim, num_labels) self._dropout = InputVariationalDropout(dropout) self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder_output_dim])) check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(), "tag representation dim", "tag feedforward output dim") check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(), "arc representation dim", "arc feedforward output dim") self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE} self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self._attachment_scores = AttachmentScores() initializer(self) @overrides def forward(self, # type: ignore # words: Dict[str, torch.LongTensor], encoded_text: torch.FloatTensor, mask: torch.LongTensor, pos_logits: torch.LongTensor = None, # predicted head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, _, _ = encoded_text.size() pos_tags = None if pos_logits is not None and self.pos_tag_embedding is not None: # Embed the predicted POS tags and concatenate the embeddings to the input num_pos_classes = pos_logits.size(-1) pos_logits = pos_logits.view(-1, num_pos_classes) _, pos_tags = pos_logits.max(-1) pos_embed_size = self.pos_tag_embedding.get_output_dim() embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags)) embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size) encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1) encoded_text = self.encoder(encoded_text, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attachment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], # "pos": [meta["pos"] for meta in metadata] } return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: head_tags = output_dict.pop("head_tags").cpu().detach().numpy() heads = output_dict.pop("heads").cpu().detach().numpy() mask = output_dict.pop("mask") lengths = get_lengths_from_binary_sequence_mask(mask) head_tag_labels = [] head_indices = [] for instance_heads, instance_tags, length in zip(heads, head_tags, lengths): instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags] head_tag_labels.append(labels) head_indices.append(instance_heads) output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices return output_dict def _construct_loss(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax(head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).byte().unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, heads) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _mst_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions using the Edmonds' Algorithm for finding minimum spanning trees on directed graphs. Nodes in the graph are the words in the sentence, and between each pair of nodes, there is an edge in each direction, where the weight of the edge corresponds to the most likely dependency label probability for that arc. The MST is then generated from this directed graph. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the optimally decoded heads of each word. """ batch_size, sequence_length, tag_representation_dim = head_tag_representation.size() lengths = mask.data.sum(dim=1).long().cpu().numpy() expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim] head_tag_representation = head_tag_representation.unsqueeze(2) head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous() child_tag_representation = child_tag_representation.unsqueeze(1) child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous() # Shape (batch_size, sequence_length, sequence_length, num_head_tags) pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) # Note that this log_softmax is over the tag dimension, and we don't consider pairs # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. # Shape (batch, num_labels,sequence_length, sequence_length) normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2) # Mask padded tokens, because we only want to consider actual words as heads. minus_inf = -1e8 minus_mask = (1 - mask.float()) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) # Shape (batch_size, sequence_length, sequence_length) normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) # Shape (batch_size, num_head_tags, sequence_length, sequence_length) # This energy tensor expresses the following relation: # energy[i,j] = "Score that i is the head of j". In this # case, we have heads pointing to their children. batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) return self._run_mst_decoding(batch_energy, lengths) @staticmethod def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: heads = [] head_tags = [] for energy, length in zip(batch_energy.detach().cpu(), lengths): scores, tag_ids = energy.max(dim=0) # Although we need to include the root node so that the MST includes it, # we do not want any word to be the parent of the root node. # Here, we enforce this by setting the scores for all word -> ROOT edges # edges to be 0. scores[0, :] = 0 # Decode the heads. Because we modify the scores to prevent # adding in word -> ROOT edges, we need to find the labels ourselves. instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False) # Find the labels which correspond to the edges in the max spanning tree. instance_head_tags = [] for child, parent in enumerate(instance_heads): instance_head_tags.append(tag_ids[parent, child].item()) # We don't care what the head or tag is for the root token, but by default it's # not necesarily the same in the batched vs unbatched case, which is annoying. # Here we'll just set them to zero. instance_heads[0] = 0 instance_head_tags[0] = 0 heads.append(instance_heads) head_tags.append(instance_head_tags) return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags)) def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous() # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return {f".run/deps/{metric_name}": metric for metric_name, metric in self._attachment_scores.get_metric(reset).items()}
class DependencyParser(Model): """ This dependency parser follows the model of ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . Word representations are generated using a bidirectional LSTM, followed by separate biaffine classifiers for pairs of words, predicting whether a directed arc exists between the two words and the dependency label the arc should have. Decoding can either be done greedily, or the optimal Minimum Spanning Tree can be decoded using Edmond's algorithm by viewing the dependency tree as a MST on a fully connected graph, where nodes are words and edges are scored dependency arcs. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. encoder : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use to generate representations of tokens. tag_representation_dim : ``int``, required. The dimension of the MLPs used for dependency tag prediction. arc_representation_dim : ``int``, required. The dimension of the MLPs used for head arc prediction. tag_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce tag representations. By default, a 1 layer feedforward network with an elu activation is used. arc_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce arc representations. By default, a 1 layer feedforward network with an elu activation is used. pos_tag_embedding : ``Embedding``, optional. Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model. use_mst_decoding_for_validation : ``bool``, optional (default = True). Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation. If false, decoding is greedy. dropout : ``float``, optional, (default = 0.0) The variational dropout applied to the output of the encoder and MLP layers. input_dropout : ``float``, optional, (default = 0.0) The dropout applied to the embedded text input. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, tag_representation_dim: int, arc_representation_dim: int, lemmatize_helper: LemmatizeHelper, task_config: TaskConfig, morpho_vector_dim: int = 0, gram_val_representation_dim: int = -1, lemma_representation_dim: int = -1, tag_feedforward: FeedForward = None, arc_feedforward: FeedForward = None, pos_tag_embedding: Embedding = None, use_mst_decoding_for_validation: bool = True, dropout: float = 0.0, input_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(DependencyParser, self).__init__(vocab, regularizer) self.TopNCnt = 3 self.text_field_embedder = text_field_embedder self.encoder = encoder self.lemmatize_helper = lemmatize_helper self.task_config = task_config encoder_dim = encoder.get_output_dim() self.head_arc_feedforward = arc_feedforward or \ FeedForward(encoder_dim, 1, arc_representation_dim, Activation.by_name("elu")()) self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) self.arc_attention = BilinearMatrixAttention(arc_representation_dim, arc_representation_dim, use_input_biases=True) num_labels = self.vocab.get_vocab_size("head_tags") self.head_tag_feedforward = tag_feedforward or \ FeedForward(encoder_dim, 1, tag_representation_dim, Activation.by_name("elu")()) self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim, tag_representation_dim, num_labels) self._pos_tag_embedding = pos_tag_embedding or None assert self.task_config.params.get("use_pos_tag", False) == (self._pos_tag_embedding is not None) self._dropout = InputVariationalDropout(dropout) self._input_dropout = Dropout(input_dropout) self._head_sentinel = torch.nn.Parameter( torch.randn([1, 1, encoder.get_output_dim()])) if gram_val_representation_dim <= 0: self._gram_val_output = torch.nn.Linear( encoder_dim, self.vocab.get_vocab_size("grammar_value_tags")) else: self._gram_val_output = torch.nn.Sequential( Dropout(dropout), torch.nn.Linear(encoder_dim, gram_val_representation_dim), Dropout(dropout), torch.nn.Linear( gram_val_representation_dim, self.vocab.get_vocab_size("grammar_value_tags"))) if lemma_representation_dim <= 0: self._lemma_output = torch.nn.Linear(encoder_dim, len(lemmatize_helper)) else: # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ #actual_input_dim = encoder_dim actual_input_dim = encoder_dim + self.vocab.get_vocab_size( "grammar_value_tags") self._lemma_output = torch.nn.Sequential( Dropout(dropout), torch.nn.Linear(actual_input_dim, lemma_representation_dim), Dropout(dropout), torch.nn.Linear(lemma_representation_dim, len(lemmatize_helper))) representation_dim = text_field_embedder.get_output_dim( ) + morpho_vector_dim if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match(representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim") check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(), "tag representation dim", "tag feedforward output dim") check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(), "arc representation dim", "arc feedforward output dim") self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = { tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE } self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info("HELLO FROM INIT") logger.info( f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self._attachment_scores = AttachmentScores() self._gram_val_prediction_accuracy = CategoricalAccuracy() self._lemma_prediction_accuracy = CategoricalAccuracy() initializer(self) @overrides def forward( self, # type: ignore words: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], morpho_embedding: torch.FloatTensor = None, pos_tags: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. metadata : List[Dict[str, Any]], optional (default=None) A dictionary of metadata for each batch element which has keys: words : ``List[str]``, required. The tokens in the original sentence. pos : ``List[str]``, required. The dependencies POS tags for each word. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if morpho_embedding is not None: embedded_text_input = torch.cat( [embedded_text_input, morpho_embedding], -1) if grammar_values is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(grammar_values) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) output_dict = self._parse(embedded_text_input, mask, head_tags, head_indices, grammar_values, lemma_indices) if self.task_config.task_type == "multitask": losses = ["arc_nll", "tag_nll", "grammar_nll", "lemma_nll"] elif self.task_config.task_type == "single": if self.task_config.params["model"] == "morphology": losses = ["grammar_nll"] elif self.task_config.params["model"] == "lemmatization": losses = ["lemma_nll"] elif self.task_config.params["model"] == "syntax": losses = ["arc_nll", "tag_nll"] else: assert False, "Unknown model type {}".format( self.task_config.params["model"]) else: assert False, "Unknown task type {}".format( self.task_config.task_type) output_dict["loss"] = sum(output_dict[loss_name] for loss_name in losses) if head_indices is not None and head_tags is not None: evaluation_mask = self._get_mask_for_eval(mask, pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(output_dict["heads"][:, 1:], output_dict["head_tags"][:, 1:], head_indices, head_tags, evaluation_mask) output_dict["words"] = [meta["words"] for meta in metadata] output_dict["original_words"] = [ meta["original_words"] for meta in metadata ] output_dict["token_nos"] = [meta["token_nos"] for meta in metadata] if metadata and "pos" in metadata[0]: output_dict["pos"] = [meta["pos"] for meta in metadata] return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: head_tags = output_dict.pop("head_tags").cpu().detach().numpy() heads = output_dict.pop("heads").cpu().detach().numpy() predicted_gram_vals = output_dict.pop( "gram_vals").cpu().detach().numpy() predicted_lemmas = output_dict.pop("lemmas").cpu().detach().numpy() mask = output_dict.pop("mask") lengths = get_lengths_from_binary_sequence_mask(mask) top_lemmas = output_dict.pop("top_lemmas").cpu().detach().numpy() top_lemmas_prob = output_dict.pop( "top_lemmas_prob").cpu().detach().numpy() top_gramms = output_dict.pop("top_gramms").cpu().detach().numpy() top_gramms_prob = output_dict.pop( "top_gramms_prob").cpu().detach().numpy() top_heads = output_dict.pop("top_heads").cpu().detach().numpy() top_deprels = output_dict.pop("top_deprels").cpu().detach().numpy() top_dr_probs = output_dict.pop( "top_deprels_prob").cpu().detach().numpy() assert len(head_tags) == len(heads) == len(lengths) == len( predicted_gram_vals) == len(predicted_lemmas) == len( top_lemmas) == len(top_gramms) head_tag_labels, head_indices, decoded_gram_vals, decoded_lemmas = [], [], [], [] decoded_top_lemmas, decoded_top_gramms, decoded_top_deprels = [], [], [] for instance_index in range(len(head_tags)): instance_heads, instance_tags = heads[instance_index], head_tags[ instance_index] words, length = output_dict["words"][instance_index], lengths[ instance_index] gram_vals, lemmas = predicted_gram_vals[ instance_index], predicted_lemmas[instance_index] words = words[:length.item() - 1] gram_vals = gram_vals[:length.item() - 1] lemmas = lemmas[:length.item() - 1] instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [ self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags ] head_tag_labels.append(labels) head_indices.append(instance_heads) decoded_gram_vals.append([ self.vocab.get_token_from_index(gram_val, "grammar_value_tags") for gram_val in gram_vals ]) decoded_lemmas.append([ self.lemmatize_helper.lemmatize(word, lemmatize_rule_index) for word, lemmatize_rule_index in zip(words, lemmas) ]) sent_top_lemmas = [] for w_i in range(len(words)): word_top_lemmas = [] for top_i in range(self.TopNCnt): word_top_lemmas.append( self.lemmatize_helper.lemmatize( words[w_i], top_lemmas[instance_index][w_i][top_i])) sent_top_lemmas.append(word_top_lemmas) decoded_top_lemmas.append(sent_top_lemmas) sent_top_gramms = [] for w_i in range(len(words)): word_top_gramms = [] for top_i in range(self.TopNCnt): word_top_gramms.append( self.vocab.get_token_from_index( top_gramms[instance_index][w_i][top_i], "grammar_value_tags")) sent_top_gramms.append(word_top_gramms) decoded_top_gramms.append(sent_top_gramms) sent_top_deprels = [] itop_heads = top_heads[instance_index][1:length] itop_deprels = top_deprels[instance_index][1:length] itop_dr_probs = top_dr_probs[instance_index][1:length] for w_i in range(len(words)): word_top_deprels = [""] for top_i in range(self.TopNCnt - 1): top_i_head_no = itop_heads[w_i][top_i] top_i_rel = itop_deprels[w_i][top_i] top_i_prb = itop_dr_probs[w_i][top_i] lbl = self.vocab.get_token_from_index( floor(top_i_rel), "head_tags") word_top_deprels.append( str(top_i_head_no) + " " + lbl + " ({:.5f})".format(top_i_prb)) sent_top_deprels.append(word_top_deprels) decoded_top_deprels.append(sent_top_deprels) if self.task_config.task_type == "multitask": output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices output_dict["predicted_gram_vals"] = decoded_gram_vals output_dict["predicted_lemmas"] = decoded_lemmas output_dict["top_lemmas"] = decoded_top_lemmas output_dict["top_lemmas_prob"] = top_lemmas_prob output_dict["top_gramms"] = decoded_top_gramms output_dict["top_gramms_prob"] = top_gramms_prob output_dict["top_deprels"] = decoded_top_deprels elif self.task_config.task_type == "single": if self.task_config.params["model"] == "morphology": output_dict["predicted_gram_vals"] = decoded_gram_vals elif self.task_config.params["model"] == "lemmatization": output_dict["predicted_lemmas"] = decoded_lemmas elif self.task_config.params["model"] == "syntax": output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices else: assert False, "Unknown model type {}".format( self.task_config.params["model"]) else: assert False, "Unknown task type {}".format( self.task_config.task_type) return output_dict def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) grammar_value_logits = self._gram_val_output(encoded_text) predicted_gram_vals = grammar_value_logits.argmax(-1) # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ #l_ext_input = encoded_text l_ext_input = torch.cat([encoded_text, grammar_value_logits], -1) lemma_logits = self._lemma_output(l_ext_input) predicted_lemmas = lemma_logits.argmax(-1) # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ВАРИАНТОВ ЛЕММАТИЗАЦИИ И ОЦЕНОК ВЕРОЯТНОСТИ lemma_probs = torch.nn.functional.softmax(lemma_logits, -1) top_lemmas_indices = (-lemma_logits).argsort(-1)[:, :, :self.TopNCnt] #top_lemmas_indices = (-lemma_probs).argsort(-1)[:,:,:self.TopNCnt] top_lemmas_prob = torch.gather(lemma_probs, -1, top_lemmas_indices) #top_lemmas_prob = torch.gather(lemma_logits, -1, top_lemmas_indices) # АНАЛОГИЧНО ДЛЯ ГРАММЕМ gramm_probs = torch.nn.functional.softmax(grammar_value_logits, -1) top_gramms_indices = ( -grammar_value_logits).argsort(-1)[:, :, :self.TopNCnt] top_gramms_prob = torch.gather(gramm_probs, -1, top_gramms_indices) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: synt_prediction, benrg = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) predicted_heads, predicted_head_tags = synt_prediction # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ЛОКАЛЬНЫХ!!! (не mst) ВАРИАНТОВ СИНТАКСИЧЕСКОГО РАЗБОРА И ОЦЕНОК ВЕРОЯТНОСИ benrgf = torch.flatten(benrg, start_dim=1, end_dim=2).permute( 0, 2, 1) # склеивает тип синт. отношения с индексом родителя top_deprels_indices = (-benrgf).argsort( -1)[:, :, :self.TopNCnt] # отбираем наилучшие комбинации top_deprels_prob = torch.gather(benrgf, -1, top_deprels_indices) seqlen = benrg.shape[2] top_heads = torch.fmod(top_deprels_indices, seqlen) top_deprels_indices = torch.div(top_deprels_indices, seqlen) # torch.floor не срабатывает if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: grammar_nll = self._update_multiclass_prediction_metrics( logits=grammar_value_logits, targets=grammar_values, mask=token_mask, accuracy_metric=self._gram_val_prediction_accuracy) lemma_nll = torch.tensor(0.) if lemma_indices is not None: lemma_nll = self._update_multiclass_prediction_metrics( logits=lemma_logits, targets=lemma_indices, mask=token_mask, accuracy_metric=self._lemma_prediction_accuracy, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, "top_lemmas": top_lemmas_indices, "top_lemmas_prob": top_lemmas_prob, "top_gramms": top_gramms_indices, "top_gramms_prob": top_gramms_prob, "top_heads": top_heads, "top_deprels": top_deprels_indices, "top_deprels_prob": top_deprels_prob, } return output_dict @staticmethod def _update_multiclass_prediction_metrics(logits, targets, mask, accuracy_metric, masked_index=None): accuracy_metric(logits, targets, mask) logits = logits.view(-1, logits.shape[-1]) loss = F.cross_entropy(logits, targets.view(-1), reduction='none') if masked_index is not None: mask = mask * (targets != masked_index) loss_mask = mask.view(-1) return (loss * loss_mask).sum() / loss_mask.sum() def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax( attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax( head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs independently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag( attended_arcs.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, heads) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _mst_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions using the Edmonds' Algorithm for finding minimum spanning trees on directed graphs. Nodes in the graph are the words in the sentence, and between each pair of nodes, there is an edge in each direction, where the weight of the edge corresponds to the most likely dependency label probability for that arc. The MST is then generated from this directed graph. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the optimally decoded heads of each word. """ batch_size, sequence_length, tag_representation_dim = head_tag_representation.size( ) lengths = mask.data.sum(dim=1).long().cpu().numpy() expanded_shape = [ batch_size, sequence_length, sequence_length, tag_representation_dim ] head_tag_representation = head_tag_representation.unsqueeze(2) head_tag_representation = head_tag_representation.expand( *expanded_shape).contiguous() child_tag_representation = child_tag_representation.unsqueeze(1) child_tag_representation = child_tag_representation.expand( *expanded_shape).contiguous() # Shape (batch_size, sequence_length, sequence_length, num_head_tags) pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) # Note that this log_softmax is over the tag dimension, and we don't consider pairs # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. # Shape (batch, num_labels,sequence_length, sequence_length) normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute( 0, 3, 1, 2) # Mask padded tokens, because we only want to consider actual words as heads. minus_inf = -1e8 minus_mask = (1 - mask.float()) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) # Shape (batch_size, sequence_length, sequence_length) normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) # Shape (batch_size, num_head_tags, sequence_length, sequence_length) # This energy tensor expresses the following relation: # energy[i,j] = "Score that i is the head of j". In this # case, we have heads pointing to their children. batch_energy = torch.exp( normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) return self._run_mst_decoding( batch_energy, lengths ), batch_energy #normalized_pairwise_head_logits, normalized_arc_logits @staticmethod def _run_mst_decoding( batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: heads = [] head_tags = [] for energy, length in zip(batch_energy.detach().cpu(), lengths): scores, tag_ids = energy.max(dim=0) # Although we need to include the root node so that the MST includes it, # we do not want any word to be the parent of the root node. # Here, we enforce this by setting the scores for all word -> ROOT edges # edges to be 0. scores[0, :] = 0 # Decode the heads. Because we modify the scores to prevent # adding in word -> ROOT edges, we need to find the labels ourselves. instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False) # Find the labels which correspond to the edges in the max spanning tree. instance_head_tags = [] for child, parent in enumerate(instance_heads): instance_head_tags.append(tag_ids[parent, child].item()) # We don't care what the head or tag is for the root token, but by default it's # not necesarily the same in the batched vs unbatched case, which is annoying. # Here we'll just set them to zero. instance_heads[0] = 0 instance_head_tags[0] = 0 heads.append(instance_heads) head_tags.append(instance_head_tags) return torch.from_numpy(numpy.stack(heads)), torch.from_numpy( numpy.stack(head_tags)) def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector( batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[ range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous( ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = self._attachment_scores.get_metric(reset) metrics['GramValAcc'] = self._gram_val_prediction_accuracy.get_metric( reset) metrics['LemmaAcc'] = self._lemma_prediction_accuracy.get_metric(reset) metrics['MeanAcc'] = (metrics['GramValAcc'] + metrics['LemmaAcc'] + metrics['LAS']) / 3. return metrics
class BiaffineDependencyParser(Model): """ This dependency parser follows the model of ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . Word representations are generated using a bidirectional LSTM, followed by separate biaffine classifiers for pairs of words, predicting whether a directed arc exists between the two words and the dependency label the arc should have. Decoding can either be done greedily, or the optimial Minimum Spanning Tree can be decoded using Edmond's algorithm by viewing the dependency tree as a MST on a fully connected graph, where nodes are words and edges are scored dependency arcs. Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the ``tokens`` ``TextField`` we get as input to the model. encoder : ``Seq2SeqEncoder`` The encoder (with its own internal stacking) that we will use to generate representations of tokens. tag_representation_dim : ``int``, required. The dimension of the MLPs used for dependency tag prediction. arc_representation_dim : ``int``, required. The dimension of the MLPs used for head arc prediction. tag_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce tag representations. By default, a 1 layer feedforward network with an elu activation is used. arc_feedforward : ``FeedForward``, optional, (default = None). The feedforward network used to produce arc representations. By default, a 1 layer feedforward network with an elu activation is used. pos_tag_embedding : ``Embedding``, optional. Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model. use_mst_decoding_for_validation : ``bool``, optional (default = True). Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation. If false, decoding is greedy. dropout : ``float``, optional, (default = 0.0) The variational dropout applied to the output of the encoder and MLP layers. input_dropout : ``float``, optional, (default = 0.0) The dropout applied to the embedded text input. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, tag_representation_dim: int, arc_representation_dim: int, tag_feedforward: FeedForward = None, arc_feedforward: FeedForward = None, pos_tag_embedding: Embedding = None, use_mst_decoding_for_validation: bool = True, dropout: float = 0.0, input_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(BiaffineDependencyParser, self).__init__(vocab, regularizer) self.text_field_embedder = text_field_embedder self.encoder = encoder encoder_dim = encoder.get_output_dim() self.head_arc_feedforward = arc_feedforward or \ FeedForward(encoder_dim, 1, arc_representation_dim, Activation.by_name("elu")()) self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) self.arc_attention = BilinearMatrixAttention(arc_representation_dim, arc_representation_dim, use_input_biases=True) num_labels = self.vocab.get_vocab_size("head_tags") self.head_tag_feedforward = tag_feedforward or \ FeedForward(encoder_dim, 1, tag_representation_dim, Activation.by_name("elu")()) self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim, tag_representation_dim, num_labels) self._pos_tag_embedding = pos_tag_embedding or None self._dropout = InputVariationalDropout(dropout) self._input_dropout = Dropout(input_dropout) self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder.get_output_dim()])) representation_dim = text_field_embedder.get_output_dim() if pos_tag_embedding is not None: representation_dim += pos_tag_embedding.get_output_dim() check_dimensions_match(representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim") check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(), "tag representation dim", "tag feedforward output dim") check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(), "arc representation dim", "arc feedforward output dim") self.use_mst_decoding_for_validation = use_mst_decoding_for_validation tags = self.vocab.get_token_to_index_vocabulary("pos") punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE} self._pos_to_ignore = set(punctuation_tag_indices.values()) logger.info(f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. " "Ignoring words with these POS tags for evaluation.") self._attachment_scores = AttachmentScores() initializer(self) @overrides def forward(self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required. The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata] } return output_dict @overrides def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: head_tags = output_dict.pop("head_tags").cpu().detach().numpy() heads = output_dict.pop("heads").cpu().detach().numpy() mask = output_dict.pop("mask") lengths = get_lengths_from_binary_sequence_mask(mask) head_tag_labels = [] head_indices = [] for instance_heads, instance_tags, length in zip(heads, head_tags, lengths): instance_heads = list(instance_heads[1:length]) instance_tags = instance_tags[1:length] labels = [self.vocab.get_token_from_index(label, "head_tags") for label in instance_tags] head_tag_labels.append(labels) head_indices.append(instance_heads) output_dict["predicted_dependencies"] = head_tag_labels output_dict["predicted_heads"] = head_indices return output_dict def _construct_loss(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax(head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll def _greedy_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions by decoding the unlabeled arcs independently for each word and then again, predicting the head tags of these greedily chosen arcs indpendently. Note that this method of decoding is not guaranteed to produce trees (i.e. there maybe be multiple roots, or cycles when children are attached to their parents). Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the greedily decoded heads of each word. """ # Mask the diagonal, because the head of a word can't be itself. attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf)) # Mask padded tokens, because we only want to consider actual words as heads. if mask is not None: minus_mask = (1 - mask).byte().unsqueeze(2) attended_arcs.masked_fill_(minus_mask, -numpy.inf) # Compute the heads greedily. # shape (batch_size, sequence_length) _, heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, heads) _, head_tags = head_tag_logits.max(dim=2) return heads, head_tags def _mst_decode(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Decodes the head and head tag predictions using the Edmonds' Algorithm for finding minimum spanning trees on directed graphs. Nodes in the graph are the words in the sentence, and between each pair of nodes, there is an edge in each direction, where the weight of the edge corresponds to the most likely dependency label probability for that arc. The MST is then generated from this directed graph. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. Returns ------- heads : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the greedily decoded heads of each word. head_tags : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length) representing the dependency tags of the optimally decoded heads of each word. """ batch_size, sequence_length, tag_representation_dim = head_tag_representation.size() lengths = mask.data.sum(dim=1).long().cpu().numpy() expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim] head_tag_representation = head_tag_representation.unsqueeze(2) head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous() child_tag_representation = child_tag_representation.unsqueeze(1) child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous() # Shape (batch_size, sequence_length, sequence_length, num_head_tags) pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) # Note that this log_softmax is over the tag dimension, and we don't consider pairs # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. # Shape (batch, num_labels,sequence_length, sequence_length) normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2) # Mask padded tokens, because we only want to consider actual words as heads. minus_inf = -1e8 minus_mask = (1 - mask.float()) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) # Shape (batch_size, sequence_length, sequence_length) normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) # Shape (batch_size, num_head_tags, sequence_length, sequence_length) # This energy tensor expresses the following relation: # energy[i,j] = "Score that i is the head of j". In this # case, we have heads pointing to their children. batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) return self._run_mst_decoding(batch_energy, lengths) @staticmethod def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: heads = [] head_tags = [] for energy, length in zip(batch_energy.detach().cpu(), lengths): scores, tag_ids = energy.max(dim=0) # Although we need to include the root node so that the MST includes it, # we do not want any word to be the parent of the root node. # Here, we enforce this by setting the scores for all word -> ROOT edges # edges to be 0. scores[0, :] = 0 # Decode the heads. Because we modify the scores to prevent # adding in word -> ROOT edges, we need to find the labels ourselves. instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False) # Find the labels which correspond to the edges in the max spanning tree. instance_head_tags = [] for child, parent in enumerate(instance_heads): instance_head_tags.append(tag_ids[parent, child].item()) # We don't care what the head or tag is for the root token, but by default it's # not necesarily the same in the batched vs unbatched case, which is annoying. # Here we'll just set them to zero. instance_heads[0] = 0 instance_head_tags[0] = 0 heads.append(instance_heads) head_tags.append(instance_head_tags) return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags)) def _get_head_tags(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor: """ Decodes the head tags given the head and child tag representations and a tensor of head indices to compute tags for. Note that these are either gold or predicted heads, depending on whether this function is being called to compute the loss, or if it's being called during inference. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. Returns ------- head_tag_logits : ``torch.Tensor`` A tensor of shape (batch_size, sequence_length, num_head_tags), representing logits for predicting a distribution over tags for each arc. """ batch_size = head_tag_representation.size(0) # shape (batch_size,) range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1) # This next statement is quite a complex piece of indexing, which you really # need to read the docs to understand. See here: # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing # In effect, we are selecting the indices corresponding to the heads of each word from the # sequence length dimension for each element in the batch. # shape (batch_size, sequence_length, tag_representation_dim) selected_head_tag_representations = head_tag_representation[range_vector, head_indices] selected_head_tag_representations = selected_head_tag_representations.contiguous() # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self.tag_bilinear(selected_head_tag_representations, child_tag_representation) return head_tag_logits def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return self._attachment_scores.get_metric(reset)
class UDParser(Transduction): """ Model will use only the encoder part of a Transduction model, but to make maximally compatible we'll have it inherit and just not use the decoder modules. """ def __init__( self, vocab: Vocabulary, # source-side bert_encoder: BaseBertWrapper, encoder_token_embedder: TextFieldEmbedder, encoder_pos_embedding: Embedding, encoder: Seq2SeqEncoder, syntax_edge_type_namespace: str = None, biaffine_parser: DeepTreeParser = None, dropout: float = 0.0, eps: float = 1e-20, pretrained_weights: str = None, vocab_dir: str = None, ) -> None: super(UDParser, self).__init__(vocab=vocab, bert_encoder=bert_encoder, encoder_token_embedder=encoder_token_embedder, encoder=encoder, decoder_token_embedder=None, decoder_node_index_embedding=None, decoder=None, extended_pointer_generator=None, tree_parser=None, label_smoothing=None, target_output_namespace=None, pretrained_weights=pretrained_weights, dropout=dropout, eps=eps) # source-side self.encoder_pos_embedding = encoder_pos_embedding # misc self._syntax_edge_type_namespace = syntax_edge_type_namespace self.biaffine_parser = biaffine_parser self.vocab_dir = vocab_dir #metrics self._syntax_metrics = AttachmentScores() self.syntax_las = 0.0 self.syntax_uas = 0.0 # compatibility self.loss_mixer = None self.syntactic_method = "encoder-side" # pretrained if self.pretrained_weights is not None: self.load_partial(self.pretrained_weights) # load vocab if self.vocab_dir is not None: syn_vocab = Vocabulary.from_files(vocab_dir) self.vocab._token_to_index[ self._syntax_edge_type_namespace] = syn_vocab._token_to_index[ self._syntax_edge_type_namespace] def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics = OrderedDict( syn_uas=0.0, syn_las=0.0, ) metrics["syn_las"] = self.syntax_las metrics["syn_uas"] = self.syntax_uas return metrics def _update_syntax_scores(self): scores = self._syntax_metrics.get_metric(reset=True) self.syntax_las = scores["LAS"] * 100 self.syntax_uas = scores["UAS"] * 100 def _compute_biaffine_loss(self, biaffine_outputs, inputs): #print(f"pred heads {biaffine_outputs['edge_heads']}") #print(f"true heads {inputs['syn_edge_heads']}") #print(f"pred tags {biaffine_outputs['edge_types']}") #print(f"true types {inputs['syn_edge_types']['syn_edge_types']}") edge_prediction_loss = self._compute_edge_prediction_loss( biaffine_outputs['edge_head_ll'], biaffine_outputs['edge_type_ll'], biaffine_outputs['edge_heads'], biaffine_outputs['edge_types'], inputs['syn_edge_heads'], inputs['syn_edge_types']['syn_edge_types'], inputs['syn_valid_node_mask'], syntax=True) return edge_prediction_loss['loss_per_node'] def _parse_syntax(self, encoder_outputs: torch.Tensor, edge_head_mask: torch.Tensor, edge_heads: torch.Tensor = None, valid_node_mask: torch.Tensor = None, do_mst=False) -> Dict: parser_outputs = self.biaffine_parser(query=encoder_outputs, key=encoder_outputs, edge_head_mask=edge_head_mask, gold_edge_heads=edge_heads, decode_mst=do_mst, valid_node_mask=valid_node_mask) return parser_outputs def _read_edge_predictions( self, edge_predictions: Dict[str, torch.Tensor], is_syntax=False) -> Tuple[List[List[int]], List[List[str]]]: edge_type_predictions = [] edge_head_predictions = edge_predictions["edge_heads"].tolist() edge_type_ind_predictions = edge_predictions["edge_types"].tolist() if is_syntax: namespace = self._syntax_edge_type_namespace else: namespace = self._edge_type_namespace for edge_types in edge_type_ind_predictions: edge_type_predictions.append([ self.vocab.get_token_from_index(edge_type, namespace) for edge_type in edge_types ]) return edge_head_predictions, edge_type_predictions, edge_type_ind_predictions @overrides def _prepare_inputs(self, raw_inputs): inputs = raw_inputs.copy() inputs["source_mask"] = get_text_field_mask( raw_inputs["source_tokens"]) source_subtoken_ids = raw_inputs.get("source_subtoken_ids", None) if source_subtoken_ids is None: inputs["source_subtoken_ids"] = None else: inputs["source_subtoken_ids"] = source_subtoken_ids.long() source_token_recovery_matrix = raw_inputs.get( "source_token_recovery_matrix", None) if source_token_recovery_matrix is None: inputs["source_token_recovery_matrix"] = None else: inputs[ "source_token_recovery_matrix"] = source_token_recovery_matrix.long( ) return inputs def _transformer_encode(self, tokens: Dict[str, torch.Tensor], subtoken_ids: torch.Tensor, token_recovery_matrix: torch.Tensor, mask: torch.Tensor, **kwargs) -> Dict: # [batch, num_tokens, embedding_size] encoder_inputs = [self._encoder_token_embedder(tokens)] if subtoken_ids is not None and self._bert_encoder is not None: bert_embeddings = self._bert_encoder( input_ids=subtoken_ids, attention_mask=subtoken_ids.ne(0), output_all_encoded_layers=False, token_recovery_matrix=token_recovery_matrix) encoder_inputs += [bert_embeddings] encoder_inputs = torch.cat(encoder_inputs, 2) encoder_inputs = self._dropout(encoder_inputs) # [batch, num_tokens, encoder_output_size] encoder_outputs = self._encoder(encoder_inputs, mask) encoder_outputs = self._dropout(encoder_outputs) return dict(encoder_outputs=encoder_outputs, ) @overrides def _encode(self, inputs) -> Dict: if isinstance(self._encoder, MisoTransformerEncoder): encoding_outputs = self._transformer_encode( tokens=inputs["source_tokens"], pos_tags=inputs["source_pos_tags"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"]) else: encoding_outputs = super()._encode( tokens=inputs["source_tokens"], pos_tags=inputs["source_pos_tags"], subtoken_ids=inputs["source_subtoken_ids"], token_recovery_matrix=inputs["source_token_recovery_matrix"], mask=inputs["source_mask"]) return encoding_outputs def _training_forward(self, inputs: Dict) -> Dict[str, torch.Tensor]: encoding_outputs = self._encode(inputs) biaffine_outputs = self._parse_syntax( encoding_outputs['encoder_outputs'], inputs["syn_edge_head_mask"], inputs["syn_edge_heads"], do_mst=False) biaffine_loss = self._compute_biaffine_loss(biaffine_outputs, inputs) return dict(loss=biaffine_loss) def _test_forward(self, inputs: Dict) -> Dict: encoding_outputs = self._encode(inputs) biaffine_outputs = self._parse_syntax( encoding_outputs['encoder_outputs'], inputs["syn_edge_head_mask"], None, valid_node_mask=inputs["syn_valid_node_mask"], do_mst=True) syn_edge_head_predictions, syn_edge_type_predictions, syn_edge_type_inds = self._read_edge_predictions( biaffine_outputs, is_syntax=True) bsz, __ = inputs["source_tokens"]["source_tokens"].shape outputs = dict(syn_nodes=inputs['syn_tokens_str'], syn_edge_heads=syn_edge_head_predictions, syn_edge_types=syn_edge_type_predictions, syn_edge_type_inds=syn_edge_type_inds, loss=torch.tensor([0.0]), nodes=torch.ones((bsz, 1)), node_indices=torch.ones((bsz, 1)), edge_heads=torch.ones((bsz, 1)), edge_types=torch.ones((bsz, 1)), edge_types_inds=torch.ones((bsz, 1)), node_attributes=torch.ones((bsz, 1, 44)), node_attributes_mask=torch.ones((bsz, 1, 44)), edge_attributes=torch.ones((bsz, 1, 14)), edge_attributes_mask=torch.ones((bsz, 1, 14))) return outputs