def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            encoder: Seq2SeqEncoder,
            edge_model: graph_dependency_parser.components.edge_models.
        EdgeModel,
            loss_function: graph_dependency_parser.components.losses.EdgeLoss,
            pos_tag_embedding: Embedding = None,
            use_mst_decoding_for_validation: bool = True,
            dropout: float = 0.0,
            input_dropout: float = 0.0,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None,
            validation_evaluator: Optional[ValidationEvaluator] = None
    ) -> None:
        super(GraphDependencyParser, self).__init__(vocab, regularizer)

        self.validation_evaluator = validation_evaluator

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        check_dimensions_match(encoder.get_output_dim(),
                               edge_model.encoder_dim(), "encoder output dim",
                               "input dim edge model")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)

        self.edge_model = edge_model
        self.loss_function = loss_function

        #Being able to detect what state we are in, probably not the best idea.
        self.current_epoch = 1
        self.pass_over_data_just_started = True
Esempio n. 2
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()
        self.head_arc_projection = torch.nn.Linear(encoder_dim,
                                                   arc_representation_dim)
        self.child_arc_projection = torch.nn.Linear(encoder_dim,
                                                    arc_representation_dim)
        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")
        self.head_tag_projection = torch.nn.Linear(encoder_dim,
                                                   tag_representation_dim)
        self.child_tag_projection = torch.nn.Linear(encoder_dim,
                                                    tag_representation_dim)
        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Esempio n. 3
0
    def __init__(
        self,
        vocab: Vocabulary,
        # source-side
        bert_encoder: BaseBertWrapper,
        encoder_token_embedder: TextFieldEmbedder,
        encoder_pos_embedding: Embedding,
        encoder: Seq2SeqEncoder,
        syntax_edge_type_namespace: str = None,
        biaffine_parser: DeepTreeParser = None,
        dropout: float = 0.0,
        eps: float = 1e-20,
        pretrained_weights: str = None,
        vocab_dir: str = None,
    ) -> None:

        super(UDParser,
              self).__init__(vocab=vocab,
                             bert_encoder=bert_encoder,
                             encoder_token_embedder=encoder_token_embedder,
                             encoder=encoder,
                             decoder_token_embedder=None,
                             decoder_node_index_embedding=None,
                             decoder=None,
                             extended_pointer_generator=None,
                             tree_parser=None,
                             label_smoothing=None,
                             target_output_namespace=None,
                             pretrained_weights=pretrained_weights,
                             dropout=dropout,
                             eps=eps)

        # source-side
        self.encoder_pos_embedding = encoder_pos_embedding
        # misc
        self._syntax_edge_type_namespace = syntax_edge_type_namespace
        self.biaffine_parser = biaffine_parser
        self.vocab_dir = vocab_dir
        #metrics
        self._syntax_metrics = AttachmentScores()
        self.syntax_las = 0.0
        self.syntax_uas = 0.0
        # compatibility
        self.loss_mixer = None
        self.syntactic_method = "encoder-side"

        # pretrained
        if self.pretrained_weights is not None:
            self.load_partial(self.pretrained_weights)
        # load vocab
        if self.vocab_dir is not None:
            syn_vocab = Vocabulary.from_files(vocab_dir)
            self.vocab._token_to_index[
                self._syntax_edge_type_namespace] = syn_vocab._token_to_index[
                    self._syntax_edge_type_namespace]
Esempio n. 4
0
    def __init__(self, vocab: Vocabulary,
                 name:str,
                 edge_model: EdgeModel,
                 loss_function: EdgeLoss,
                 supertagger: Supertagger,
                 lexlabeltagger: Supertagger,
                 supertagger_loss : SupertaggingLoss,
                 lexlabel_loss : SupertaggingLoss,
                 output_null_lex_label : bool = True,
                 loss_mixing : Dict[str,float] = None,
                 dropout : float = 0.0,
                 validation_evaluator: Optional[Evaluator] = None,
                 regularizer: Optional[RegularizerApplicator] = None):

        super().__init__(vocab,regularizer)
        self.name = name
        self.edge_model = edge_model
        self.supertagger = supertagger
        self.lexlabeltagger = lexlabeltagger
        self.supertagger_loss = supertagger_loss
        self.lexlabel_loss = lexlabel_loss
        self.loss_function = loss_function
        self.loss_mixing = loss_mixing or dict()
        self.validation_evaluator = validation_evaluator
        self.output_null_lex_label = output_null_lex_label

        self._dropout = InputVariationalDropout(dropout)

        loss_names = ["edge_existence", "edge_label", "supertagging", "lexlabel"]
        for loss_name in loss_names:
            if loss_name not in self.loss_mixing:
                self.loss_mixing[loss_name] = 1.0
                logger.info(f"Loss name {loss_name} not found in loss_mixing, using a weight of 1.0")
            else:
                if self.loss_mixing[loss_name] is None:
                    if loss_name not in ["supertagging", "lexlabel"]:
                        raise ConfigurationError("Only the loss mixing coefficients for supertagging and lexlabel may be None, but not "+loss_name)

        not_contained = set(self.loss_mixing.keys()) - set(loss_names)
        if len(not_contained):
            logger.critical(f"The following loss name(s) are unknown: {not_contained}")
            raise ValueError(f"The following loss name(s) are unknown: {not_contained}")

        self._supertagging_acc = CategoricalAccuracy()
        self._top_6supertagging_acc = CategoricalAccuracy(top_k=6)
        self._lexlabel_acc = CategoricalAccuracy()
        self._attachment_scores = AttachmentScores()
        self.current_epoch = 0

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")
    def test_attachment_scores_can_ignore_labels(self):
        scorer = AttachmentScores(ignore_classes=[1])

        label_predictions = self.label_predictions
        # Change the predictions where the gold label is 1;
        # as we are ignoring 1, we should still get a perfect score.
        label_predictions[0, 3] = 2
        scorer(self.predictions, label_predictions, self.gold_indices,
               self.gold_labels, self.mask)

        for value in list(scorer.get_metric().values()):
            assert value == 1.0
    def test_attachment_scores_can_ignore_labels(self):
        scorer = AttachmentScores(ignore_classes=[1])

        label_predictions = self.label_predictions
        # Change the predictions where the gold label is 1;
        # as we are ignoring 1, we should still get a perfect score.
        label_predictions[0, 3] = 2
        scorer(self.predictions, label_predictions,
               self.gold_indices, self.gold_labels, self.mask)

        for value in scorer.get_metric().values():
            assert value == 1.0
    def setUp(self):
        super().setUp()
        self.scorer = AttachmentScores()

        self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]])

        self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]])

        self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]])

        self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]])

        self.mask = torch.Tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
Esempio n. 8
0
class ParsingEvaluator(Evaluator):
    def __init__(self):
        self._attachment_scores = AttachmentScores()

    def add(self, gold_idx, gold_label, prediction_idx, prediction_label,
            mask):
        self._attachment_scores(prediction_idx, prediction_label, gold_idx,
                                gold_label, mask)

    def get_metric(self):
        return self._attachment_scores.get_metric(reset=False)

    def reset(self):
        self._attachment_scores.reset()
Esempio n. 9
0
    def setup_method(self):
        super().setup_method()
        self.scorer = AttachmentScores()

        self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]])

        self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4], [0, 3, 2, 1, 0, 0]])

        self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]])

        self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2], [0, 4, 8, 2, 0, 0]])

        self.mask = torch.tensor(
            [[True, True, True, True, True, True], [True, True, True, True, False, False]]
        )
Esempio n. 10
0
    def __init__(self,
                 vocab: Vocabulary,
                 # source-side
                 bert_encoder: Seq2SeqBertEncoder,
                 encoder_token_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 # target-side
                 decoder_token_embedder: TextFieldEmbedder,
                 decoder_node_index_embedding: Embedding,
                 decoder: RNNDecoder,
                 extended_pointer_generator: ExtendedPointerGenerator,
                 tree_parser: DeepTreeParser,
                 # misc
                 label_smoothing: LabelSmoothing,
                 target_output_namespace: str,
                 dropout: float = 0.0,
                 eps: float = 1e-20,
                 pretrained_weights: str = None,
                 ) -> None:
        super().__init__(vocab=vocab)
        # source-side
        self._bert_encoder = bert_encoder
        self._encoder_token_embedder = encoder_token_embedder
        self._encoder = encoder

        # target-side
        self._decoder_token_embedder = decoder_token_embedder
        self._decoder_node_index_embedding = decoder_node_index_embedding
        self._decoder = decoder
        self._extended_pointer_generator = extended_pointer_generator
        self._tree_parser = tree_parser

        # metrics
        self._node_pred_metrics = ExtendedPointerGeneratorMetrics()
        self._edge_pred_metrics = AttachmentScores()
        self._synt_edge_pred_metrics = AttachmentScores()

        self._label_smoothing = label_smoothing
        self._dropout = InputVariationalDropout(p=dropout)
        self._eps = eps

        # dynamic initialization
        self._target_output_namespace = target_output_namespace
        self._vocab_size = self.vocab.get_vocab_size(target_output_namespace)
        self._vocab_pad_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN, target_output_namespace)

        # loading partial weights
        self.pretrained_weights = pretrained_weights
Esempio n. 11
0
    def test_multiple_distributed_runs(self):
        predictions = [torch.Tensor([[0, 1, 3, 5, 2, 4]]), torch.Tensor([[0, 3, 2, 1, 0, 0]])]

        gold_indices = [torch.Tensor([[0, 1, 3, 5, 2, 4]]), torch.Tensor([[0, 3, 2, 1, 0, 0]])]

        label_predictions = [
            torch.Tensor([[0, 5, 2, 3, 3, 3]]),
            torch.Tensor([[7, 4, 8, 2, 0, 0]]),
        ]

        gold_labels = [torch.Tensor([[0, 5, 2, 1, 4, 2]]), torch.Tensor([[0, 4, 8, 2, 0, 0]])]

        mask = [
            torch.tensor([[True, True, True, True, True, True]]),
            torch.tensor([[True, True, True, True, False, False]]),
        ]

        metric_kwargs = {
            "predicted_indices": predictions,
            "gold_indices": gold_indices,
            "predicted_labels": label_predictions,
            "gold_labels": gold_labels,
            "mask": mask,
        }

        desired_metrics = {
            "UAS": 1.0,
            "LAS": 0.6,
            "UEM": 1.0,
            "LEM": 0.0,
        }
        run_distributed_test(
            [-1, -1], multiple_runs, AttachmentScores(), metric_kwargs, desired_metrics, exact=True,
        )
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()
        self.head_arc_projection = torch.nn.Linear(encoder_dim,
                                                   arc_representation_dim)
        self.child_arc_projection = torch.nn.Linear(encoder_dim,
                                                    arc_representation_dim)
        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")
        self.head_tag_projection = torch.nn.Linear(encoder_dim,
                                                   tag_representation_dim)
        self.child_tag_projection = torch.nn.Linear(encoder_dim,
                                                    tag_representation_dim)
        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        self._attachment_scores = AttachmentScores()
        initializer(self)
    def __init__(self,
                 validation_data_path: str,
                 validation_prediction_path: str,
                 semantics_only: bool,
                 drop_syntax: bool,
                 include_attribute_scores: bool = False,
                 warmup_epochs: int = 0,
                 syntactic_method: str = 'concat-after',
                 *args,
                 **kwargs):
        super(DecompSyntaxTrainer,
              self).__init__(validation_data_path, validation_prediction_path,
                             semantics_only, drop_syntax,
                             include_attribute_scores, warmup_epochs, *args,
                             **kwargs)

        self.attachment_scorer = AttachmentScores()
        self.syntactic_method = syntactic_method

        if self.model.loss_mixer is not None:
            self.model.loss_mixer.update_weights(curr_epoch=0,
                                                 total_epochs=self._num_epochs)
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()
        self.head_arc_projection = torch.nn.Linear(encoder_dim, arc_representation_dim)
        self.child_arc_projection = torch.nn.Linear(encoder_dim, arc_representation_dim)
        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")
        self.head_tag_projection = torch.nn.Linear(encoder_dim, tag_representation_dim)
        self.child_tag_projection = torch.nn.Linear(encoder_dim, tag_representation_dim)
        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Esempio n. 15
0
    def setUp(self):
        super().setUp()
        self.scorer = AttachmentScores()

        self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4],
                                         [0, 3, 2, 1, 0, 0]])

        self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4],
                                          [0, 3, 2, 1, 0, 0]])

        self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2],
                                               [0, 4, 8, 2, 0, 0]])

        self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2],
                                         [0, 4, 8, 2, 0, 0]])

        self.mask = torch.Tensor([[1, 1, 1, 1, 1, 1],
                                  [1, 1, 1, 1, 0, 0]])
Esempio n. 16
0
def multiple_runs(
    global_rank: int,
    world_size: int,
    gpu_id: Union[int, torch.device],
    metric: AttachmentScores,
    metric_kwargs: Dict[str, List[Any]],
    desired_values: Dict[str, Any],
    exact: Union[bool, Tuple[float, float]] = True,
):

    kwargs = {}
    # Use the arguments meant for the process with rank `global_rank`.
    for argname in metric_kwargs:
        kwargs[argname] = metric_kwargs[argname][global_rank]

    for i in range(200):
        metric(**kwargs)

    metrics = metric.get_metric()

    for key in metrics:
        assert desired_values[key] == metrics[key]
Esempio n. 17
0
    def __init__(self,
                 options,
                 tag_dim,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(None, regularizer)

        self.device = options.device

        encoder = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(tag_dim,
                          options.lstm_dims,
                          batch_first=True,
                          bidirectional=True))
        # encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(tag_dim, options.lstm_dims, batch_first=True))
        self.encoder = encoder
        # TODO: IMPORTANT
        num_labels = options.num_labels
        self.ablation = options.ablation
        # print(num_labels)
        tag_representation_dim = options.tag_representation_dim  # 100
        arc_representation_dim = options.arc_representation_dim  # 200

        encoder_dim = self.encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)


        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = tag_dim  #text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        self._pos_to_ignore = set()
        # tags = self.vocab.get_token_to_index_vocabulary("pos")
        # punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        # self._pos_to_ignore = set(punctuation_tag_indices.values())
        # logger.info(f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
        #             "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Esempio n. 18
0
class DependencyDecoder(Model):
    """
    Modifies BiaffineDependencyParser, removing the input TextFieldEmbedder dependency to allow the model to
    essentially act as a decoder when given intermediate word embeddings instead of as a standalone model.
    """

    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_embed_dim: int = None,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DependencyDecoder, self).__init__(vocab, regularizer)

        self.pos_tag_embedding = None
        if pos_embed_dim is not None:
            self.pos_tag_embedding = Embedding(self.vocab.get_vocab_size("upos"), pos_embed_dim)

        self.dropout = torch.nn.Dropout(p=dropout)

        self.encoder = encoder
        encoder_output_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_output_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_output_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._dropout = InputVariationalDropout(dropout)
        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder_output_dim]))

        check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim", "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim", "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                # words: Dict[str, torch.LongTensor],
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                pos_logits: torch.LongTensor = None,  # predicted
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, _, _ = encoded_text.size()

        pos_tags = None
        if pos_logits is not None and self.pos_tag_embedding is not None:
            # Embed the predicted POS tags and concatenate the embeddings to the input
            num_pos_classes = pos_logits.size(-1)
            pos_logits = pos_logits.view(-1, num_pos_classes)
            _, pos_tags = pos_logits.max(-1)

            pos_embed_size = self.pos_tag_embedding.get_output_dim()
            embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags))
            embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1)

        encoded_text = self.encoder(encoded_text, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
            "words": [meta["words"] for meta in metadata],
            # "pos": [meta["pos"] for meta in metadata]
        }

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)
        head_tag_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [self.vocab.get_token_from_index(label, "head_tags")
                      for label in instance_tags]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_dependencies"] = head_tag_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict

    def _construct_loss(self,
                        head_tag_representation: torch.Tensor,
                        child_tag_representation: torch.Tensor,
                        attended_arcs: torch.Tensor,
                        head_indices: torch.Tensor,
                        head_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.
        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(attended_arcs,
                                                   mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
        normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
                                                        mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll

    def _greedy_decode(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       attended_arcs: torch.Tensor,
                       mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs independently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).byte().unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags

    def _mst_decode(self,
                    head_tag_representation: torch.Tensor,
                    child_tag_representation: torch.Tensor,
                    attended_arcs: torch.Tensor,
                    mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size()

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2)

        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.
        batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits)
        return self._run_mst_decoding(batch_energy, lengths)

    @staticmethod
    def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))

    def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.
        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits

    def _get_mask_for_eval(self,
                           mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.
        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.
        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {f".run/deps/{metric_name}": metric
                for metric_name, metric in self._attachment_scores.get_metric(reset).items()}
class AttachmentScoresTest(AllenNlpTestCase):
    def setUp(self):
        super(AttachmentScoresTest, self).setUp()
        self.scorer = AttachmentScores()

        self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4],
                                         [0, 3, 2, 1, 0, 0]])

        self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4],
                                          [0, 3, 2, 1, 0, 0]])

        self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2],
                                               [0, 4, 8, 2, 0, 0]])

        self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2],
                                         [0, 4, 8, 2, 0, 0]])

        self.mask = torch.Tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])

    def test_perfect_scores(self):
        self.scorer(self.predictions, self.label_predictions,
                    self.gold_indices, self.gold_labels, self.mask)

        for value in list(self.scorer.get_metric().values()):
            assert value == 1.0

    def test_unlabeled_accuracy_ignores_incorrect_labels(self):
        label_predictions = self.label_predictions
        # Change some stuff so our 4 of our label predictions are wrong.
        label_predictions[0, 3:] = 3
        label_predictions[1, 0] = 7
        self.scorer(self.predictions, label_predictions, self.gold_indices,
                    self.gold_labels, self.mask)

        metrics = self.scorer.get_metric()

        assert metrics[u"UAS"] == 1.0
        assert metrics[u"UEM"] == 1.0

        # 4 / 12 labels were wrong and 2 positions
        # are masked, so 6/10 = 0.6 LAS.
        assert metrics[u"LAS"] == 0.6
        # Neither should have labeled exact match.
        assert metrics[u"LEM"] == 0.0

    def test_labeled_accuracy_is_affected_by_incorrect_heads(self):
        predictions = self.predictions
        # Change some stuff so our 4 of our predictions are wrong.
        predictions[0, 3:] = 3
        predictions[1, 0] = 7
        # This one is in the padded part, so it shouldn't affect anything.
        predictions[1, 5] = 7
        self.scorer(predictions, self.label_predictions, self.gold_indices,
                    self.gold_labels, self.mask)

        metrics = self.scorer.get_metric()

        # 4 heads are incorrect, so the unlabeled score should be
        # 6/10 = 0.6 LAS.
        assert metrics[u"UAS"] == 0.6
        # All the labels were correct, but some heads
        # were wrong, so the LAS should equal the UAS.
        assert metrics[u"LAS"] == 0.6

        # Neither batch element had a perfect labeled or unlabeled EM.
        assert metrics[u"LEM"] == 0.0
        assert metrics[u"UEM"] == 0.0

    def test_attachment_scores_can_ignore_labels(self):
        scorer = AttachmentScores(ignore_classes=[1])

        label_predictions = self.label_predictions
        # Change the predictions where the gold label is 1;
        # as we are ignoring 1, we should still get a perfect score.
        label_predictions[0, 3] = 2
        scorer(self.predictions, label_predictions, self.gold_indices,
               self.gold_labels, self.mask)

        for value in list(scorer.get_metric().values()):
            assert value == 1.0
Esempio n. 20
0
class AttachmentScoresTest(AllenNlpTestCase):
    def setup_method(self):
        super().setup_method()
        self.scorer = AttachmentScores()

        self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4],
                                         [0, 3, 2, 1, 0, 0]])

        self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4],
                                          [0, 3, 2, 1, 0, 0]])

        self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2],
                                               [0, 4, 8, 2, 0, 0]])

        self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2],
                                         [0, 4, 8, 2, 0, 0]])

        self.mask = torch.tensor([[True, True, True, True, True, True],
                                  [True, True, True, True, False, False]])

    def _send_tensors_to_device(self, device: str):
        self.predictions = self.predictions.to(device)
        self.gold_indices = self.gold_indices.to(device)
        self.label_predictions = self.label_predictions.to(device)
        self.gold_labels = self.gold_labels.to(device)
        self.mask = self.mask.to(device)

    @multi_device
    def test_perfect_scores(self, device: str):
        self._send_tensors_to_device(device)

        self.scorer(self.predictions, self.label_predictions,
                    self.gold_indices, self.gold_labels, self.mask)

        for value in self.scorer.get_metric().values():
            assert value == 1.0

    @multi_device
    def test_unlabeled_accuracy_ignores_incorrect_labels(self, device: str):
        self._send_tensors_to_device(device)

        label_predictions = self.label_predictions
        # Change some stuff so our 4 of our label predictions are wrong.
        label_predictions[0, 3:] = 3
        label_predictions[1, 0] = 7
        self.scorer(self.predictions, label_predictions, self.gold_indices,
                    self.gold_labels, self.mask)

        metrics = self.scorer.get_metric()

        assert metrics["UAS"] == 1.0
        assert metrics["UEM"] == 1.0

        # 4 / 12 labels were wrong and 2 positions
        # are masked, so 6/10 = 0.6 LAS.
        assert metrics["LAS"] == 0.6
        # Neither should have labeled exact match.
        assert metrics["LEM"] == 0.0

    @multi_device
    def test_labeled_accuracy_is_affected_by_incorrect_heads(
            self, device: str):
        self._send_tensors_to_device(device)

        predictions = self.predictions
        # Change some stuff so our 4 of our predictions are wrong.
        predictions[0, 3:] = 3
        predictions[1, 0] = 7
        # This one is in the padded part, so it shouldn't affect anything.
        predictions[1, 5] = 7
        self.scorer(predictions, self.label_predictions, self.gold_indices,
                    self.gold_labels, self.mask)

        metrics = self.scorer.get_metric()

        # 4 heads are incorrect, so the unlabeled score should be
        # 6/10 = 0.6 LAS.
        assert metrics["UAS"] == 0.6
        # All the labels were correct, but some heads
        # were wrong, so the LAS should equal the UAS.
        assert metrics["LAS"] == 0.6

        # Neither batch element had a perfect labeled or unlabeled EM.
        assert metrics["LEM"] == 0.0
        assert metrics["UEM"] == 0.0

    @multi_device
    def test_attachment_scores_can_ignore_labels(self, device: str):
        self._send_tensors_to_device(device)

        scorer = AttachmentScores(ignore_classes=[1])

        label_predictions = self.label_predictions
        # Change the predictions where the gold label is 1;
        # as we are ignoring 1, we should still get a perfect score.
        label_predictions[0, 3] = 2
        scorer(self.predictions, label_predictions, self.gold_indices,
               self.gold_labels, self.mask)

        for value in scorer.get_metric().values():
            assert value == 1.0

    def test_distributed_attachment_scores(self):
        predictions = [
            torch.Tensor([[0, 1, 3, 5, 2, 4]]),
            torch.Tensor([[0, 3, 2, 1, 0, 0]])
        ]

        gold_indices = [
            torch.Tensor([[0, 1, 3, 5, 2, 4]]),
            torch.Tensor([[0, 3, 2, 1, 0, 0]])
        ]

        label_predictions = [
            torch.Tensor([[0, 5, 2, 3, 3, 3]]),
            torch.Tensor([[7, 4, 8, 2, 0, 0]]),
        ]

        gold_labels = [
            torch.Tensor([[0, 5, 2, 1, 4, 2]]),
            torch.Tensor([[0, 4, 8, 2, 0, 0]])
        ]

        mask = [
            torch.tensor([[True, True, True, True, True, True]]),
            torch.tensor([[True, True, True, True, False, False]]),
        ]

        metric_kwargs = {
            "predicted_indices": predictions,
            "gold_indices": gold_indices,
            "predicted_labels": label_predictions,
            "gold_labels": gold_labels,
            "mask": mask,
        }

        desired_metrics = {
            "UAS": 1.0,
            "LAS": 0.6,
            "UEM": 1.0,
            "LEM": 0.0,
        }
        run_distributed_test(
            [-1, -1],
            global_distributed_metric,
            AttachmentScores(),
            metric_kwargs,
            desired_metrics,
            exact=True,
        )
Esempio n. 21
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 lemmatize_helper: LemmatizeHelper,
                 task_config: TaskConfig,
                 morpho_vector_dim: int = 0,
                 gram_val_representation_dim: int = -1,
                 lemma_representation_dim: int = -1,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DependencyParser, self).__init__(vocab, regularizer)

        self.TopNCnt = 3

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.lemmatize_helper = lemmatize_helper
        self.task_config = task_config

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        assert self.task_config.params.get("use_pos_tag",
                                           False) == (self._pos_tag_embedding
                                                      is not None)

        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        if gram_val_representation_dim <= 0:
            self._gram_val_output = torch.nn.Linear(
                encoder_dim, self.vocab.get_vocab_size("grammar_value_tags"))
        else:
            self._gram_val_output = torch.nn.Sequential(
                Dropout(dropout),
                torch.nn.Linear(encoder_dim, gram_val_representation_dim),
                Dropout(dropout),
                torch.nn.Linear(
                    gram_val_representation_dim,
                    self.vocab.get_vocab_size("grammar_value_tags")))

        if lemma_representation_dim <= 0:
            self._lemma_output = torch.nn.Linear(encoder_dim,
                                                 len(lemmatize_helper))
        else:
            # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ
            #actual_input_dim = encoder_dim
            actual_input_dim = encoder_dim + self.vocab.get_vocab_size(
                "grammar_value_tags")
            self._lemma_output = torch.nn.Sequential(
                Dropout(dropout),
                torch.nn.Linear(actual_input_dim, lemma_representation_dim),
                Dropout(dropout),
                torch.nn.Linear(lemma_representation_dim,
                                len(lemmatize_helper)))

        representation_dim = text_field_embedder.get_output_dim(
        ) + morpho_vector_dim
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info("HELLO FROM INIT")
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        self._gram_val_prediction_accuracy = CategoricalAccuracy()
        self._lemma_prediction_accuracy = CategoricalAccuracy()

        initializer(self)
Esempio n. 22
0
class DependencyParser(Model):
    """
    This dependency parser follows the model of
    ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
    <https://arxiv.org/abs/1611.01734>`_ .

    Word representations are generated using a bidirectional LSTM,
    followed by separate biaffine classifiers for pairs of words,
    predicting whether a directed arc exists between the two words
    and the dependency label the arc should have. Decoding can either
    be done greedily, or the optimal Minimum Spanning Tree can be
    decoded using Edmond's algorithm by viewing the dependency tree as
    a MST on a fully connected graph, where nodes are words and edges
    are scored dependency arcs.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    encoder : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use to generate representations
        of tokens.
    tag_representation_dim : ``int``, required.
        The dimension of the MLPs used for dependency tag prediction.
    arc_representation_dim : ``int``, required.
        The dimension of the MLPs used for head arc prediction.
    tag_feedforward : ``FeedForward``, optional, (default = None).
        The feedforward network used to produce tag representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    arc_feedforward : ``FeedForward``, optional, (default = None).
        The feedforward network used to produce arc representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    pos_tag_embedding : ``Embedding``, optional.
        Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model.
    use_mst_decoding_for_validation : ``bool``, optional (default = True).
        Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation.
        If false, decoding is greedy.
    dropout : ``float``, optional, (default = 0.0)
        The variational dropout applied to the output of the encoder and MLP layers.
    input_dropout : ``float``, optional, (default = 0.0)
        The dropout applied to the embedded text input.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 lemmatize_helper: LemmatizeHelper,
                 task_config: TaskConfig,
                 morpho_vector_dim: int = 0,
                 gram_val_representation_dim: int = -1,
                 lemma_representation_dim: int = -1,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DependencyParser, self).__init__(vocab, regularizer)

        self.TopNCnt = 3

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.lemmatize_helper = lemmatize_helper
        self.task_config = task_config

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        assert self.task_config.params.get("use_pos_tag",
                                           False) == (self._pos_tag_embedding
                                                      is not None)

        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        if gram_val_representation_dim <= 0:
            self._gram_val_output = torch.nn.Linear(
                encoder_dim, self.vocab.get_vocab_size("grammar_value_tags"))
        else:
            self._gram_val_output = torch.nn.Sequential(
                Dropout(dropout),
                torch.nn.Linear(encoder_dim, gram_val_representation_dim),
                Dropout(dropout),
                torch.nn.Linear(
                    gram_val_representation_dim,
                    self.vocab.get_vocab_size("grammar_value_tags")))

        if lemma_representation_dim <= 0:
            self._lemma_output = torch.nn.Linear(encoder_dim,
                                                 len(lemmatize_helper))
        else:
            # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ
            #actual_input_dim = encoder_dim
            actual_input_dim = encoder_dim + self.vocab.get_vocab_size(
                "grammar_value_tags")
            self._lemma_output = torch.nn.Sequential(
                Dropout(dropout),
                torch.nn.Linear(actual_input_dim, lemma_representation_dim),
                Dropout(dropout),
                torch.nn.Linear(lemma_representation_dim,
                                len(lemmatize_helper)))

        representation_dim = text_field_embedder.get_output_dim(
        ) + morpho_vector_dim
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info("HELLO FROM INIT")
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        self._gram_val_prediction_accuracy = CategoricalAccuracy()
        self._lemma_prediction_accuracy = CategoricalAccuracy()

        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            words: Dict[str, torch.LongTensor],
            metadata: List[Dict[str, Any]],
            morpho_embedding: torch.FloatTensor = None,
            pos_tags: torch.LongTensor = None,
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None,
            grammar_values: torch.LongTensor = None,
            lemma_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        metadata : List[Dict[str, Any]], optional (default=None)
            A dictionary of metadata for each batch element which has keys:
                words : ``List[str]``, required.
                    The tokens in the original sentence.
                pos : ``List[str]``, required.
                    The dependencies POS tags for each word.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        loss : ``torch.FloatTensor``, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)

        if morpho_embedding is not None:
            embedded_text_input = torch.cat(
                [embedded_text_input, morpho_embedding], -1)

        if grammar_values is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(grammar_values)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)

        output_dict = self._parse(embedded_text_input, mask, head_tags,
                                  head_indices, grammar_values, lemma_indices)

        if self.task_config.task_type == "multitask":
            losses = ["arc_nll", "tag_nll", "grammar_nll", "lemma_nll"]
        elif self.task_config.task_type == "single":
            if self.task_config.params["model"] == "morphology":
                losses = ["grammar_nll"]
            elif self.task_config.params["model"] == "lemmatization":
                losses = ["lemma_nll"]
            elif self.task_config.params["model"] == "syntax":
                losses = ["arc_nll", "tag_nll"]
            else:
                assert False, "Unknown model type {}".format(
                    self.task_config.params["model"])
        else:
            assert False, "Unknown task type {}".format(
                self.task_config.task_type)

        output_dict["loss"] = sum(output_dict[loss_name]
                                  for loss_name in losses)

        if head_indices is not None and head_tags is not None:
            evaluation_mask = self._get_mask_for_eval(mask, pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(output_dict["heads"][:, 1:],
                                    output_dict["head_tags"][:, 1:],
                                    head_indices, head_tags, evaluation_mask)

        output_dict["words"] = [meta["words"] for meta in metadata]
        output_dict["original_words"] = [
            meta["original_words"] for meta in metadata
        ]
        output_dict["token_nos"] = [meta["token_nos"] for meta in metadata]
        if metadata and "pos" in metadata[0]:
            output_dict["pos"] = [meta["pos"] for meta in metadata]

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        predicted_gram_vals = output_dict.pop(
            "gram_vals").cpu().detach().numpy()
        predicted_lemmas = output_dict.pop("lemmas").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)
        top_lemmas = output_dict.pop("top_lemmas").cpu().detach().numpy()
        top_lemmas_prob = output_dict.pop(
            "top_lemmas_prob").cpu().detach().numpy()
        top_gramms = output_dict.pop("top_gramms").cpu().detach().numpy()
        top_gramms_prob = output_dict.pop(
            "top_gramms_prob").cpu().detach().numpy()
        top_heads = output_dict.pop("top_heads").cpu().detach().numpy()
        top_deprels = output_dict.pop("top_deprels").cpu().detach().numpy()
        top_dr_probs = output_dict.pop(
            "top_deprels_prob").cpu().detach().numpy()

        assert len(head_tags) == len(heads) == len(lengths) == len(
            predicted_gram_vals) == len(predicted_lemmas) == len(
                top_lemmas) == len(top_gramms)

        head_tag_labels, head_indices, decoded_gram_vals, decoded_lemmas = [], [], [], []
        decoded_top_lemmas, decoded_top_gramms, decoded_top_deprels = [], [], []
        for instance_index in range(len(head_tags)):
            instance_heads, instance_tags = heads[instance_index], head_tags[
                instance_index]
            words, length = output_dict["words"][instance_index], lengths[
                instance_index]
            gram_vals, lemmas = predicted_gram_vals[
                instance_index], predicted_lemmas[instance_index]

            words = words[:length.item() - 1]
            gram_vals = gram_vals[:length.item() - 1]
            lemmas = lemmas[:length.item() - 1]

            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [
                self.vocab.get_token_from_index(label, "head_tags")
                for label in instance_tags
            ]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

            decoded_gram_vals.append([
                self.vocab.get_token_from_index(gram_val, "grammar_value_tags")
                for gram_val in gram_vals
            ])

            decoded_lemmas.append([
                self.lemmatize_helper.lemmatize(word, lemmatize_rule_index)
                for word, lemmatize_rule_index in zip(words, lemmas)
            ])

            sent_top_lemmas = []
            for w_i in range(len(words)):
                word_top_lemmas = []
                for top_i in range(self.TopNCnt):
                    word_top_lemmas.append(
                        self.lemmatize_helper.lemmatize(
                            words[w_i],
                            top_lemmas[instance_index][w_i][top_i]))
                sent_top_lemmas.append(word_top_lemmas)
            decoded_top_lemmas.append(sent_top_lemmas)

            sent_top_gramms = []
            for w_i in range(len(words)):
                word_top_gramms = []
                for top_i in range(self.TopNCnt):
                    word_top_gramms.append(
                        self.vocab.get_token_from_index(
                            top_gramms[instance_index][w_i][top_i],
                            "grammar_value_tags"))
                sent_top_gramms.append(word_top_gramms)
            decoded_top_gramms.append(sent_top_gramms)

            sent_top_deprels = []
            itop_heads = top_heads[instance_index][1:length]
            itop_deprels = top_deprels[instance_index][1:length]
            itop_dr_probs = top_dr_probs[instance_index][1:length]
            for w_i in range(len(words)):
                word_top_deprels = [""]
                for top_i in range(self.TopNCnt - 1):
                    top_i_head_no = itop_heads[w_i][top_i]
                    top_i_rel = itop_deprels[w_i][top_i]
                    top_i_prb = itop_dr_probs[w_i][top_i]
                    lbl = self.vocab.get_token_from_index(
                        floor(top_i_rel), "head_tags")
                    word_top_deprels.append(
                        str(top_i_head_no) + " " + lbl +
                        " ({:.5f})".format(top_i_prb))
                sent_top_deprels.append(word_top_deprels)
            decoded_top_deprels.append(sent_top_deprels)

        if self.task_config.task_type == "multitask":
            output_dict["predicted_dependencies"] = head_tag_labels
            output_dict["predicted_heads"] = head_indices
            output_dict["predicted_gram_vals"] = decoded_gram_vals
            output_dict["predicted_lemmas"] = decoded_lemmas
            output_dict["top_lemmas"] = decoded_top_lemmas
            output_dict["top_lemmas_prob"] = top_lemmas_prob
            output_dict["top_gramms"] = decoded_top_gramms
            output_dict["top_gramms_prob"] = top_gramms_prob
            output_dict["top_deprels"] = decoded_top_deprels
        elif self.task_config.task_type == "single":
            if self.task_config.params["model"] == "morphology":
                output_dict["predicted_gram_vals"] = decoded_gram_vals
            elif self.task_config.params["model"] == "lemmatization":
                output_dict["predicted_lemmas"] = decoded_lemmas
            elif self.task_config.params["model"] == "syntax":
                output_dict["predicted_dependencies"] = head_tag_labels
                output_dict["predicted_heads"] = head_indices
            else:
                assert False, "Unknown model type {}".format(
                    self.task_config.params["model"])
        else:
            assert False, "Unknown task type {}".format(
                self.task_config.task_type)

        return output_dict

    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        grammar_value_logits = self._gram_val_output(encoded_text)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

        # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ
        #l_ext_input = encoded_text
        l_ext_input = torch.cat([encoded_text, grammar_value_logits], -1)
        lemma_logits = self._lemma_output(l_ext_input)
        predicted_lemmas = lemma_logits.argmax(-1)

        # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ВАРИАНТОВ ЛЕММАТИЗАЦИИ И ОЦЕНОК ВЕРОЯТНОСТИ
        lemma_probs = torch.nn.functional.softmax(lemma_logits, -1)
        top_lemmas_indices = (-lemma_logits).argsort(-1)[:, :, :self.TopNCnt]
        #top_lemmas_indices = (-lemma_probs).argsort(-1)[:,:,:self.TopNCnt]
        top_lemmas_prob = torch.gather(lemma_probs, -1, top_lemmas_indices)
        #top_lemmas_prob = torch.gather(lemma_logits, -1, top_lemmas_indices)

        # АНАЛОГИЧНО ДЛЯ ГРАММЕМ
        gramm_probs = torch.nn.functional.softmax(grammar_value_logits, -1)
        top_gramms_indices = (
            -grammar_value_logits).argsort(-1)[:, :, :self.TopNCnt]
        top_gramms_prob = torch.gather(gramm_probs, -1, top_gramms_indices)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            synt_prediction, benrg = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
            predicted_heads, predicted_head_tags = synt_prediction

        # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ЛОКАЛЬНЫХ!!! (не mst) ВАРИАНТОВ СИНТАКСИЧЕСКОГО РАЗБОРА И ОЦЕНОК ВЕРОЯТНОСИ
        benrgf = torch.flatten(benrg, start_dim=1, end_dim=2).permute(
            0, 2, 1)  # склеивает тип синт. отношения с индексом родителя
        top_deprels_indices = (-benrgf).argsort(
            -1)[:, :, :self.TopNCnt]  # отбираем наилучшие комбинации
        top_deprels_prob = torch.gather(benrgf, -1, top_deprels_indices)
        seqlen = benrg.shape[2]
        top_heads = torch.fmod(top_deprels_indices, seqlen)
        top_deprels_indices = torch.div(top_deprels_indices,
                                        seqlen)  # torch.floor не срабатывает

        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            grammar_nll = self._update_multiclass_prediction_metrics(
                logits=grammar_value_logits,
                targets=grammar_values,
                mask=token_mask,
                accuracy_metric=self._gram_val_prediction_accuracy)

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            lemma_nll = self._update_multiclass_prediction_metrics(
                logits=lemma_logits,
                targets=lemma_indices,
                mask=token_mask,
                accuracy_metric=self._lemma_prediction_accuracy,
                masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX)

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
            "top_lemmas": top_lemmas_indices,
            "top_lemmas_prob": top_lemmas_prob,
            "top_gramms": top_gramms_indices,
            "top_gramms_prob": top_gramms_prob,
            "top_heads": top_heads,
            "top_deprels": top_deprels_indices,
            "top_deprels_prob": top_deprels_prob,
        }

        return output_dict

    @staticmethod
    def _update_multiclass_prediction_metrics(logits,
                                              targets,
                                              mask,
                                              accuracy_metric,
                                              masked_index=None):
        accuracy_metric(logits, targets, mask)

        logits = logits.view(-1, logits.shape[-1])
        loss = F.cross_entropy(logits, targets.view(-1), reduction='none')
        if masked_index is not None:
            mask = mask * (targets != masked_index)
        loss_mask = mask.view(-1)
        return (loss * loss_mask).sum() / loss_mask.sum()

    def _construct_loss(
            self, head_tag_representation: torch.Tensor,
            child_tag_representation: torch.Tensor,
            attended_arcs: torch.Tensor, head_indices: torch.Tensor,
            head_tags: torch.Tensor,
            mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(
            attended_arcs,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        normalised_head_tag_logits = masked_log_softmax(
            head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll

    def _greedy_decode(
            self, head_tag_representation: torch.Tensor,
            child_tag_representation: torch.Tensor,
            attended_arcs: torch.Tensor,
            mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs independently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(
            attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation, heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags

    def _mst_decode(self, head_tag_representation: torch.Tensor,
                    child_tag_representation: torch.Tensor,
                    attended_arcs: torch.Tensor,
                    mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size(
        )

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [
            batch_size, sequence_length, sequence_length,
            tag_representation_dim
        ]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(
            *expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(
            *expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation,
                                                 child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits,
                                                        dim=3).permute(
                                                            0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs,
                                              dim=2).transpose(1, 2)

        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.
        batch_energy = torch.exp(
            normalized_arc_logits.unsqueeze(1) +
            normalized_pairwise_head_logits)
        return self._run_mst_decoding(
            batch_energy, lengths
        ), batch_energy  #normalized_pairwise_head_logits, normalized_arc_logits

    @staticmethod
    def _run_mst_decoding(
            batch_energy: torch.Tensor,
            lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(),
                                           length,
                                           has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(
            numpy.stack(head_tags))

    def _get_head_tags(self, head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(
            batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[
            range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous(
        )
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits

    def _get_mask_for_eval(self, mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = self._attachment_scores.get_metric(reset)
        metrics['GramValAcc'] = self._gram_val_prediction_accuracy.get_metric(
            reset)
        metrics['LemmaAcc'] = self._lemma_prediction_accuracy.get_metric(reset)
        metrics['MeanAcc'] = (metrics['GramValAcc'] + metrics['LemmaAcc'] +
                              metrics['LAS']) / 3.

        return metrics
Esempio n. 23
0
class Transduction(Model):

    def __init__(self,
                 vocab: Vocabulary,
                 # source-side
                 bert_encoder: Seq2SeqBertEncoder,
                 encoder_token_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 # target-side
                 decoder_token_embedder: TextFieldEmbedder,
                 decoder_node_index_embedding: Embedding,
                 decoder: RNNDecoder,
                 extended_pointer_generator: ExtendedPointerGenerator,
                 tree_parser: DeepTreeParser,
                 # misc
                 label_smoothing: LabelSmoothing,
                 target_output_namespace: str,
                 dropout: float = 0.0,
                 eps: float = 1e-20,
                 pretrained_weights: str = None,
                 ) -> None:
        super().__init__(vocab=vocab)
        # source-side
        self._bert_encoder = bert_encoder
        self._encoder_token_embedder = encoder_token_embedder
        self._encoder = encoder

        # target-side
        self._decoder_token_embedder = decoder_token_embedder
        self._decoder_node_index_embedding = decoder_node_index_embedding
        self._decoder = decoder
        self._extended_pointer_generator = extended_pointer_generator
        self._tree_parser = tree_parser

        # metrics
        self._node_pred_metrics = ExtendedPointerGeneratorMetrics()
        self._edge_pred_metrics = AttachmentScores()
        self._synt_edge_pred_metrics = AttachmentScores()

        self._label_smoothing = label_smoothing
        self._dropout = InputVariationalDropout(p=dropout)
        self._eps = eps

        # dynamic initialization
        self._target_output_namespace = target_output_namespace
        self._vocab_size = self.vocab.get_vocab_size(target_output_namespace)
        self._vocab_pad_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN, target_output_namespace)

        # loading partial weights
        self.pretrained_weights = pretrained_weights

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        node_pred_metrics = self._node_pred_metrics.get_metric(reset)
        edge_pred_metrics = self._edge_pred_metrics.get_metric(reset)
        synt_edge_pred_metrics = self._synt_edge_pred_metrics.get_metric(reset)
        metrics = OrderedDict(
            ppl=node_pred_metrics["ppl"],
            node_pred=node_pred_metrics["accuracy"] * 100,
            generate=node_pred_metrics["generate"] * 100,
            src_copy=node_pred_metrics["src_copy"] * 100,
            tgt_copy=node_pred_metrics["tgt_copy"] * 100,
            uas=edge_pred_metrics["UAS"] * 100,
            las=edge_pred_metrics["LAS"] * 100,
        )
        return metrics

    @overrides
    def forward(self, **raw_inputs: Dict) -> Dict:
        inputs = self._prepare_inputs(raw_inputs)
        if self.training:
            return self._training_forward(inputs)
        else:
            return self._test_forward(inputs)

    def _compute_edge_prediction_loss(self,
                                      edge_head_ll: torch.Tensor,
                                      edge_type_ll: torch.Tensor,
                                      pred_edge_heads: torch.Tensor,
                                      pred_edge_types: torch.Tensor,
                                      gold_edge_heads: torch.Tensor,
                                      gold_edge_types: torch.Tensor,
                                      valid_node_mask: torch.Tensor,
                                      syntax: bool = False) -> Dict:
        """
        Compute the edge prediction loss.

        :param edge_head_ll: [batch_size, target_length, target_length + 1 (for sentinel)].
        :param edge_type_ll: [batch_size, target_length, num_labels].
        :param pred_edge_heads: [batch_size, target_length].
        :param pred_edge_types: [batch_size, target_length].
        :param gold_edge_heads: [batch_size, target_length].
        :param gold_edge_types: [batch_size, target_length].
        :param valid_node_mask: [batch_size, target_length].
        """

        # Index the log-likelihood (ll) of gold edge heads and types.
        batch_size, target_length, _ = edge_head_ll.size()
        batch_indices = torch.arange(0, batch_size).view(batch_size, 1).type_as(gold_edge_heads)
        node_indices = torch.arange(0, target_length).view(1, target_length) \
            .expand(batch_size, target_length).type_as(gold_edge_heads)

        gold_edge_head_ll = edge_head_ll[batch_indices, node_indices, gold_edge_heads]
        gold_edge_type_ll = edge_type_ll[batch_indices, node_indices, gold_edge_types]
        # Set the ll of invalid nodes to 0.
        num_nodes = valid_node_mask.sum().float()

        if not syntax: 
            # don't incur loss on EOS/SOS token
            valid_node_mask[gold_edge_heads == -1] = 0

        valid_node_mask = valid_node_mask.bool()
        gold_edge_head_ll.masked_fill_(~valid_node_mask, 0)
        gold_edge_type_ll.masked_fill_(~valid_node_mask, 0)

        # Negative log-likelihood.
        loss = -(gold_edge_head_ll.sum() + gold_edge_type_ll.sum())
        # Update metrics.
        if self.training and not syntax:
            self._edge_pred_metrics(
                predicted_indices=pred_edge_heads,
                predicted_labels=pred_edge_types,
                gold_indices=gold_edge_heads,
                gold_labels=gold_edge_types,
                mask=valid_node_mask
            )

        elif self.training and syntax:
            self._syntax_metrics(
                predicted_indices=pred_edge_heads,
                predicted_labels=pred_edge_types,
                gold_indices=gold_edge_heads,
                gold_labels=gold_edge_types,
                mask=valid_node_mask
            )

        return dict(
            loss=loss,
            num_nodes=num_nodes,
            loss_per_node=loss / num_nodes,
        )

    def _compute_node_prediction_loss(self,
                                      prob_dist: torch.Tensor,
                                      generation_outputs: torch.Tensor,
                                      source_copy_indices: torch.Tensor,
                                      target_copy_indices: torch.Tensor,
                                      source_dynamic_vocab_size: int,
                                      source_attention_weights: torch.Tensor = None,
                                      coverage_history: torch.Tensor = None) -> Dict:
        """
        Compute the node prediction loss based on the final hybrid probability distribution.

        :param prob_dist: probability distribution,
            [batch_size, target_length, vocab_size + source_dynamic_vocab_size + target_dynamic_vocab_size].
        :param generation_outputs: generated node indices in the pre-defined vocabulary,
            [batch_size, target_length].
        :param source_copy_indices:  source-side copied node indices in the source dynamic vocabulary,
            [batch_size, target_length].
        :param target_copy_indices:  target-side copied node indices in the source dynamic vocabulary,
            [batch_size, target_length].
        :param source_dynamic_vocab_size: int.
        :param source_attention_weights: None or [batch_size, target_length, source_length].
        :param coverage_history: None or a tensor recording the source-side coverage history.
            [batch_size, target_length, source_length].
        """
        _, prediction = prob_dist.max(2)

        batch_size, target_length = prediction.size()
        not_pad_mask = generation_outputs.ne(self._vocab_pad_index)
        num_nodes = not_pad_mask.sum()

        # Priority: target_copy > source_copy > generation
        # Prepare mask.
        valid_target_copy_mask = target_copy_indices.ne(0) & not_pad_mask  # 0 for sentinel.
        valid_source_copy_mask = (~valid_target_copy_mask & not_pad_mask &
                                  source_copy_indices.ne(1) & source_copy_indices.ne(0))  # 1 for unk; 0 for pad.
        valid_generation_mask = ~(valid_target_copy_mask | valid_source_copy_mask) & not_pad_mask
        # Prepare hybrid targets.
        _target_copy_indices = ((target_copy_indices + self._vocab_size + source_dynamic_vocab_size) *
                                valid_target_copy_mask.long())
        _source_copy_indices = (source_copy_indices + self._vocab_size) * valid_source_copy_mask.long()

        _generation_outputs = generation_outputs * valid_generation_mask.long()
        hybrid_targets = _target_copy_indices + _source_copy_indices + _generation_outputs

        # Compute loss.
        log_prob_dist = (prob_dist.view(batch_size * target_length, -1) + self._eps).log()
        flat_hybrid_targets = hybrid_targets.view(batch_size * target_length)
        loss = self._label_smoothing(log_prob_dist, flat_hybrid_targets)

        # Coverage loss.
        if coverage_history is not None:
            #coverage_loss = torch.sum(torch.min(coverage_history.unsqueeze(-1), source_attention_weights), 2)
            coverage_loss = torch.sum(torch.min(coverage_history, source_attention_weights), 2)
            coverage_loss = (coverage_loss * not_pad_mask.float()).sum()
            loss = loss + coverage_loss

        # Update metric stats.
        self._node_pred_metrics(
            loss=loss,
            prediction=prediction,
            generation_outputs=_generation_outputs,
            valid_generation_mask=valid_generation_mask,
            source_copy_indices=_source_copy_indices,
            valid_source_copy_mask=valid_source_copy_mask,
            target_copy_indices=_target_copy_indices,
            valid_target_copy_mask=valid_target_copy_mask
        )

        return dict(
            loss=loss,
            num_nodes=num_nodes,
            loss_per_node=loss / num_nodes,
        )

    def _decode(self,
                tokens: Dict[str, torch.Tensor],
                node_indices: torch.Tensor,
                encoder_outputs: torch.Tensor,
                hidden_states: Tuple[torch.Tensor, torch.Tensor],
                mask: torch.Tensor,
                **kwargs) -> Dict:

        # [batch, num_tokens, embedding_size]
        decoder_inputs = torch.cat([
            self._decoder_token_embedder(tokens),
            self._decoder_node_index_embedding(node_indices),
        ], dim=2)
        decoder_inputs = self._dropout(decoder_inputs)

        decoder_outputs = self._decoder(
            inputs=decoder_inputs,
            source_memory_bank=encoder_outputs,
            source_mask=mask,
            hidden_state=hidden_states
        )

        return decoder_outputs

    def _encode(self,
                tokens: Dict[str, torch.Tensor],
                subtoken_ids: torch.Tensor,
                token_recovery_matrix: torch.Tensor,
                mask: torch.Tensor,
                **kwargs) -> Dict:
        # [batch, num_tokens, embedding_size]
        encoder_inputs = [self._encoder_token_embedder(tokens)]
        if subtoken_ids is not None and self._bert_encoder is not None:
            bert_embeddings = self._bert_encoder(
                input_ids=subtoken_ids,
                attention_mask=subtoken_ids.ne(0),
                output_all_encoded_layers=False,
                token_recovery_matrix=token_recovery_matrix
            ).detach()

            encoder_inputs += [bert_embeddings]

        encoder_inputs = torch.cat(encoder_inputs, 2)
        encoder_inputs = self._dropout(encoder_inputs)

        # [batch, num_tokens, encoder_output_size]
        encoder_outputs = self._encoder(encoder_inputs, mask)
        encoder_outputs = self._dropout(encoder_outputs)
        # A tuple of (state, memory) with shape [num_layers, batch, encoder_output_size]
        encoder_final_states = self._encoder.get_final_states()
        self._encoder.reset_states()

        return dict(
            encoder_outputs=encoder_outputs,
            final_states=encoder_final_states
        )

    def _parse(self,
               rnn_outputs: torch.Tensor,
               edge_head_mask: torch.Tensor,
               edge_heads: torch.Tensor = None) -> Dict:
        """
        Based on the vector representation for each node, predict its head and edge type.
        :param rnn_outputs: vector representations of nodes, including <BOS>.
            [batch_size, target_length + 1, hidden_vector_dim].
        :param edge_head_mask: mask used in the edge head search.
            [batch_size, target_length, target_length].
        :param edge_heads: None or gold head indices, [batch_size, target_length]
        """
        # Exclude <BOS>.
        # <EOS> is already excluded in ``_prepare_inputs''.
        rnn_outputs = self._dropout(rnn_outputs[:, 1:])
        parser_outputs = self._tree_parser(
            query=rnn_outputs,
            key=rnn_outputs,
            edge_head_mask=edge_head_mask,
            gold_edge_heads=edge_heads
        )
        return parser_outputs

    def _prepare_inputs(self, raw_inputs: Dict) -> Dict:
        return raw_inputs

    def _test_forward(self, inputs: Dict) -> Dict:
        raise NotImplementedError

    def _training_forward(self, inputs: Dict) -> Dict[str, torch.Tensor]:
        encoding_outputs = self._encode(
            tokens=inputs["source_tokens"],
            subtoken_ids=inputs["source_subtoken_ids"],
            token_recovery_matrix=inputs["source_token_recovery_matrix"],
            mask=inputs["source_mask"]
        )
        decoding_outputs = self._decode(
            tokens=inputs["target_tokens"],
            node_indices=inputs["target_node_indices"],
            encoder_outputs=encoding_outputs["encoder_outputs"],
            hidden_states=encoding_outputs["final_states"],
            mask=inputs["source_mask"]
        )
        node_prediction_outputs = self._extended_pointer_generator(
            inputs=decoding_outputs["attentional_tensors"],
            source_attention_weights=decoding_outputs["source_attention_weights"],
            target_attention_weights=decoding_outputs["target_attention_weights"],
            source_attention_map=inputs["source_attention_map"],
            target_attention_map=inputs["target_attention_map"]
        )
        edge_prediction_outputs = self._parse(
            rnn_outputs=decoding_outputs["rnn_outputs"],
            edge_head_mask=inputs["edge_head_mask"],
            edge_heads=inputs["edge_heads"]
        )
        node_pred_loss = self._compute_node_prediction_loss(
            prob_dist=node_prediction_outputs["hybrid_prob_dist"],
            generation_outputs=inputs["generation_outputs"],
            source_copy_indices=inputs["source_copy_indices"],
            target_copy_indices=inputs["target_copy_indices"],
            source_dynamic_vocab_size=inputs["source_dynamic_vocab_size"],
            source_attention_weights=decoding_outputs["source_attention_weights"],
            coverage_history=decoding_outputs["coverage_history"]
        )
        edge_pred_loss = self._compute_edge_prediction_loss(
            edge_head_ll=edge_prediction_outputs["edge_head_ll"],
            edge_type_ll=edge_prediction_outputs["edge_type_ll"],
            pred_edge_heads=edge_prediction_outputs["edge_heads"],
            pred_edge_types=edge_prediction_outputs["edge_types"],
            gold_edge_heads=inputs["edge_heads"],
            gold_edge_types=inputs["edge_types"],
            valid_node_mask=inputs["valid_node_mask"]
        )
        loss = node_pred_loss["loss_per_node"] + edge_pred_loss["loss_per_node"]
        return dict(loss=loss)

    def load_partial(self, param_file: str): 
        """
        loads weights and matches the ones it can 
        """
        logger.info(f"Attempting to load pretrained weights from {param_file}") 
        pretrained_state_dict = torch.load(param_file)
        current_state_dict = self.state_dict() 
        for k, v in pretrained_state_dict.items():
            if isinstance(v, torch.nn.Parameter):
                v = v.data
            try:
                current_state_dict[k].copy_(v)
                print(f"matched {k}")
                logger.info(f"matched {k}")
            except RuntimeError:
                new_shape = pretrained_state_dict[k].shape
                og_shape = current_state_dict[k].shape
                print(f"Unable to match {k} due to shape error: pretrained: {new_shape} vs original: {og_shape}") 
                logger.warning(f"Unable to match {k} due to shape error: pretrained: {new_shape} vs original: {og_shape}") 
                continue 
            except KeyError:
                logger.warning(f"Unable to match {k} because it does not exist in original model") 
                print(f"Unable to match {k} because it does not exist in original model") 
                continue

        key = "biaffine_parser.edge_type_query_linear.weight"
        self.load_state_dict(current_state_dict) 
class BiaffineDependencyParser(Model):
    """
    This dependency parser follows the model of
    ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
    <https://arxiv.org/abs/1611.01734>`_ .

    Word representations are generated using a bidirectional LSTM,
    followed by separate biaffine classifiers for pairs of words,
    predicting whether a directed arc exists between the two words
    and the dependency label the arc should have. Decoding can either
    be done greedily, or the optimial Minimum Spanning Tree can be
    decoded using Edmond's algorithm by viewing the dependency tree as
    a MST on a fully connected graph, where nodes are words and edges
    are scored dependency arcs.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    encoder : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use to generate representations
        of tokens.
    tag_representation_dim : ``int``, required.
        The dimension of the MLPs used for dependency tag prediction.
    arc_representation_dim : ``int``, required.
        The dimension of the MLPs used for head arc prediction.
    tag_feedforward : ``FeedForward``, optional, (default = None).
        The feedforward network used to produce tag representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    arc_feedforward : ``FeedForward``, optional, (default = None).
        The feedforward network used to produce arc representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    pos_tag_embedding : ``Embedding``, optional.
        Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model.
    use_mst_decoding_for_validation : ``bool``, optional (default = True).
        Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation.
        If false, decoding is greedy.
    dropout : ``float``, optional, (default = 0.0)
        The variational dropout applied to the output of the encoder and MLP layers.
    input_dropout : ``float``, optional, (default = 0.0)
        The dropout applied to the embedded text input.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            words: Dict[str, torch.LongTensor],
            pos_tags: torch.LongTensor,
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required.
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        loss : ``torch.FloatTensor``, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:,
                                                        1:], head_indices[:,
                                                                          1:],
                                    head_tags[:, 1:], evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask
        }

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:

        head_tags = output_dict["head_tags"].cpu().detach().numpy()
        heads = output_dict["heads"].cpu().detach().numpy()
        lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"])
        head_tag_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(
                heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [
                self.vocab.get_token_from_index(label, "head_tags")
                for label in instance_tags
            ]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_dependencies"] = head_tag_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict

    def _construct_loss(
            self, head_tag_representation: torch.Tensor,
            child_tag_representation: torch.Tensor,
            attended_arcs: torch.Tensor, head_indices: torch.Tensor,
            head_tags: torch.Tensor,
            mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = last_dim_log_softmax(
            attended_arcs,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        normalised_head_tag_logits = last_dim_log_softmax(
            head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll

    def _greedy_decode(
            self, head_tag_representation: torch.Tensor,
            child_tag_representation: torch.Tensor,
            attended_arcs: torch.Tensor,
            mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs indpendently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(
            attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).byte().unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation, heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags

    def _mst_decode(self, head_tag_representation: torch.Tensor,
                    child_tag_representation: torch.Tensor,
                    attended_arcs: torch.Tensor,
                    mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size(
        )

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [
            batch_size, sequence_length, sequence_length,
            tag_representation_dim
        ]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(
            *expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(
            *expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation,
                                                 child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits,
                                                        dim=3).permute(
                                                            0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs,
                                              dim=2).transpose(1, 2)
        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        batch_energy = torch.exp(
            normalized_arc_logits.unsqueeze(1) +
            normalized_pairwise_head_logits)

        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu().numpy(),
                                  lengths):
            head, head_tag = decode_mst(energy, length)
            heads.append(head)
            head_tags.append(head_tag)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(
            numpy.stack(head_tags))

    def _get_head_tags(self, head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(
            batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[
            range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous(
        )
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits

    def _get_mask_for_eval(self, mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return self._attachment_scores.get_metric(reset)
Esempio n. 25
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder_0: Seq2SeqEncoder,
                 encoder_1: Seq2SeqEncoder,
                 encoder_2: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 use_layer_normalization: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        a = vocab.get_index_to_token_vocabulary(namespace='tokens')
        # glyph_config['idx2word'] = {k: v for k, v in a.items()}

        # self.glyph = GlyphEmbedding(glyph_config)

        self.text_field_embedder = text_field_embedder

        self.encoder_0 = encoder_0
        self.encoder_1 = encoder_1
        self.encoder_2 = encoder_2

        encoder_dim = self.encoder_2.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        # self._dropout = Dropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, self.encoder_2.get_output_dim()]))

        self.use_layer_normalization = use_layer_normalization

        if use_layer_normalization:
            self.norm_input = torch.nn.LayerNorm(
                self.encoder_0.get_input_dim())
            self.norm_hidden = torch.nn.LayerNorm(
                self.encoder_0.get_output_dim())

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        # check_dimensions_match(representation_dim, encoder.get_input_dim(),
        #                        "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 treebank_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 use_treebank_embedding: bool = False,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParserMonolingual, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._treebank_embedding = treebank_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        if treebank_embedding is not None:
            representation_dim += treebank_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim", "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim", "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation
        self.use_treebank_embedding = use_treebank_embedding

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")
        
        if self.use_treebank_embedding:
            tbids = self.vocab.get_token_to_index_vocabulary("tbids")
            tbid_indices = {tb: index for tb, index in tbids.items()}
            self._tbids = set(tbid_indices.values())
            logger.info(f"Found TBIDs corresponding to the following treebanks : {tbid_indices}. "
                        "Embedding these as additional features.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        tag_representation_dim: int,
        arc_representation_dim: int,
        model_name: str = None,
        tag_feedforward: FeedForward = None,
        arc_feedforward: FeedForward = None,
        pos_tag_embedding: Embedding = None,
        use_mst_decoding_for_validation: bool = True,
        dropout: float = 0.0,
        input_dropout: float = 0.0,
        word_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        if model_name:
            from src.data.token_indexers import PretrainedAutoTokenizer
            self._tokenizer = PretrainedAutoTokenizer.load(model_name)

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or FeedForward(
            encoder_dim, 1, arc_representation_dim,
            Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or FeedForward(
            encoder_dim, 1, tag_representation_dim,
            Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._word_dropout = word_dropout
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(
            representation_dim,
            encoder.get_input_dim(),
            "text field embedding dim",
            "encoder input dim",
        )

        check_dimensions_match(
            tag_representation_dim,
            self.head_tag_feedforward.get_output_dim(),
            "tag representation dim",
            "tag feedforward output dim",
        )
        check_dimensions_match(
            arc_representation_dim,
            self.head_arc_feedforward.get_output_dim(),
            "arc representation dim",
            "arc feedforward output dim",
        )

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Esempio n. 28
0
class AMTask(Model):
    """
    A class that implements a task-specific model. It conceptually belongs to a formalism or corpus.
    """
    loss_names = ["edge_existence", "edge_label", "supertagging", "lexlabel"]

    def __init__(self,
                 vocab: Vocabulary,
                 name: str,
                 edge_model: EdgeModel,
                 loss_function: EdgeLoss,
                 supertagger: Supertagger,
                 lexlabeltagger: Supertagger,
                 supertagger_loss: SupertaggingLoss,
                 lexlabel_loss: SupertaggingLoss,
                 output_null_lex_label: bool = True,
                 loss_mixing: Dict[str, float] = None,
                 dropout: float = 0.0,
                 validation_evaluator: Optional[Evaluator] = None,
                 regularizer: Optional[RegularizerApplicator] = None):

        super().__init__(vocab, regularizer)
        self.name = name
        self.edge_model = edge_model
        self.supertagger = supertagger
        self.lexlabeltagger = lexlabeltagger
        self.supertagger_loss = supertagger_loss
        self.lexlabel_loss = lexlabel_loss
        self.loss_function = loss_function
        self.loss_mixing = loss_mixing or dict()
        self.validation_evaluator = validation_evaluator
        self.output_null_lex_label = output_null_lex_label

        self._dropout = InputVariationalDropout(dropout)

        for loss_name in AMTask.loss_names:
            if loss_name not in self.loss_mixing:
                self.loss_mixing[loss_name] = 1.0
                logger.info(
                    f"Loss name {loss_name} not found in loss_mixing, using a weight of 1.0"
                )
            else:
                if self.loss_mixing[loss_name] is None:
                    if loss_name not in ["supertagging", "lexlabel"]:
                        raise ConfigurationError(
                            "Only the loss mixing coefficients for supertagging and lexlabel may be None, but not "
                            + loss_name)

        not_contained = set(self.loss_mixing.keys()) - set(AMTask.loss_names)
        if len(not_contained):
            logger.critical(
                f"The following loss name(s) are unknown: {not_contained}")
            raise ValueError(
                f"The following loss name(s) are unknown: {not_contained}")

        self._supertagging_acc = CategoricalAccuracy()
        self._top_6supertagging_acc = CategoricalAccuracy(top_k=6)
        self._lexlabel_acc = CategoricalAccuracy()
        self._attachment_scores = AttachmentScores()
        self.current_epoch = 0

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self.compute_softmax_for_scores = False  # set to true when dumping scores to incorporate softmax computation into computation time

    def check_all_dimensions_match(self, encoder_output_dim):

        check_dimensions_match(encoder_output_dim,
                               self.edge_model.encoder_dim(),
                               "encoder output dim",
                               self.name + " input dim edge model")
        check_dimensions_match(encoder_output_dim,
                               self.supertagger.encoder_dim(),
                               "encoder output dim",
                               self.name + " supertagger input dim")
        check_dimensions_match(encoder_output_dim,
                               self.lexlabeltagger.encoder_dim(),
                               "encoder output dim",
                               self.name + " lexical label tagger input dim")

    @overrides
    def forward(
            self,  # type: ignore
            encoded_text_parsing: torch.Tensor,
            encoded_text_tagging: torch.Tensor,
            mask: torch.Tensor,
            pos_tags: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            supertags: torch.LongTensor = None,
            lexlabels: torch.LongTensor = None,
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        """
        Takes a batch of encoded sentences and returns a dictionary with loss and predictions.

        :param encoded_text_parsing: sentence representation of shape (batch_size, seq_len, encoder_output_dim)
        :param encoded_text_tagging: sentence representation of shape (batch_size, seq_len, encoder_output_dim)
            or None if formalism of batch doesn't need supertagging
        :param mask:  matching the sentence representation of shape (batch_size, seq_len)
        :param pos_tags: the accompanying pos tags (batch_size, seq_len)
        :param metadata:
        :param supertags: the accompanying supertags (batch_size, seq_len)
        :param lexlabels: the accompanying lexical labels (batch_size, seq_len)
        :param head_tags: the gold heads of every word (batch_size, seq_len)
        :param head_indices: the gold edge labels for each word (incoming edge, see amconll files) (batch_size, seq_len)
        :return:
        """

        encoded_text_parsing = self._dropout(encoded_text_parsing)
        if encoded_text_tagging is not None:
            encoded_text_tagging = self._dropout(encoded_text_tagging)

        batch_size, seq_len, _ = encoded_text_parsing.shape

        edge_existence_scores = self.edge_model.edge_existence(
            encoded_text_parsing, mask)  # shape (batch_size, seq_len, seq_len)
        # shape (batch_size, seq_len, num_supertags)
        if encoded_text_tagging is not None and self.supertagger is not None and self.lexlabeltagger is not None\
                and self.loss_mixing["supertagging"] is not None\
                and self.loss_mixing["lexlabel"] is not None:
            supertagger_logits = self.supertagger.compute_logits(
                encoded_text_tagging)
            lexlabel_logits = self.lexlabeltagger.compute_logits(
                encoded_text_tagging
            )  # shape (batch_size, seq_len, num label tags)
        else:
            supertagger_logits = None
            lexlabel_logits = None

        # Make predictions on data:
        if self.training:
            predicted_heads = self._greedy_decode_arcs(edge_existence_scores,
                                                       mask)
            edge_label_logits = self.edge_model.label_scores(
                encoded_text_parsing, predicted_heads
            )  # shape (batch_size, seq_len, num edge labels)
            predicted_edge_labels = self._greedy_decode_edge_labels(
                edge_label_logits)
        else:
            # Find best tree with CLE
            predicted_heads = cle_decode(edge_existence_scores,
                                         mask.data.sum(dim=1).long())
            # With info about tree structure, get edge label scores
            edge_label_logits = self.edge_model.label_scores(
                encoded_text_parsing, predicted_heads)
            # Predict edge labels
            predicted_edge_labels = self._greedy_decode_edge_labels(
                edge_label_logits)

        output_dict = {
            "heads":
            predicted_heads,
            "edge_existence_scores":
            edge_existence_scores,
            "label_logits":
            edge_label_logits,  # shape (batch_size, seq_len, num edge labels)
            "full_label_logits":
            self.edge_model.full_label_scores(
                encoded_text_parsing
            ),  #these are mostly required for the projective decoder
            "mask":
            mask,
            "words": [meta["words"] for meta in metadata],
            "attributes": [meta["attributes"] for meta in metadata],
            "token_ranges": [meta["token_ranges"] for meta in metadata],
            "encoded_text_parsing":
            encoded_text_parsing,
            "encoded_text_tagging":
            encoded_text_tagging,
            "position_in_corpus":
            [meta["position_in_corpus"] for meta in metadata],
            "formalism":
            self.name
        }

        if encoded_text_tagging is not None and self.loss_mixing[
                "supertagging"] is not None:
            output_dict[
                "supertag_scores"] = supertagger_logits  # shape (batch_size, seq_len, num supertags)
            output_dict["best_supertags"] = Supertagger.top_k_supertags(
                supertagger_logits,
                1).squeeze(2)  # shape (batch_size, seq_len)

        if encoded_text_tagging is not None and self.loss_mixing[
                "lexlabel"] is not None:
            if not self.output_null_lex_label:
                bottom_lex_label_index = self.vocab.get_token_index(
                    "_", namespace=self.name + "_lex_labels")
                masked_lexlabel_logits = lexlabel_logits.clone().detach(
                )  # shape (batch_size, seq_len, num label tags)
                masked_lexlabel_logits[:, :, bottom_lex_label_index] = -1e20
            else:
                masked_lexlabel_logits = lexlabel_logits

            output_dict["lexlabels"] = Supertagger.top_k_supertags(
                masked_lexlabel_logits,
                1).squeeze(2)  # shape (batch_size, seq_len)

        is_annotated = metadata[0]["is_annotated"]
        if any(metadata[i]["is_annotated"] != is_annotated
               for i in range(batch_size)):
            raise ValueError(
                "Batch contained inconsistent information if data is annotated."
            )

        # Compute loss:
        if is_annotated and head_indices is not None and head_tags is not None:
            gold_edge_label_logits = self.edge_model.label_scores(
                encoded_text_parsing, head_indices)
            edge_label_loss = self.loss_function.label_loss(
                gold_edge_label_logits, mask, head_tags)
            edge_existence_loss = self.loss_function.edge_existence_loss(
                edge_existence_scores, head_indices, mask)

            # compute loss, remove loss for artificial root
            if encoded_text_tagging is not None and self.loss_mixing[
                    "supertagging"] is not None:
                supertagger_logits = supertagger_logits[:, 1:, :].contiguous()
                supertagging_nll = self.supertagger_loss.loss(
                    supertagger_logits, supertags, mask[:, 1:])
            else:
                supertagging_nll = None

            if encoded_text_tagging is not None and self.loss_mixing[
                    "lexlabel"] is not None:
                lexlabel_logits = lexlabel_logits[:, 1:, :].contiguous()
                lexlabel_nll = self.lexlabel_loss.loss(lexlabel_logits,
                                                       lexlabels, mask[:, 1:])
            else:
                lexlabel_nll = None

            loss = mix_loss(self.loss_mixing["edge_existence"],
                            edge_existence_loss) + mix_loss(
                                self.loss_mixing["edge_label"],
                                edge_label_loss)

            if supertagging_nll is not None:
                loss += mix_loss(self.loss_mixing["supertagging"],
                                 supertagging_nll)
            if lexlabel_nll is not None:
                loss += mix_loss(self.loss_mixing["lexlabel"], lexlabel_nll)

            # Compute LAS/UAS/Supertagging acc/Lex label acc:
            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            if edge_existence_loss is not None and edge_label_loss is not None:
                self._attachment_scores(predicted_heads[:, 1:],
                                        predicted_edge_labels[:, 1:],
                                        head_indices[:, 1:], head_tags[:, 1:],
                                        evaluation_mask)
            evaluation_mask = mask[:, 1:].contiguous()
            if supertagging_nll is not None:
                self._top_6supertagging_acc(supertagger_logits, supertags,
                                            evaluation_mask)
                self._supertagging_acc(
                    supertagger_logits, supertags,
                    evaluation_mask)  # compare against gold data
            if lexlabel_nll is not None:
                self._lexlabel_acc(
                    lexlabel_logits, lexlabels,
                    evaluation_mask)  # compare against gold data

            output_dict["arc_loss"] = edge_existence_loss
            output_dict["edge_label_loss"] = edge_label_loss
            output_dict["supertagging_loss"] = supertagging_nll
            output_dict["lexlabel_loss"] = lexlabel_nll
            output_dict["loss"] = loss

        if self.compute_softmax_for_scores:
            # We don't use the results but we want it to be included in the time measurement
            # See dump_scores what part of computation is done outside of the time measurement in forward()
            F.log_softmax(output_dict["full_label_logits"], 3)
            F.log_softmax(output_dict["supertag_scores"], 2)
            torch.argsort(output_dict["supertag_scores"],
                          descending=True,
                          dim=2)

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]):
        """
        In contrast to its name, this function does not perform the decoding but only prepares it.
        :param output_dict:
        :return:
        """
        if self.supertagger is not None and self.loss_mixing[
                "supertagging"] is not None:  #we have a supertagger, so this is proper AM dependency parsing
            return self.prepare_for_ftd(output_dict)
        else:  #we don't have a supertagger, perform good old dependency parsing
            return self.only_cle(output_dict)

    def only_cle(self, output_dict: Dict[str, torch.Tensor]):
        """
        Therefore, we take the result of forward and perform the following steps (for each sentence in batch):
        - remove padding
        :param output_dict: result of forward
        :return: output_dict with the following keys added:
            - lexlabels: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root)
            - supertags: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root)
        """
        full_label_logits = output_dict.pop("full_label_logits").cpu().detach(
        ).numpy()  #shape (batch size, seq len, seq len, num edge labels)
        edge_existence_scores = output_dict.pop(
            "edge_existence_scores").cpu().detach().numpy(
            )  #shape (batch size, seq len, seq len, num edge labels)
        heads = output_dict.pop("heads")
        heads_cpu = heads.cpu().detach().numpy()
        mask = output_dict.pop("mask")
        edge_label_logits = output_dict.pop("label_logits").cpu().detach(
        ).numpy()  # shape (batch_size, seq_len, num edge labels)

        output_dict.pop("encoded_text_parsing")
        output_dict.pop("encoded_text_tagging")  #don't need that

        lengths = get_lengths_from_binary_sequence_mask(mask)

        #here we collect things, in the end we will have one entry for each sentence:
        all_edge_label_logits = []
        head_indices = []
        all_full_label_logits = []
        all_edge_existence_scores = []

        for i, length in enumerate(lengths):
            instance_heads_cpu = list(heads_cpu[i, 1:length])
            #apply changes to instance_heads tensor:
            instance_heads = heads[i, :]
            for j, x in enumerate(instance_heads_cpu):
                instance_heads[j + 1] = torch.tensor(
                    x
                )  #+1 because we removed the first position from instance_heads_cpu

            all_edge_label_logits.append(edge_label_logits[i, 1:length, :])

            all_full_label_logits.append(
                full_label_logits[i, :length, :length, :])
            all_edge_existence_scores.append(
                edge_existence_scores[i, :length, :length])
            head_indices.append(instance_heads_cpu)

        output_dict["label_logits"] = all_edge_label_logits
        output_dict["predicted_heads"] = head_indices
        output_dict["full_label_logits"] = all_full_label_logits
        output_dict["edge_existence_scores"] = all_edge_existence_scores
        return output_dict

    def prepare_for_ftd(self, output_dict: Dict[str, torch.Tensor]):
        """
        This function does not perform the decoding but only prepares it.
        Therefore, we take the result of forward and perform the following steps (for each sentence in batch):
        - remove padding
        - identify the root of the sentence, group other root-candidates under the proper root
        - collect a selection of supertags to speed up computation (top k selection is done later)
        :param output_dict: result of forward
        :return: output_dict with the following keys added:
            - lexlabels: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root)
            - supertags: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root)
        """
        t0 = time()
        best_supertags = output_dict.pop(
            "best_supertags").cpu().detach().numpy()
        supertag_scores = output_dict.pop(
            "supertag_scores")  # shape (batch_size, seq_len, num supertags)
        full_label_logits = output_dict.pop("full_label_logits").cpu().detach(
        ).numpy()  #shape (batch size, seq len, seq len, num edge labels)
        edge_existence_scores = output_dict.pop(
            "edge_existence_scores").cpu().detach().numpy(
            )  #shape (batch size, seq len, seq len, num edge labels)
        k = 10
        if self.validation_evaluator:  #retrieve k supertags from validation evaluator.
            if isinstance(self.validation_evaluator.predictor,
                          AMconllPredictor):
                k = self.validation_evaluator.predictor.k
        k += 10  # perhaps there are some ill-formed supertags, make that very unlikely that there are not enough left after filtering.
        top_k_supertags = Supertagger.top_k_supertags(
            supertag_scores,
            k).cpu().detach().numpy()  # shape (batch_size, seq_len, k)
        supertag_scores = supertag_scores.cpu().detach().numpy()
        lexlabels = output_dict.pop(
            "lexlabels").cpu().detach().numpy()  #shape (batch_size, seq_len)
        heads = output_dict.pop("heads")
        heads_cpu = heads.cpu().detach().numpy()
        mask = output_dict.pop("mask")
        edge_label_logits = output_dict.pop("label_logits").cpu().detach(
        ).numpy()  # shape (batch_size, seq_len, num edge labels)
        encoded_text_parsing = output_dict.pop("encoded_text_parsing")
        output_dict.pop("encoded_text_tagging")  #don't need that
        lengths = get_lengths_from_binary_sequence_mask(mask)

        #here we collect things, in the end we will have one entry for each sentence:
        all_edge_label_logits = []
        all_supertags = []
        head_indices = []
        roots = []
        all_predicted_lex_labels = []
        all_full_label_logits = []
        all_edge_existence_scores = []
        all_supertag_scores = []

        #we need the following to identify the root
        root_edge_label_id = self.vocab.get_token_index("ROOT",
                                                        namespace=self.name +
                                                        "_head_tags")
        bot_id = self.vocab.get_token_index(AMSentence.get_bottom_supertag(),
                                            namespace=self.name +
                                            "_supertag_labels")

        for i, length in enumerate(lengths):
            instance_heads_cpu = list(heads_cpu[i, 1:length])
            #Postprocess heads and find root of sentence:
            instance_heads_cpu, root = find_root(
                instance_heads_cpu,
                best_supertags[i, 1:length],
                edge_label_logits[i, 1:length, :],
                root_edge_label_id,
                bot_id,
                modify=True)
            roots.append(root)
            #apply changes to instance_heads tensor:
            instance_heads = heads[i, :]
            for j, x in enumerate(instance_heads_cpu):
                instance_heads[j + 1] = torch.tensor(
                    x
                )  #+1 because we removed the first position from instance_heads_cpu

            # re-calculate edge label logits since heads might have changed:
            label_logits = self.edge_model.label_scores(
                encoded_text_parsing[i].unsqueeze(0),
                instance_heads.unsqueeze(0)).squeeze(0).detach().cpu().numpy()
            #(un)squeeze: fake batch dimension
            all_edge_label_logits.append(label_logits[1:length, :])

            all_full_label_logits.append(
                full_label_logits[i, :length, :length, :])
            all_edge_existence_scores.append(
                edge_existence_scores[i, :length, :length])

            #calculate supertags for this sentence:
            all_supertag_scores.append(supertag_scores[
                i, 1:length, :])  #new shape (sent length, num supertags)
            supertags_for_this_sentence = []
            for word in range(1, length):
                supertags_for_this_word = []
                for top_k in top_k_supertags[i, word]:
                    fragment, typ = AMSentence.split_supertag(
                        self.vocab.get_token_from_index(top_k,
                                                        namespace=self.name +
                                                        "_supertag_labels"))
                    score = supertag_scores[i, word, top_k]
                    supertags_for_this_word.append((score, fragment, typ))
                if bot_id not in top_k_supertags[
                        i,
                        word]:  #\bot is not in the top k, but we have to add it anyway in order for the decoder to work properly.
                    fragment, typ = AMSentence.split_supertag(
                        AMSentence.get_bottom_supertag())
                    supertags_for_this_word.append(
                        (supertag_scores[i, word, bot_id], fragment, typ))
                supertags_for_this_sentence.append(supertags_for_this_word)
            all_supertags.append(supertags_for_this_sentence)
            all_predicted_lex_labels.append([
                self.vocab.get_token_from_index(label,
                                                namespace=self.name +
                                                "_lex_labels")
                for label in lexlabels[i, 1:length]
            ])
            head_indices.append(instance_heads_cpu)

        t1 = time()
        normalized_diff = (t1 - t0) / len(lengths)
        output_dict["normalized_prepare_ftd_time"] = [
            normalized_diff for _ in range(len(lengths))
        ]
        output_dict["lexlabels"] = all_predicted_lex_labels
        output_dict["supertags"] = all_supertags
        output_dict["root"] = roots
        output_dict["label_logits"] = all_edge_label_logits
        output_dict["predicted_heads"] = head_indices
        output_dict["full_label_logits"] = all_full_label_logits
        output_dict["edge_existence_scores"] = all_edge_existence_scores
        output_dict["supertag_scores"] = all_supertag_scores
        return output_dict

    def _greedy_decode_edge_labels(
            self, edge_label_logits: torch.Tensor) -> torch.Tensor:
        """
        Assigns edge labels according to (existing) edges.
        Parameters
        ----------
        edge_label_logits: ``torch.Tensor`` of shape (batch_size, sequence_length, num_head_tags)

        Returns
        -------
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded head tags (labels of incoming edges) of each word.
        """
        _, head_tags = edge_label_logits.max(dim=2)
        return head_tags

    def _greedy_decode_arcs(self, existence_scores: torch.Tensor,
                            mask: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head  predictions by decoding the unlabeled arcs
        independently for each word. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        Parameters
        ----------
        existence_scores : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        mask: torch.Tensor, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        existence_scores = existence_scores + torch.diag(
            existence_scores.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).byte().unsqueeze(2)
            existence_scores.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = existence_scores.max(dim=2)
        return heads

    def _get_mask_for_eval(self, mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    def metrics(self,
                parser_model,
                reset: bool = False,
                model_path=None) -> Dict[str, float]:
        """
        Is called by a GraphDependencyParser
        :param parser_model: a GraphDependencyParser
        :param reset:
        :return:
        """
        r = self.get_metrics(reset)
        if reset:  #epoch done
            if self.training:  #done on the training data
                self.current_epoch += 1
            else:  #done on dev/test data
                if self.validation_evaluator:
                    metrics = self.validation_evaluator.eval(
                        parser_model, self.current_epoch, model_path)
                    for name, val in metrics.items():
                        r[name] = val
        return r

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        r = self._attachment_scores.get_metric(reset)
        if self.loss_mixing["supertagging"] is not None:
            r["Constant_Acc"] = self._supertagging_acc.get_metric(reset)
            r["Constant_Acc_6_best"] = self._top_6supertagging_acc.get_metric(
                reset)
        if self.loss_mixing["lexlabel"] is not None:
            r["Label_Acc"] = self._lexlabel_acc.get_metric(reset)
        las = r["LAS"]
        if "Constant_Acc" in r:
            r["mean_constant_acc_las"] = (las + r["Constant_Acc"]) / 2
        return r
Esempio n. 29
0
class AttachmentScoresTest(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.scorer = AttachmentScores()

        self.predictions = torch.Tensor([[0, 1, 3, 5, 2, 4],
                                         [0, 3, 2, 1, 0, 0]])

        self.gold_indices = torch.Tensor([[0, 1, 3, 5, 2, 4],
                                          [0, 3, 2, 1, 0, 0]])

        self.label_predictions = torch.Tensor([[0, 5, 2, 1, 4, 2],
                                               [0, 4, 8, 2, 0, 0]])

        self.gold_labels = torch.Tensor([[0, 5, 2, 1, 4, 2],
                                         [0, 4, 8, 2, 0, 0]])

        self.mask = torch.Tensor([[1, 1, 1, 1, 1, 1],
                                  [1, 1, 1, 1, 0, 0]])

    def test_perfect_scores(self):
        self.scorer(self.predictions, self.label_predictions,
                    self.gold_indices, self.gold_labels, self.mask)

        for value in self.scorer.get_metric().values():
            assert value == 1.0

    def test_unlabeled_accuracy_ignores_incorrect_labels(self):
        label_predictions = self.label_predictions
        # Change some stuff so our 4 of our label predictions are wrong.
        label_predictions[0, 3:] = 3
        label_predictions[1, 0] = 7
        self.scorer(self.predictions, label_predictions,
                    self.gold_indices, self.gold_labels, self.mask)

        metrics = self.scorer.get_metric()

        assert metrics["UAS"] == 1.0
        assert metrics["UEM"] == 1.0

        # 4 / 12 labels were wrong and 2 positions
        # are masked, so 6/10 = 0.6 LAS.
        assert metrics["LAS"] == 0.6
        # Neither should have labeled exact match.
        assert metrics["LEM"] == 0.0

    def test_labeled_accuracy_is_affected_by_incorrect_heads(self):
        predictions = self.predictions
        # Change some stuff so our 4 of our predictions are wrong.
        predictions[0, 3:] = 3
        predictions[1, 0] = 7
        # This one is in the padded part, so it shouldn't affect anything.
        predictions[1, 5] = 7
        self.scorer(predictions, self.label_predictions,
                    self.gold_indices, self.gold_labels, self.mask)

        metrics = self.scorer.get_metric()

        # 4 heads are incorrect, so the unlabeled score should be
        # 6/10 = 0.6 LAS.
        assert metrics["UAS"] == 0.6
        # All the labels were correct, but some heads
        # were wrong, so the LAS should equal the UAS.
        assert metrics["LAS"] == 0.6

        # Neither batch element had a perfect labeled or unlabeled EM.
        assert metrics["LEM"] == 0.0
        assert metrics["UEM"] == 0.0

    def test_attachment_scores_can_ignore_labels(self):
        scorer = AttachmentScores(ignore_classes=[1])

        label_predictions = self.label_predictions
        # Change the predictions where the gold label is 1;
        # as we are ignoring 1, we should still get a perfect score.
        label_predictions[0, 3] = 2
        scorer(self.predictions, label_predictions,
               self.gold_indices, self.gold_labels, self.mask)

        for value in scorer.get_metric().values():
            assert value == 1.0
Esempio n. 30
0
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_embed_dim: int = None,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DependencyDecoder, self).__init__(vocab, regularizer)

        self.pos_tag_embedding = None
        if pos_embed_dim is not None:
            self.pos_tag_embedding = Embedding(self.vocab.get_vocab_size("upos"), pos_embed_dim)

        self.dropout = torch.nn.Dropout(p=dropout)

        self.encoder = encoder
        encoder_output_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_output_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_output_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._dropout = InputVariationalDropout(dropout)
        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder_output_dim]))

        check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim", "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim", "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Esempio n. 31
0
    def __init__(
            self,
            vocab: Vocabulary,
            # source-side
            bert_encoder: BaseBertWrapper,
            encoder_token_embedder: TextFieldEmbedder,
            encoder_pos_embedding: Embedding,
            encoder: Seq2SeqEncoder,
            # target-side
            decoder_token_embedder: TextFieldEmbedder,
            decoder_node_index_embedding: Embedding,
            decoder_pos_embedding: Embedding,
            decoder: RNNDecoder,
            extended_pointer_generator: ExtendedPointerGenerator,
            tree_parser: DecompTreeParser,
            node_attribute_module: NodeAttributeDecoder,
            edge_attribute_module: EdgeAttributeDecoder,
            # misc
            label_smoothing: LabelSmoothing,
            target_output_namespace: str,
            pos_tag_namespace: str,
            edge_type_namespace: str,
            syntax_edge_type_namespace: str = None,
            biaffine_parser: DeepTreeParser = None,
            syntactic_method: str = None,
            dropout: float = 0.0,
            beam_size: int = 5,
            max_decoding_steps: int = 50,
            eps: float = 1e-20,
            loss_mixer: LossMixer = None) -> None:

        super(DecompSyntaxOnlyParser, self).__init__(
            vocab=vocab,
            # source-side
            bert_encoder=bert_encoder,
            encoder_token_embedder=encoder_token_embedder,
            encoder_pos_embedding=encoder_pos_embedding,
            encoder=encoder,
            # target-side
            decoder_token_embedder=decoder_token_embedder,
            decoder_node_index_embedding=decoder_node_index_embedding,
            decoder_pos_embedding=decoder_pos_embedding,
            decoder=decoder,
            extended_pointer_generator=extended_pointer_generator,
            tree_parser=tree_parser,
            node_attribute_module=node_attribute_module,
            edge_attribute_module=edge_attribute_module,
            # misc
            label_smoothing=label_smoothing,
            target_output_namespace=target_output_namespace,
            pos_tag_namespace=pos_tag_namespace,
            edge_type_namespace=edge_type_namespace,
            syntax_edge_type_namespace=syntax_edge_type_namespace,
            syntactic_method=syntactic_method,
            dropout=dropout,
            beam_size=beam_size,
            max_decoding_steps=max_decoding_steps,
            eps=eps)

        self.syntactic_method = "encoder-side"
        self.biaffine_parser = biaffine_parser
        self.loss_mixer = None
        self._syntax_metrics = AttachmentScores()
        self.syntax_las = 0.0
        self.syntax_uas = 0.0
class BiaffineDependencyParser(Model):
    """
    This dependency parser follows the model of
    [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)]
    (https://arxiv.org/abs/1611.01734) .

    Word representations are generated using a bidirectional LSTM,
    followed by separate biaffine classifiers for pairs of words,
    predicting whether a directed arc exists between the two words
    and the dependency label the arc should have. Decoding can either
    be done greedily, or the optimal Minimum Spanning Tree can be
    decoded using Edmond's algorithm by viewing the dependency tree as
    a MST on a fully connected graph, where nodes are words and edges
    are scored dependency arcs.

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : `TextFieldEmbedder`, required
        Used to embed the `tokens` `TextField` we get as input to the model.
    encoder : `Seq2SeqEncoder`
        The encoder (with its own internal stacking) that we will use to generate representations
        of tokens.
    tag_representation_dim : `int`, required.
        The dimension of the MLPs used for dependency tag prediction.
    arc_representation_dim : `int`, required.
        The dimension of the MLPs used for head arc prediction.
    tag_feedforward : `FeedForward`, optional, (default = None).
        The feedforward network used to produce tag representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    arc_feedforward : `FeedForward`, optional, (default = None).
        The feedforward network used to produce arc representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    pos_tag_embedding : `Embedding`, optional.
        Used to embed the `pos_tags` `SequenceLabelField` we get as input to the model.
    use_mst_decoding_for_validation : `bool`, optional (default = True).
        Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation.
        If false, decoding is greedy.
    dropout : `float`, optional, (default = 0.0)
        The variational dropout applied to the output of the encoder and MLP layers.
    input_dropout : `float`, optional, (default = 0.0)
        The dropout applied to the embedded text input.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    regularizer : `RegularizerApplicator`, optional (default=`None`)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        tag_representation_dim: int,
        arc_representation_dim: int,
        model_name: str = None,
        tag_feedforward: FeedForward = None,
        arc_feedforward: FeedForward = None,
        pos_tag_embedding: Embedding = None,
        use_mst_decoding_for_validation: bool = True,
        dropout: float = 0.0,
        input_dropout: float = 0.0,
        word_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        if model_name:
            from src.data.token_indexers import PretrainedAutoTokenizer
            self._tokenizer = PretrainedAutoTokenizer.load(model_name)

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or FeedForward(
            encoder_dim, 1, arc_representation_dim,
            Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or FeedForward(
            encoder_dim, 1, tag_representation_dim,
            Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._word_dropout = word_dropout
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(
            representation_dim,
            encoder.get_input_dim(),
            "text field embedding dim",
            "encoder input dim",
        )

        check_dimensions_match(
            tag_representation_dim,
            self.head_tag_feedforward.get_output_dim(),
            "tag representation dim",
            "tag feedforward output dim",
        )
        check_dimensions_match(
            arc_representation_dim,
            self.head_arc_feedforward.get_output_dim(),
            "arc representation dim",
            "arc feedforward output dim",
        )

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)

    @overrides
    def forward(
        self,  # type: ignore
        words: TextFieldTensors,
        pos_tags: torch.LongTensor,
        metadata: List[Dict[str, Any]],
        head_tags: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
        lemmas: TextFieldTensors = None,
        feats: TextFieldTensors = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        words : TextFieldTensors, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer`
            tensors.  At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens":
            Tensor(batch_size, sequence_length)}`. This dictionary will have the same keys as were used
            for the `TokenIndexers` when you created the `TextField` representing your
            sequence.  The dictionary is designed to be passed directly to a `TextFieldEmbedder`,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : `torch.LongTensor`, required
            The output of a `SequenceLabelField` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        metadata : List[Dict[str, Any]], optional (default=None)
            A dictionary of metadata for each batch element which has keys:
                words : `List[str]`, required.
                    The tokens in the original sentence.
                pos : `List[str]`, required.
                    The dependencies POS tags for each word.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape `(batch_size, sequence_length)`.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape `(batch_size, sequence_length)`.

        # Returns

        An output dictionary consisting of:
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        arc_loss : `torch.FloatTensor`
            The loss contribution from the unlabeled arcs.
        loss : `torch.FloatTensor`, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : `torch.FloatTensor`
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : `torch.FloatTensor`
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : `torch.LongTensor`
            A mask denoting the padded elements in the batch.
        """
        mask = get_text_field_mask(words)
        words = self._apply_token_dropout(words)
        embedded_text_input = self.text_field_embedder(words)

        if pos_tags is not None and self._pos_tag_embedding is not None:
            pos_tags_dict = {"tokens": pos_tags, "mask": mask}
            self._apply_token_dropout(pos_tags_dict)
            pos_tags = pos_tags_dict["tokens"]
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll = self._parse(
            encoded_text, mask, head_tags, head_indices)

        loss = arc_nll + tag_nll

        if head_indices is not None and head_tags is not None:
            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(
                predicted_heads[:, 1:],
                predicted_head_tags[:, 1:],
                head_indices,
                head_tags,
                evaluation_mask,
            )

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
        }

        self._add_metadata_to_output_dict(metadata, output_dict)

        return output_dict

    def _add_metadata_to_output_dict(self, metadata, output_dict):
        if metadata is not None:
            output_dict["words"] = [x["words"] for x in metadata]
            output_dict["upos"] = [x["upos"] for x in metadata]
            output_dict["xpos"] = [x["xpos"] for x in metadata]
            output_dict["feats"] = [x["feats"] for x in metadata]
            output_dict["lemmas"] = [x["lemmas"] for x in metadata]
            output_dict["ids"] = [x["ids"] for x in metadata if "ids" in x]
            output_dict["multiword_ids"] = [
                x["multiword_ids"] for x in metadata if "multiword_ids" in x
            ]
            output_dict["multiword_forms"] = [
                x["multiword_forms"] for x in metadata
                if "multiword_forms" in x
            ]

        return output_dict

    def _apply_token_dropout(self, words):
        # Word dropout

        def mask_words(tokens, drop_mask, drop_token):
            drop_fill = tokens.new_empty(
                tokens.size()).long().fill_(drop_token)
            return torch.where(drop_mask, drop_fill, tokens)

        if "tokens" in words:
            drop_mask = self._get_dropout_mask(words["mask"].bool(),
                                               p=self._word_dropout,
                                               training=self.training)
            drop_token = self.vocab.get_token_index(self.vocab._oov_token)
            words["tokens"] = mask_words(words["tokens"], drop_mask,
                                         drop_token)

        def mask_subwords(words, drop_mask, drop_token):
            token_ids = words["token_ids"]
            offsets = words["offsets"]
            subword_drop_mask = token_ids.new_zeros(token_ids.size()).bool()
            batch_size, seq_len, _ = offsets.size()
            for i in range(batch_size):
                for j in range(seq_len):
                    start, end = offsets[i, j].tolist()
                    subword_drop_mask[i, start:(end + 1)] = drop_mask[i, j]
            drop_fill = token_ids.new_empty(
                token_ids.size()).long().fill_(drop_token)
            return torch.where(subword_drop_mask, drop_fill, token_ids)

        if "roberta" in words:
            drop_mask = self._get_dropout_mask(words["roberta"]["mask"].bool(),
                                               p=self._word_dropout,
                                               training=self.training)
            drop_token = self._tokenizer.encode("<mask>",
                                                add_special_tokens=False)[0]
            words["roberta"]["token_ids"] = mask_subwords(
                words["roberta"], drop_mask, drop_token)

    @staticmethod
    def _get_dropout_mask(mask: torch.Tensor,
                          p: float = 0.0,
                          training: float = True) -> torch.LongTensor:
        """
        During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p``

        :param tokens: The current batch of padded sentences with word ids
        :param drop_token: The mask token
        :param padding_tokens: The tokens for padding the input batch
        :param p: The probability a word gets mapped to the unknown token
        :param training: Applies the dropout if set to ``True``
        :return: A copy of the input batch with token dropout applied
        """
        if training and p > 0:
            # Create a uniformly random mask selecting either() the original words or OOV tokens
            dropout_mask = (mask.new_empty(mask.size()).float().uniform_() < p)
            drop_mask = dropout_mask & mask
            return drop_mask
        else:
            return mask.new_zeros(mask.size()).bool()

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:

        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)
        head_tag_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(
                heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [
                self.vocab.get_token_from_index(label, "head_tags")
                for label in instance_tags
            ]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_dependencies"] = head_tag_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict

    def _parse(
        self,
        encoded_text: torch.Tensor,
        mask: torch.LongTensor,
        head_tags: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor]:

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask,
            )
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask,
            )

        return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll

    def _construct_loss(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        head_indices: torch.Tensor,
        head_tags: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : `torch.Tensor`, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        # Returns

        arc_nll : `torch.Tensor`, required.
            The negative log likelihood from the arc loss.
        tag_nll : `torch.Tensor`, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = (masked_log_softmax(attended_arcs, mask) *
                                 float_mask.unsqueeze(2) *
                                 float_mask.unsqueeze(1))

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        normalised_head_tag_logits = masked_log_softmax(
            head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = (timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long())
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll

    def _greedy_decode(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs independently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        # Returns

        heads : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(
            attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation, heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags

    def _mst_decode(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        # Returns

        heads : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size(
        )

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [
            batch_size, sequence_length, sequence_length,
            tag_representation_dim
        ]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(
            *expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(
            *expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation,
                                                 child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits,
                                                        dim=3).permute(
                                                            0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs,
                                              dim=2).transpose(1, 2)

        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.
        batch_energy = torch.exp(
            normalized_arc_logits.unsqueeze(1) +
            normalized_pairwise_head_logits)
        return self._run_mst_decoding(batch_energy, lengths)

    @staticmethod
    def _run_mst_decoding(
            batch_energy: torch.Tensor,
            lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(),
                                           length,
                                           has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(
            numpy.stack(head_tags))

    def _get_head_tags(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        head_indices: torch.Tensor,
    ) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        # Returns

        head_tag_logits : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(
            batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[
            range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous(
        )
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits

    def _get_mask_for_eval(self, mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        # Parameters

        mask : `torch.LongTensor`, required.
            The original mask.
        pos_tags : `torch.LongTensor`, required.
            The pos tags for the sequence.

        # Returns

        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return self._attachment_scores.get_metric(reset)
Esempio n. 33
0
    def __init__(self,
                 vocab,
                 text_field_embedder,
                 encoder,
                 tag_representation_dim,
                 arc_representation_dim,
                 tag_feedforward=None,
                 arc_feedforward=None,
                 pos_tag_embedding=None,
                 use_mst_decoding_for_validation=True,
                 dropout=0.0,
                 input_dropout=0.0,
                 initializer=InitializerApplicator(),
                 regularizer=None):
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or\
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name(u"elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size(u"head_tags")

        self.head_tag_feedforward = tag_feedforward or\
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name(u"elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               u"text field embedding dim",
                               u"encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               u"tag representation dim",
                               u"tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               u"arc representation dim",
                               u"arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary(u"pos")
        punctuation_tag_indices = dict((tag, index)
                                       for tag, index in list(tags.items())
                                       if tag in POS_TO_IGNORE)
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            "Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
            u"Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
class BiaffineDependencyParser(Model):
    """
    This dependency parser follows the model of
    ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
    <https://arxiv.org/abs/1611.01734>`_ .

    Word representations are generated using a bidirectional LSTM,
    followed by separate biaffine classifiers for pairs of words,
    predicting whether a directed arc exists between the two words
    and the dependency label the arc should have. Decoding can either
    be done greedily, or the optimial Minimum Spanning Tree can be
    decoded using Edmond's algorithm by viewing the dependency tree as
    a MST on a fully connected graph, where nodes are words and edges
    are scored dependency arcs.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    encoder : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use to generate representations
        of tokens.
    tag_representation_dim : ``int``, required.
        The dimension of the MLPs used for dependency tag prediction.
    arc_representation_dim : ``int``, required.
        The dimension of the MLPs used for head arc prediction.
    tag_feedforward : ``FeedForward``, optional, (default = None).
        The feedforward network used to produce tag representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    arc_feedforward : ``FeedForward``, optional, (default = None).
        The feedforward network used to produce arc representations.
        By default, a 1 layer feedforward network with an elu activation is used.
    pos_tag_embedding : ``Embedding``, optional.
        Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model.
    use_mst_decoding_for_validation : ``bool``, optional (default = True).
        Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation.
        If false, decoding is greedy.
    dropout : ``float``, optional, (default = 0.0)
        The variational dropout applied to the output of the encoder and MLP layers.
    input_dropout : ``float``, optional, (default = 0.0)
        The dropout applied to the embedded text input.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim", "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim", "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                words: Dict[str, torch.LongTensor],
                pos_tags: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required.
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        loss : ``torch.FloatTensor``, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
                "heads": predicted_heads,
                "head_tags": predicted_head_tags,
                "arc_loss": arc_nll,
                "tag_loss": tag_nll,
                "loss": loss,
                "mask": mask,
                "words": [meta["words"] for meta in metadata],
                "pos": [meta["pos"] for meta in metadata]
                }

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)
        head_tag_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [self.vocab.get_token_from_index(label, "head_tags")
                      for label in instance_tags]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_dependencies"] = head_tag_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict

    def _construct_loss(self,
                        head_tag_representation: torch.Tensor,
                        child_tag_representation: torch.Tensor,
                        attended_arcs: torch.Tensor,
                        head_indices: torch.Tensor,
                        head_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(attended_arcs,
                                                   mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
        normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
                                                        mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll

    def _greedy_decode(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       attended_arcs: torch.Tensor,
                       mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs indpendently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).byte().unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags

    def _mst_decode(self,
                    head_tag_representation: torch.Tensor,
                    child_tag_representation: torch.Tensor,
                    attended_arcs: torch.Tensor,
                    mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size()

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2)

        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.
        batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits)
        return self._run_mst_decoding(batch_energy, lengths)

    @staticmethod
    def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))

    def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits

    def _get_mask_for_eval(self,
                           mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return self._attachment_scores.get_metric(reset)
class GraphDependencyParser(Model):
    """
    This dependency graph_dependency_parser is a blueprint for several graph-based dependency parsers.

    There are several possible edge models and loss functions.

    For decoding, the CLE algorithm is used (during training attachments scores are usually based on greedy decoding)

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    encoder : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use to generate representations
        of tokens.
    edge_model: ``components.edge_models.EdgeModel``, required.
        The edge model to be used.
    loss_function: ``components.losses.EdgeLoss``, required.
        The loss function to be used.
    pos_tag_embedding : ``Embedding``, optional.
        Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model.
    use_mst_decoding_for_validation : ``bool``, optional (default = True).
        Whether to use Edmond's algorithm to find the optimal minimum spanning tree during validation.
        If false, decoding is greedy.
    dropout : ``float``, optional, (default = 0.0)
        The variational dropout applied to the output of the encoder and MLP layers.
    input_dropout : ``float``, optional, (default = 0.0)
        The dropout applied to the embedded text input.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            encoder: Seq2SeqEncoder,
            edge_model: graph_dependency_parser.components.edge_models.
        EdgeModel,
            loss_function: graph_dependency_parser.components.losses.EdgeLoss,
            pos_tag_embedding: Embedding = None,
            use_mst_decoding_for_validation: bool = True,
            dropout: float = 0.0,
            input_dropout: float = 0.0,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None,
            validation_evaluator: Optional[ValidationEvaluator] = None
    ) -> None:
        super(GraphDependencyParser, self).__init__(vocab, regularizer)

        self.validation_evaluator = validation_evaluator

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        check_dimensions_match(encoder.get_output_dim(),
                               edge_model.encoder_dim(), "encoder output dim",
                               "input dim edge model")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)

        self.edge_model = edge_model
        self.loss_function = loss_function

        #Being able to detect what state we are in, probably not the best idea.
        self.current_epoch = 1
        self.pass_over_data_just_started = True

    @overrides
    def forward(
            self,  # type: ignore
            words: Dict[str, torch.LongTensor],
            pos_tags: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        metadata : List[Dict[str, Any]], optional (default=None)
            A dictionary of metadata for each batch element which has keys:
                words : ``List[str]``, required.
                    The tokens in the original sentence.
                pos : ``List[str]``, required.
                    The dependencies POS tags for each word.
        head_tags : = edge_labels torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold edge labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        edge_label_loss : ``torch.FloatTensor``
            The loss contribution from the edge labels.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        edge_labels : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        encoded_text = self._dropout(encoded_text)

        edge_existence_scores = self.edge_model.edge_existence(
            encoded_text, mask)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads = self._greedy_decode_arcs(edge_existence_scores,
                                                       mask)
            edge_label_logits = self.edge_model.label_scores(
                encoded_text, predicted_heads)
            predicted_edge_labels = self._greedy_decode_edge_labels(
                edge_label_logits)
        else:
            #Find best tree with CLE
            predicted_heads = cle_decode(edge_existence_scores,
                                         mask.data.sum(dim=1).long())
            #With info about tree structure, get edge label scores
            edge_label_logits = self.edge_model.label_scores(
                encoded_text, predicted_heads)
            #Predict edge labels
            predicted_edge_labels = self._greedy_decode_edge_labels(
                edge_label_logits)

        output_dict = {
            "heads":
            predicted_heads,
            "edge_labels":
            predicted_edge_labels,
            "mask":
            mask,
            "words": [meta["words"] for meta in metadata],
            "pos": [meta["pos"] for meta in metadata],
            "position_in_corpus":
            [meta["position_in_corpus"] for meta in metadata],
        }

        if head_indices is not None and head_tags is not None:
            gold_edge_label_logits = self.edge_model.label_scores(
                encoded_text, head_indices)
            edge_label_loss = self.loss_function.label_loss(
                gold_edge_label_logits, mask, head_tags)

            arc_nll = self.loss_function.edge_existence_loss(
                edge_existence_scores, head_indices, mask)

            loss = arc_nll + edge_label_loss

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_edge_labels[:, 1:],
                                    head_indices[:, 1:], head_tags[:, 1:],
                                    evaluation_mask)
            output_dict["arc_loss"] = arc_nll
            output_dict["edge_label_loss"] = edge_label_loss
            output_dict["loss"] = loss

        if self.pass_over_data_just_started:
            # here we could decide if we want to start collecting info for the decoder.
            pass
        self.pass_over_data_just_started = False
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]):
        """
        Takes the result of forward and creates human readable, non-padded dependency trees.
        :param output_dict:
        :return: output_dict with two new keys, "predicted_labels" and "predicted_heads", which are lists of lists.
        """

        head_tags = output_dict.pop("edge_labels").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)
        edge_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(
                heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [
                self.vocab.get_token_from_index(label, "head_tags")
                for label in instance_tags
            ]
            edge_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_labels"] = edge_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict

    def _greedy_decode_edge_labels(
            self, edge_label_logits: torch.Tensor) -> torch.Tensor:
        """
        Assigns edge labels according to (existing) edges.
        Parameters
        ----------
        edge_label_logits: ``torch.Tensor`` of shape (batch_size, sequence_length, num_head_tags)

        Returns
        -------
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded head tags (labels of incoming edges) of each word.
        """
        _, head_tags = edge_label_logits.max(dim=2)
        return head_tags

    def _greedy_decode_arcs(self, existence_scores: torch.Tensor,
                            mask: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head  predictions by decoding the unlabeled arcs
        independently for each word. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        Parameters
        ----------
        existence_scores : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        mask: torch.Tensor, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        existence_scores = existence_scores + torch.diag(
            existence_scores.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).byte().unsqueeze(2)
            existence_scores.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = existence_scores.max(dim=2)
        return heads

    def _get_mask_for_eval(self, mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        r = self._attachment_scores.get_metric(reset)
        if reset:
            if self.training:  #done on the training data
                pass
                self.current_epoch += 1
                self.pass_over_data_just_started = True
            else:  #done on dev/test data
                if self.validation_evaluator:
                    metrics = self.validation_evaluator.eval(self)
                    for name, val in metrics.items():
                        r[name] = val
        return r