Ejemplo n.º 1
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: MTLWeightSharer,
                 tasks: List[AMTask],
                 pos_tag_embedding: Embedding = None,
                 lemma_embedding: Embedding = None,
                 ne_embedding: Embedding = None,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 tok2vec: Optional[TokenToVec] = None) -> None:
        super(GraphDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.tok2vec = tok2vec

        self._pos_tag_embedding = pos_tag_embedding or None
        self._lemma_embedding = lemma_embedding
        self._ne_embedding = ne_embedding

        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        if self._lemma_embedding is not None:
            representation_dim += lemma_embedding.get_output_dim()
        if self._ne_embedding is not None:
            representation_dim += ne_embedding.get_output_dim()

        assert len(tasks) > 0, "List of tasks must not be empty"
        self.tasks: Dict[str, AMTask] = {t.name: t for t in tasks}

        if self.tok2vec:
            representation_dim += self.tok2vec.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        for t in tasks:
            t.check_all_dimensions_match(encoder.get_output_dim())

        for formalism, task in sorted(self.tasks.items(),
                                      key=lambda nt: nt[0]):
            #sort by name of formalism for consistent ordering
            self.add_module(formalism, task)
        initializer(self)
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            encoder: Seq2SeqEncoder,
            edge_model: graph_dependency_parser.components.edge_models.
        EdgeModel,
            loss_function: graph_dependency_parser.components.losses.EdgeLoss,
            pos_tag_embedding: Embedding = None,
            use_mst_decoding_for_validation: bool = True,
            dropout: float = 0.0,
            input_dropout: float = 0.0,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None,
            validation_evaluator: Optional[ValidationEvaluator] = None
    ) -> None:
        super(GraphDependencyParser, self).__init__(vocab, regularizer)

        self.validation_evaluator = validation_evaluator

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        check_dimensions_match(encoder.get_output_dim(),
                               edge_model.encoder_dim(), "encoder output dim",
                               "input dim edge model")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)

        self.edge_model = edge_model
        self.loss_function = loss_function

        #Being able to detect what state we are in, probably not the best idea.
        self.current_epoch = 1
        self.pass_over_data_just_started = True
Ejemplo n.º 3
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()
        self.head_arc_projection = torch.nn.Linear(encoder_dim,
                                                   arc_representation_dim)
        self.child_arc_projection = torch.nn.Linear(encoder_dim,
                                                    arc_representation_dim)
        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")
        self.head_tag_projection = torch.nn.Linear(encoder_dim,
                                                   tag_representation_dim)
        self.child_tag_projection = torch.nn.Linear(encoder_dim,
                                                    tag_representation_dim)
        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Ejemplo n.º 4
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()
        self.head_arc_projection = torch.nn.Linear(encoder_dim, arc_representation_dim)
        self.child_arc_projection = torch.nn.Linear(encoder_dim, arc_representation_dim)
        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")
        self.head_tag_projection = torch.nn.Linear(encoder_dim, tag_representation_dim)
        self.child_tag_projection = torch.nn.Linear(encoder_dim, tag_representation_dim)
        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Ejemplo n.º 5
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 label_namespace: str = "pos",
                 treebank_embedding: Embedding = None,
                 use_treebank_embedding: bool = True,
                 langs_for_early_stop: List[str] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(PosTaggerTbemb, self).__init__(vocab, regularizer)

        self.label_namespace = label_namespace
        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size(label_namespace)
        self.encoder = encoder
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._langs_for_early_stop = langs_for_early_stop or []
        self._treebank_embedding = treebank_embedding or None
        self._use_treebank_embedding = use_treebank_embedding
        self._lang_accuracy_scores: Dict[
            str, CategoricalAccuracy] = defaultdict(CategoricalAccuracy)

        self.tag_projection_layer = TimeDistributed(
            Linear(self.encoder.get_output_dim(), self.num_classes))

        representation_dim = text_field_embedder.get_output_dim()

        if treebank_embedding is not None:
            representation_dim += treebank_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        if self._use_treebank_embedding:
            tbids = self.vocab.get_token_to_index_vocabulary("tbids")
            tbid_indices = {tb: index for tb, index in tbids.items()}
            self._tbids = set(tbid_indices.values())
            logger.info(
                f"Found TBIDs corresponding to the following treebanks : {tbid_indices}. "
                "Embedding these as additional features.")

        initializer(self)
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()
        self.head_arc_projection = torch.nn.Linear(encoder_dim,
                                                   arc_representation_dim)
        self.child_arc_projection = torch.nn.Linear(encoder_dim,
                                                    arc_representation_dim)
        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")
        self.head_tag_projection = torch.nn.Linear(encoder_dim,
                                                   tag_representation_dim)
        self.child_tag_projection = torch.nn.Linear(encoder_dim,
                                                    tag_representation_dim)
        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        self._attachment_scores = AttachmentScores()
        initializer(self)
Ejemplo n.º 7
0
class JointSentimentClassifier(Model):
    """
    Parameters
    ----------
    vocab: ``allennlp.data.Vocabulary``, required.
        The vocabulary fitted on the data.
    params: ``allennlp.common.Params``, required
        Configuration parameters for the multi-tasks model.
    regularizer: ``allennlp.nn.RegularizerApplicator``, optional (default = None)
        A reguralizer to apply to the model's layers.
    """

    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 share_encoder: Seq2VecEncoder = None,
                 private_encoder: Seq2VecEncoder = None,
                 dropout: float = None,
                 input_dropout: float = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: RegularizerApplicator = None) -> None:
        super(JointSentimentClassifier, self).__init__(vocab=vocab, regularizer=regularizer)

        self._text_field_embedder = text_field_embedder
        self._domain_embeddings = Embedding(len(TASKS_NAME), 50)
        if share_encoder is None and private_encoder is None:
            share_rnn = nn.LSTM(input_size=self._text_field_embedder.get_output_dim(),
                                hidden_size=150,
                                batch_first=True,
                                dropout=dropout,
                                bidirectional=True)
            share_encoder = PytorchSeq2SeqWrapper(share_rnn)
            private_rnn = nn.LSTM(input_size=self._text_field_embedder.get_output_dim(),
                                  hidden_size=150,
                                  batch_first=True,
                                  dropout=dropout,
                                  bidirectional=True)
            private_encoder = PytorchSeq2SeqWrapper(private_rnn)
            logger.info("Using LSTM as encoder")
            self._domain_embeddings = Embedding(len(TASKS_NAME), self._text_field_embedder.get_output_dim())
        self._share_encoder = share_encoder

        self._s_domain_discriminator = Discriminator(share_encoder.get_output_dim(), len(TASKS_NAME))

        self._p_domain_discriminator = Discriminator(private_encoder.get_output_dim(), len(TASKS_NAME))

        # TODO individual valid discriminator
        self._valid_discriminator = Discriminator(self._domain_embeddings.get_output_dim(), 2)

        for task in TASKS_NAME:
            tagger = SentimentClassifier(
                vocab=vocab,
                text_field_embedder=self._text_field_embedder,
                share_encoder=self._share_encoder,
                private_encoder=copy.deepcopy(private_encoder),
                domain_embeddings=self._domain_embeddings,
                s_domain_discriminator=self._s_domain_discriminator,
                p_domain_discriminator=self._p_domain_discriminator,
                valid_discriminator=self._valid_discriminator,
                dropout=dropout,
                input_dropout=input_dropout,
                label_smoothing=0.1,
                initializer=initializer
            )
            self.add_module("_tagger_{}".format(task), tagger)

        logger.info("Multi-Task Learning Model has been instantiated.")

    @overrides
    def forward(self, tensor_batch, task_name: str, epoch_trained=None, reverse=False, for_training=False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        task_tagger = getattr(self, "_tagger_%s" % task_name)
        task_index = TASKS_NAME.index(task_name)
        tensor_batch['task_index'] = torch.tensor(task_index)
        tensor_batch["reverse"] = torch.tensor(reverse)
        tensor_batch['for_training'] = torch.tensor(for_training)
        tensor_batch['epoch_trained'] = epoch_trained
        tensor_batch = move_to_device(tensor_batch, 0)
        return task_tagger.forward(**tensor_batch)

    @overrides
    def get_metrics(self, task_name: str, reset: bool = False) -> Dict[str, float]:
        task_tagger = getattr(self, "_tagger_" + task_name)
        return task_tagger.get_metrics(reset)
Ejemplo n.º 8
0
class TagDecoder(Model):
    """
    A basic sequence tagger that decodes from inputs of word embeddings
    """
    def __init__(self,
                 vocab: Vocabulary,
                 task: str,
                 encoder: Seq2SeqEncoder,
                 lang_embed_dim: int = None,
                 use_lang_feedforward: bool = False,
                 lang_feedforward: FeedForward = None,
                 label_smoothing: float = 0.0,
                 dropout: float = 0.0,
                 adaptive: bool = False,
                 features: List[str] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(TagDecoder, self).__init__(vocab, regularizer)

        self.lang_embedding = None
        if lang_embed_dim is not None:
            self.lang_embedding = Embedding(self.vocab.get_vocab_size("langs"), lang_embed_dim)

        self.dropout = torch.nn.Dropout(p=dropout)

        self.task = task
        self.encoder = encoder
        self.output_dim = encoder.get_output_dim()
        self.label_smoothing = label_smoothing
        self.num_classes = self.vocab.get_vocab_size(task)
        self.adaptive = adaptive
        self.features = features if features else []

        self.use_lang_feedforward = use_lang_feedforward
        if self.lang_embedding is not None and use_lang_feedforward:
            self.lang_feedforward = lang_feedforward or \
                                     FeedForward(self.output_dim, 1,
                                                 self.output_dim,
                                                 Activation.by_name("elu")())

        self.metrics = {
            "acc": CategoricalAccuracy(),
            # "acc3": CategoricalAccuracy(top_k=3)
        }

        if self.adaptive:
            # TODO
            adaptive_cutoffs = [round(self.num_classes / 15), 3 * round(self.num_classes / 15)]
            self.task_output = AdaptiveLogSoftmaxWithLoss(self.output_dim,
                                                          self.num_classes,
                                                          cutoffs=adaptive_cutoffs,
                                                          div_value=4.0)
        else:
            self.task_output = TimeDistributed(Linear(self.output_dim, self.num_classes))

        self.feature_outputs = torch.nn.ModuleDict()
        self.features_metrics = {}
        for feature in self.features:
            self.feature_outputs[feature] = TimeDistributed(Linear(self.output_dim,
                                                                   vocab.get_vocab_size(feature)))
            self.features_metrics[feature] = {
                "acc": CategoricalAccuracy(),
            }

        initializer(self)

    @overrides
    def forward(self,
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                gold_tags: Dict[str, torch.LongTensor],
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self.lang_embedding is not None:
            batch_size, _, _ = encoded_text.size()
            lang_embed_size = self.lang_embedding.get_output_dim()
            embedded_lang = self.dropout(self.lang_embedding(gold_tags['langs']))
            embedded_lang = embedded_lang.view(batch_size, -1, lang_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_lang], -1)

        if self.lang_embedding is not None and self.use_lang_feedforward:
            encoded_text = self.lang_feedforward(encoded_text)

        hidden = encoded_text
        hidden = self.encoder(hidden, mask)

        batch_size, sequence_length, _ = hidden.size()
        output_dim = [batch_size, sequence_length, self.num_classes]

        loss_fn = self._adaptive_loss if self.adaptive else self._loss

        output_dict = loss_fn(hidden, mask, gold_tags[self.task], output_dim)
        self._features_loss(hidden, mask, gold_tags, output_dict)

        return output_dict

    def _adaptive_loss(self, hidden, mask, gold_tags, output_dim):
        logits = hidden
        reshaped_log_probs = logits.view(-1, logits.size(2))

        class_probabilities = self.task_output.log_prob(reshaped_log_probs).view(output_dim)

        output_dict = {"logits": logits, "class_probabilities": class_probabilities}

        if gold_tags is not None:
            output_dict["loss"] = sequence_cross_entropy(class_probabilities,
                                                         gold_tags,
                                                         mask,
                                                         label_smoothing=self.label_smoothing)
            for metric in self.metrics.values():
                metric(class_probabilities, gold_tags, mask.float())

        return output_dict

    def _loss(self, hidden, mask, gold_tags, output_dim):
        logits = self.task_output(hidden)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(output_dim)

        output_dict = {"logits": logits, "class_probabilities": class_probabilities}

        if gold_tags is not None:
            output_dict["loss"] = sequence_cross_entropy_with_logits(logits,
                                                                     gold_tags,
                                                                     mask,
                                                                     label_smoothing=self.label_smoothing)
            for metric in self.metrics.values():
                metric(logits, gold_tags, mask.float())

        return output_dict

    def _features_loss(self, hidden, mask, gold_tags, output_dict):
        if gold_tags is None:
            return

        for feature in self.features:
            logits = self.feature_outputs[feature](hidden)
            loss = sequence_cross_entropy_with_logits(logits,
                                                      gold_tags[feature],
                                                      mask,
                                                      label_smoothing=self.label_smoothing)
            loss /= len(self.features)
            output_dict["loss"] += loss

            for metric in self.features_metrics[feature].values():
                metric(logits, gold_tags[feature], mask.float())

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        all_words = output_dict["words"]

        all_predictions = output_dict["class_probabilities"][self.task].cpu().data.numpy()
        if all_predictions.ndim == 3:
            predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        for predictions, words in zip(predictions_list, all_words):
            argmax_indices = numpy.argmax(predictions, axis=-1)
            tags = [self.vocab.get_token_from_index(x, namespace=self.task)
                    for x in argmax_indices]

            # TODO: specific task
            if self.task == "lemmas":
                def decode_lemma(word, rule):
                    if rule == "_":
                        return "_"
                    if rule == "@@UNKNOWN@@":
                        return word
                    return apply_lemma_rule(word, rule)
                tags = [decode_lemma(word, rule) for word, rule in zip(words, tags)]

            all_tags.append(tags)
        output_dict[self.task] = all_tags

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        main_metrics = {
            f".run/{self.task}/{metric_name}": metric.get_metric(reset)
            for metric_name, metric in self.metrics.items()
        }

        features_metrics = {
            f"_run/{self.task}/{feature}/{metric_name}": metric.get_metric(reset)
            for feature in self.features
            for metric_name, metric in self.features_metrics[feature].items()
        }

        return {**main_metrics, **features_metrics}
Ejemplo n.º 9
0
class MultiTagDecoder(Model):
    """
    A basic sequence tagger that decodes from inputs of word embeddings
    """
    def __init__(self,
                 vocab: Vocabulary,
                 task: str,
                 encoder: Seq2SeqEncoder,
                 prev_task: str,
                 prev_task_embed_dim: int = None,
                 label_smoothing: float = 0.0,
                 dropout: float = 0.0,
                 adaptive: bool = False,
                 features: List[str] = None,
                 metric: str = "acc",
                 loss_weight: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 threshold: float = 0.5,
                 max_heads: int = 2,
                 focal_gamma: float = None,
                 focal_alpha: float = None) -> None:
        super(MultiTagDecoder, self).__init__(vocab, regularizer)

        self.task = task
        self.dropout = torch.nn.Dropout(p=dropout)
        self.encoder = encoder
        self.output_dim = encoder.get_output_dim()
        self.label_smoothing = label_smoothing
        self.num_classes = self.vocab.get_vocab_size(task)
        self.adaptive = adaptive
        #self.features = features if features else []
        self.metric = metric

        self._loss3 = torch.nn.BCEWithLogitsLoss()

        self.threshold = threshold
        self.max_heads = max_heads
        self.gamma = focal_gamma
        self.alpha = focal_alpha
        self.loss_weight = loss_weight

        # A: add all possible relative encoding to vocabulary
        if self.vocab.get_token_index('100,root') == 1:
            for head in self.vocab.get_token_to_index_vocabulary('head_tags').keys():
                all_encodings = get_all_relative_encodings(head)
                self.vocab.add_tokens_to_namespace(tokens=all_encodings, namespace='dep_encoded')
            # make sure to put end token '100,root'
            self.vocab.add_token_to_namespace(token='100,root', namespace='dep_encoded')

        self.prev_task_tag_embedding = None
        if prev_task_embed_dim is not None and prev_task_embed_dim is not 0 and prev_task is not None:
            if not prev_task == 'rependency':
                self.prev_task_tag_embedding = Embedding(self.vocab.get_vocab_size(prev_task), prev_task_embed_dim)
            else:
                self.prev_task_tag_embedding = Embedding(self.vocab.get_vocab_size('dep_encoded'), prev_task_embed_dim)

        # Choose the metric to use for the evaluation (from the defined
        # "metric" value of the task). If not specified, default to accuracy.
        if self.metric == "acc":
            self.metrics = {"acc": CategoricalAccuracy()}
        elif self.metric == "multi_span_f1":
            self.metrics = {"multi_span_f1": MultiSpanBasedF1Measure(
                self.vocab, tag_namespace=self.task, label_encoding="BIO", threshold=self.threshold, max_heads=self.max_heads)}
        else:
            logger.warning(f"ERROR. Metric: {self.metric} unrecognized. Using accuracy instead.")
            self.metrics = {"acc": CategoricalAccuracy()}

        if self.adaptive:
            # TODO
            adaptive_cutoffs = [round(self.num_classes / 15), 3 * round(self.num_classes / 15)]
            self.task_output = AdaptiveLogSoftmaxWithLoss(self.output_dim,
                                                          self.num_classes,
                                                          cutoffs=adaptive_cutoffs,
                                                          div_value=4.0)
        else:
            self.task_output = TimeDistributed(Linear(self.output_dim, self.num_classes))

        # self.feature_outputs = torch.nn.ModuleDict()
        # self.features_metrics = {}
        # for feature in self.features:
        #     self.feature_outputs[feature] = TimeDistributed(Linear(self.output_dim,
        #                                                            vocab.get_vocab_size(feature)))
        #     self.features_metrics[feature] = {
        #         "acc": CategoricalAccuracy(),
        #     }

        initializer(self)

    @overrides
    def forward(self,
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                gold_tags: Dict[str, torch.LongTensor],
                prev_task_classes: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, _, _ = encoded_text.size()

        if prev_task_classes is not None and self.prev_task_tag_embedding is not None:
            if prev_task_classes[1]:
                embedded_tags = torch.matmul(prev_task_classes[0], self.prev_task_tag_embedding.weight)
            else:
                prev_embed_size = self.prev_task_tag_embedding.get_output_dim()
                embedded_tags = self.dropout(self.prev_task_tag_embedding(prev_task_classes[0]))
                embedded_tags = embedded_tags.view(batch_size, -1, prev_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_tags], -1)

        hidden = encoded_text
        hidden = self.encoder(hidden, mask)

        batch_size, sequence_length, _ = hidden.size()
        output_dim = [batch_size, sequence_length, self.num_classes]

        #loss_fn = self._adaptive_loss if self.adaptive else self._loss2#self._loss
        loss_fn = self._adaptive_loss if self.adaptive else self._loss

        output_dict = loss_fn(hidden, mask, gold_tags.get(self.task, None), output_dim)
        #self._features_loss(hidden, mask, gold_tags, output_dict)

        return output_dict

    def _adaptive_loss(self, hidden, mask, gold_tags, output_dim):
        logits = hidden
        reshaped_log_probs = logits.view(-1, logits.size(2))

        class_probabilities = self.task_output.log_prob(reshaped_log_probs).view(output_dim)

        output_dict = {"logits": logits, "class_probabilities": class_probabilities}

        if gold_tags is not None:
            output_dict["loss"] = sequence_cross_entropy(class_probabilities,
                                                         gold_tags,
                                                         mask,
                                                         label_smoothing=self.label_smoothing)
            for metric in self.metrics.values():
                metric(class_probabilities, gold_tags, mask.float())

        return output_dict

    def _loss2(self, hidden, mask, gold_tags, output_dim):
        logits = self.task_output(hidden)
        reshaped_log_probs = logits.view(-1, self.num_classes)

        # Use the sigmoid for class_probabilities instead of the softmax
        #class_probabilities = torch.sigmoid(reshaped_log_probs).view(output_dim) #logits)
        class_probabilities = torch.sigmoid(logits)
        # class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(output_dim)

        output_dict = {"logits": logits, "class_probabilities": class_probabilities}

        if gold_tags is not None:
            # Compute the loss
            output_dict["loss"] = self.multi_class_cross_entropy_loss(
                scores=logits, labels=gold_tags, mask=mask
            )

            for metric in self.metrics.values():
                # metric(logits, gold_tags, mask.float())
                metric(class_probabilities, gold_tags, mask.float())

        return output_dict

    def multi_class_cross_entropy_loss(self, scores, labels, mask):
        """
        Compute the loss from
        """
        # Compute the mask before computing the loss
        # Transform the mask that is at the sentence level (#Size: n_batches x padded_document_length)
        # to a suitable format for the relation labels level
        #mask (2x3)
        padded_document_length = mask.size(1) # prendi la seconda dimensione (3)
        mask = mask.float()  # Size: n_batches x padded_document_length (2x3)
        # [e.view(padded_document_length, 1) * e for e in mask] ([3x3, 3x3])
        #squared_mask = torch.stack([e.view(padded_document_length, 1) * e for e in mask], dim=0) (2x3x3)
        #squared_mask = squared_mask.unsqueeze(-1).repeat(
        squared_mask = mask.unsqueeze(-1).repeat(
            #1, 1, 1, self._n_classes
            1, 1, scores.size(-1)
        )  # Size: n_batches x padded_document_length x padded_document_length x n_classes (2x3x3x5)

        # The scores (and gold labels) are flattened before using
        # the binary cross entropy loss.
        # We thus transform
        flat_size = scores.size()
        scores = scores * squared_mask  # Size: n_batches x padded_document_length x padded_document_length x n_classes
        scores_flat = scores.view(
            flat_size[0], flat_size[1] * scores.size(-1)
        #    flat_size[0], flat_size[1], flat_size[2] * self._n_classes
        )  # Size: n_batches x padded_document_length x (padded_document_length x n_classes)
        labels = labels * squared_mask  # Size: n_batches x padded_document_length x padded_document_length x n_classes
        labels_flat = labels.view(
            flat_size[0], flat_size[1] * scores.size(-1)
        #    flat_size[0], flat_size[1], flat_size[2] * self._n_classes
        )  # Size: n_batches x padded_document_length x (padded_document_length x n_classes)

        #loss = self._loss_fn(scores_flat, labels_flat)
        loss = self._loss3(scores_flat, labels_flat)

        # Amplify the loss to actually see something...
        return 100 * loss

    def _loss(self, hidden, mask, gold_tags, output_dim):
        logits = self.task_output(hidden)
        reshaped_log_probs = logits.view(-1, self.num_classes)

        # Use the sigmoid for class_probabilities instead of the softmax
        #class_probabilities = torch.sigmoid(reshaped_log_probs).view(output_dim) #logits)
        class_probabilities = torch.sigmoid(logits)
        # class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(output_dim)

        output_dict = {"logits": logits, "class_probabilities": class_probabilities}

        if gold_tags is not None:
            # Compute the loss
            #output_dict["loss"] = self.multi_class_cross_entropy_loss(
            #    scores=logits, labels=gold_tags, mask=mask
            #)
            output_dict["loss"] = self.loss_weight * self.sequence_cross_entropy_with_logits(logits,
                                                                     gold_tags,
                                                                     mask,
                                                                     label_smoothing=self.label_smoothing,
                                                                     gamma=self.gamma,
                                                                     alpha=self.alpha)

            for metric in self.metrics.values():
                # metric(logits, gold_tags, mask.float())
                metric(class_probabilities, gold_tags, mask.float())

        return output_dict

    def sequence_cross_entropy_with_logits(self,
                                       logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       average: str = "batch",
                                       label_smoothing: float = None,
                                       gamma: float = None,
                                       alpha: Union[float, List[float], torch.FloatTensor] = None
                                      ) -> torch.FloatTensor:
        """
        Computes the cross entropy loss of a sequence, weighted with respect to
        some user provided weights. Note that the weighting here is not the same as
        in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
        classes; here we are weighting the loss contribution from particular elements
        in the sequence. This allows loss computations for models which use padding.

        Parameters
        ----------
        logits : ``torch.FloatTensor``, required.
            A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
            which contains the unnormalized probability for each class.
        targets : ``torch.LongTensor``, required.
            A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
            index of the true class for each corresponding step.
        weights : ``torch.FloatTensor``, required.
            A ``torch.FloatTensor`` of size (batch, sequence_length)
        average: str, optional (default = "batch")
            If "batch", average the loss across the batches. If "token", average
            the loss across each item in the input. If ``None``, return a vector
            of losses per batch element.
        label_smoothing : ``float``, optional (default = None)
            Whether or not to apply label smoothing to the cross-entropy loss.
            For example, with a label smoothing value of 0.2, a 4 class classification
            target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was
            the correct label.
        gamma : ``float``, optional (default = None)
            Focal loss[*] focusing parameter ``gamma`` to reduces the relative loss for
            well-classified examples and put more focus on hard. The greater value
            ``gamma`` is, the more focus on hard examples.
        alpha : ``float`` or ``List[float]``, optional (default = None)
            Focal loss[*] weighting factor ``alpha`` to balance between classes. Can be
            used independently with ``gamma``. If a single ``float`` is provided, it
            is assumed binary case using ``alpha`` and ``1 - alpha`` for positive and
            negative respectively. If a list of ``float`` is provided, with the same
            length as the number of classes, the weights will match the classes.
            [*] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár, "Focal Loss for
            Dense Object Detection," 2017 IEEE International Conference on Computer
            Vision (ICCV), Venice, 2017, pp. 2999-3007.

        Returns
        -------
        A torch.FloatTensor representing the cross entropy loss.
        If ``average=="batch"`` or ``average=="token"``, the returned loss is a scalar.
        If ``average is None``, the returned loss is a vector of shape (batch_size,).

        """
        if average not in {None, "token", "batch"}:
            raise ValueError("Got average f{average}, expected one of "
                             "None, 'token', or 'batch'")


        label_smoothing = None

        # make sure weights are float
        # weights = weights.float()

        # Compute the mask before computing the loss
        # Transform the mask that is at the sentence level (#Size: n_batches x padded_document_length)
        # to a suitable format for the relation labels level
        #mask (2x3)
        padded_document_length = weights.size(1) # prendi la seconda dimensione (3)
        weights = weights.float()  # Size: n_batches x padded_document_length (2x3)



        # Make weights be of the right shape (i.e., extend a dimension to NUM_CLASSES)
        NUM_CLASSES = logits.size(-1)
        #weights = weights.unsqueeze_(-1)
        #weights = weights.expand(weights.shape[0], weights.shape[1], NUM_CLASSES)
        #weights = weights.unsqueeze(2).expand(weights.shape[0], weights.shape[1], NUM_CLASSES)

        # [e.view(padded_document_length, 1) * e for e in mask] ([3x3, 3x3])
        #squared_mask = torch.stack([e.view(padded_document_length, 1) * e for e in mask], dim=0) (2x3x3)
        #squared_mask = squared_mask.unsqueeze(-1).repeat(
        weights = weights.unsqueeze(-1).repeat(
            #1, 1, 1, self._n_classes
            1, 1, logits.size(-1)
        )  # Size: n_batches x padded_document_length x padded_document_length x n_classes (2x3x3x5)


        # sum all dim except batch
        non_batch_dims = tuple(range(1, len(weights.shape)))

        # shape : (batch_size,)
        weights_batch_sum = weights.sum(dim=non_batch_dims)
        weights_batch_sum2 = weights.sum(dim=(1,))[:,0]

        # shape : (batch * sequence_length, num_classes)
        # logits_flat = logits.view(-1, logits.size(-1))

        # Use log_sigmoid instead of log_softmax
        # log_probs_flat = torch.nn.functional.logsigmoid(logits_flat)
        # shape : (batch * sequence_length, num_classes)
        # log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)

        # Make the target handle NUM_CLASSES instead of one-best
        # shape : (batch * max_len, NUM_CLASSES)
        # targets_flat = targets.view(-1, NUM_CLASSES)
        # shape : (batch * max_len, 1)
        # targets_flat = targets.view(-1, 1).long()


        # The scores (and gold labels) are flattened before using
        # the binary cross entropy loss.
        # We thus transform
        flat_size = logits.size()
        logits = logits * weights  # Size: n_batches x padded_document_length x padded_document_length x n_classes
        logits_flat = logits.view(
            flat_size[0], flat_size[1] * logits.size(-1)
        #    flat_size[0], flat_size[1], flat_size[2] * self._n_classes
        )  # Size: n_batches x padded_document_length x (padded_document_length x n_classes)
        targets = targets * weights  # Size: n_batches x padded_document_length x padded_document_length x n_classes
        targets_flat = targets.view(
            flat_size[0], flat_size[1] * logits.size(-1)
        #    flat_size[0], flat_size[1], flat_size[2] * self._n_classes
        )  # Size: n_batches x padded_document_length x (padded_document_length x n_classes)



        # focal loss coefficient
        # if gamma:
        #     # shape : (batch * sequence_length, num_classes)
        #     probs_flat = log_probs_flat.exp()
        #     # shape : (batch * sequence_length,)
        #     probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat)
        #     # shape : (batch * sequence_length,)
        #     focal_factor = (1. - probs_flat) ** gamma
        #     # shape : (batch, sequence_length)
        #     focal_factor = focal_factor.view(*targets.size())
        #     weights = weights * focal_factor

        if alpha is not None:
            # shape : () / (num_classes,)
            if isinstance(alpha, (float, int)):
                # pylint: disable=not-callable
                # shape : (2,)
                alpha_factor = torch.tensor([1. - float(alpha), float(alpha)],
                                            dtype=weights.dtype, device=weights.device)
                # pylint: enable=not-callable
            elif isinstance(alpha, (list, numpy.ndarray, torch.Tensor)):
                # pylint: disable=not-callable
                # shape : (c,)
                alpha_factor = torch.tensor(alpha, dtype=weights.dtype, device=weights.device)
                # pylint: enable=not-callable
                if not alpha_factor.size():
                    # shape : (1,)
                    alpha_factor = alpha_factor.view(1)
                    # shape : (2,)
                    alpha_factor = torch.cat([1 - alpha_factor, alpha_factor])
            else:
                raise TypeError(('alpha must be float, list of float, or torch.FloatTensor, '
                                 '{} provided.').format(type(alpha)))
            # shape : (batch, max_len)
            #alpha_factor = torch.gather(alpha_factor, dim=0, index=targets_flat.view(-1)).view(*targets.size())
            #weights = weights * alpha_factor

        if label_smoothing is not None and label_smoothing > 0.0:
            negative_log_likelihood_ = torch.nn.functional.binary_cross_entropy_with_logits(logits_flat, targets_flat, reduction='none') 

            num_classes = logits.size(-1)
            smoothing_value = label_smoothing / num_classes
            # Fill all the correct indices with 1 - smoothing value.

            #one_hot_targets = torch.zeros_like(negative_log_likelihood_).scatter_(-1, targets_flat.long(), 1.0 - label_smoothing)
            one_hot_targets = targets_flat.clone()
            one_hot_targets[one_hot_targets==1] = 1.0 - label_smoothing
            smoothed_targets = one_hot_targets + smoothing_value
            #negative_log_likelihood_flat = - logits_flat * smoothed_targets
            negative_log_likelihood_ = negative_log_likelihood_ * smoothed_targets

            # Keep all the classes instead of only the best one
            # negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
        else:
            # Contribution to the negative log likelihood only comes from the exact indices
            # of the targets, as the target distributions are one-hot. Here we use torch.gather
            # to extract the indices of the num_classes dimension which contribute to the loss.
            # shape : (batch * sequence_length, 1)
            # negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
            # negative_log_likelihood_flat = - log_probs_flat
            negative_log_likelihood_ = torch.nn.functional.binary_cross_entropy_with_logits(logits_flat, targets_flat, reduction='none') #self._loss3(logits_new, targets_new)
        # shape : (batch, sequence_length)
        # negative_log_likelihood = negative_log_likelihood_.view(*targets.size())
        # negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
        # shape : (batch, sequence_length)
        #negative_log_likelihood = negative_log_likelihood * weights

        if gamma:
            # shape : (batch * sequence_length, num_classes)
            # probs_flat = log_probs_flat.exp()
            probs_flat = negative_log_likelihood_.exp()
            # shape : (batch * sequence_length,)
            # probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat)
            # shape : (batch * sequence_length,)
            focal_factor = (1. - probs_flat) ** gamma
            # shape : (batch, sequence_length)
            focal_factor = focal_factor.view(*targets.size())
            weights = weights * focal_factor

        if alpha is not None:
            # shape : (batch, max_len)
            alpha_factor = torch.gather(alpha_factor, dim=0, index=targets_flat.long().view(-1)).view(*targets.size())
            weights = weights * alpha_factor

        negative_log_likelihood = negative_log_likelihood_.view(*targets.size())
        negative_log_likelihood = negative_log_likelihood * weights


        if average == "batch":
            # shape : (batch_size,)
            per_token_loss = negative_log_likelihood.sum((2,)) / NUM_CLASSES
            #print(per_token_loss, per_token_loss.shape)
            
            #per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (weights_batch_sum + 1e-13)
            per_batch_loss = per_token_loss.sum((1,)) / (weights_batch_sum2 + 1e-13)

            num_non_empty_sequences = ((weights_batch_sum2 > 0).float().sum() + 1e-13)

            return (per_batch_loss.sum() / num_non_empty_sequences) * 100 # amplify it to see something
        elif average == "token":
            return negative_log_likelihood.sum() / (weights_batch_sum.sum() + 1e-13)
        else:
            # shape : (batch_size,)
            per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (weights_batch_sum + 1e-13)
            return per_batch_loss


    # def _features_loss(self, hidden, mask, gold_tags, output_dict):
    #     if gold_tags is None:
    #         return

    #     for feature in self.features:
    #         logits = self.feature_outputs[feature](hidden)
    #         loss = sequence_cross_entropy_with_logits(logits,
    #                                                   gold_tags[feature],
    #                                                   mask,
    #                                                   label_smoothing=self.label_smoothing)
    #         loss /= len(self.features)
    #         output_dict["loss"] += loss

    #         for metric in self.features_metrics[feature].values():
    #             metric(logits, gold_tags[feature], mask.float())

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        all_words = output_dict["words"]

        all_predictions = output_dict["class_probabilities"][self.task].cpu().data.numpy()
        if all_predictions.ndim == 3:
            predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        for predictions, words in zip(predictions_list, all_words):

            # Hard-coded parameters for now
            THRESH = self.threshold
            k = self.max_heads
            outside_index = self.vocab.get_token_index("O", namespace=self.task)

            # Get the thresholded matrix and prepare the prediction sequence
            pred_over_thresh = (predictions >= THRESH) * predictions
            sequence_token_labels = []
            maxxx = numpy.argmax(predictions, axis=-1).tolist()

            # For each label set, check if to apply argmax or sigmoid thresh
            j=0
            for pred in pred_over_thresh:
                num_pred_over_thresh = numpy.count_nonzero(pred)

                if (num_pred_over_thresh == 0) or (num_pred_over_thresh == 1):
                    pred_idx_list = [maxxx[j]]

                elif num_pred_over_thresh <= k:
                    pred_idx_list = list(numpy.argpartition(pred, -num_pred_over_thresh)[-num_pred_over_thresh:])

                    outside_position = -1
                    try:
                        outside_position = pred_idx_list.index(outside_index)
                    except ValueError:
                        outside_position = -1
                    # outside_position = None
                    # for el_i in range(len(pred_idx_list)):
                    #     if pred_idx_list[el_i] == outside_index:
                    #         outside_position = el_i
                    #         break
                    if outside_position != -1:
                        pred_len = len(pred_idx_list)-1
                        # If the last (i.e., the best) is "O", ignore/remove the others
                        if outside_position == pred_len:
                            pred_idx_list = [pred_idx_list[-1]]
                        # O.w. get only from the last before the "O"
                        else:
                            # del pred_idx_list[outside_position]
                            pred_idx_list = pred_idx_list[outside_position+1:]

                else:
                    pred_idx_list = list(numpy.argpartition(pred, -k)[-k:])

                    outside_position = -1
                    try:
                        outside_position = pred_idx_list.index(outside_index)
                    except ValueError:
                        outside_position = -1
                    # outside_position = None
                    # for el_i in range(len(pred_idx_list)):
                    #     if pred_idx_list[el_i] == outside_index:
                    #         outside_position = el_i
                    #         break
                    if outside_position != -1:
                        pred_len = len(pred_idx_list)-1
                        # If the last (i.e., the best) is "O", ignore/remove the others
                        if outside_position == pred_len:
                            pred_idx_list = [pred_idx_list[-1]]
                        # O.w. get only from the last before the "O"
                        else:
                            # del pred_idx_list[outside_position]
                            pred_idx_list = pred_idx_list[outside_position+1:]


                # if num_pred_over_thresh < k:
                #     pred_idx_list = [maxxx[j]]
                #     # print("argmax  ->", pred_idx_list)
                # else:
                #     #pred_idx_list = [maxxx[j]]
                #     pred_idx_list = list(numpy.argpartition(pred, -k)[-k:])
                #     # # print("sigmoid ->", pred_idx_list)

                #     # # If the first (i.e., second best) is "O", ignore/remove it
                #     if pred_idx_list[0] == outside_index:
                #         pred_idx_list = pred_idx_list[1:]
                #     # If the second (i.e., the best) is "O", ignore/remove the first
                #     elif pred_idx_list[1] == outside_index:
                #         pred_idx_list = pred_idx_list[1:]
                #     else:
                #         pass

                sequence_token_labels.append(pred_idx_list)
                j += 1

            # Create the list of tags to append for the output
            tags = []
            for token_labels in sequence_token_labels:
                curr_labels = []
                for token_label in token_labels:
                    curr_labels.append(
                        self.vocab.get_token_from_index(token_label, namespace=self.task))
                tags.append(curr_labels)
            # print(tags)

            # argmax_indices = numpy.argmax(predictions, axis=-1)
            # tags = [self.vocab.get_token_from_index(x, namespace=self.task)
            #         for x in argmax_indices]

            # # TODO: specific task
            # if self.task == "lemmas":
            #     def decode_lemma(word, rule):
            #         if rule == "_":
            #             return "_"
            #         if rule == "@@UNKNOWN@@":
            #             return word
            #         return apply_lemma_rule(word, rule)
            #     tags = [decode_lemma(word, rule) for word, rule in zip(words, tags)]

            all_tags.append(tags)
        output_dict[self.task] = all_tags

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        main_metrics = {
            f".run/{self.task}/{metric_name}": metric.get_metric(reset)
            for metric_name, metric in self.metrics.items()
        }

        return {**main_metrics}
Ejemplo n.º 10
0
class TagDecoder(Model):
    """
    A basic sequence tagger that decodes from inputs of word embeddings
    """
    def __init__(self,
                 vocab: Vocabulary,
                 task: str,
                 encoder: Seq2SeqEncoder,
                 prev_task: str,
                 prev_task_embed_dim: int = None,
                 label_smoothing: float = 0.0,
                 dropout: float = 0.0,
                 adaptive: bool = False,
                 features: List[str] = None,
                 metric: str = "acc",
                 loss_weight: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(TagDecoder, self).__init__(vocab, regularizer)

        self.task = task
        self.dropout = torch.nn.Dropout(p=dropout)
        self.encoder = encoder
        self.output_dim = encoder.get_output_dim()
        self.label_smoothing = label_smoothing
        self.num_classes = self.vocab.get_vocab_size(task)
        self.adaptive = adaptive
        self.features = features if features else []
        self.metric = metric
        self.loss_weight = loss_weight

        # A: add all possible relative encoding to vocabulary
        if self.vocab.get_token_index('100,root') == 1:
            for head in self.vocab.get_token_to_index_vocabulary('head_tags').keys():
                all_encodings = get_all_relative_encodings(head)
                self.vocab.add_tokens_to_namespace(tokens=all_encodings, namespace='dep_encoded')
            # make sure to put end token '100,root'
            self.vocab.add_token_to_namespace(token='100,root', namespace='dep_encoded')

        self.prev_task_tag_embedding = None
        if prev_task_embed_dim is not None and prev_task_embed_dim is not 0 and prev_task is not None:
            if not prev_task == 'rependency':
                self.prev_task_tag_embedding = Embedding(self.vocab.get_vocab_size(prev_task), prev_task_embed_dim)
            else:
                self.prev_task_tag_embedding = Embedding(self.vocab.get_vocab_size('dep_encoded'), prev_task_embed_dim)

        # Choose the metric to use for the evaluation (from the defined
        # "metric" value of the task). If not specified, default to accuracy.
        if self.metric == "acc":
            self.metrics = {"acc": CategoricalAccuracy()}
        elif self.metric == "span_f1":
            self.metrics = {"span_f1": SpanBasedF1Measure(
                self.vocab, tag_namespace=self.task, label_encoding="BIO")}
        else:
            logger.warning(f"ERROR. Metric: {self.metric} unrecognized. Using accuracy instead.")
            self.metrics = {"acc": CategoricalAccuracy()}

        if self.adaptive:
            # TODO
            adaptive_cutoffs = [round(self.num_classes / 15), 3 * round(self.num_classes / 15)]
            self.task_output = AdaptiveLogSoftmaxWithLoss(self.output_dim,
                                                          self.num_classes,
                                                          cutoffs=adaptive_cutoffs,
                                                          div_value=4.0)
        else:
            self.task_output = TimeDistributed(Linear(self.output_dim, self.num_classes))

        self.feature_outputs = torch.nn.ModuleDict()
        self.features_metrics = {}
        for feature in self.features:
            self.feature_outputs[feature] = TimeDistributed(Linear(self.output_dim,
                                                                   vocab.get_vocab_size(feature)))
            self.features_metrics[feature] = {
                "acc": CategoricalAccuracy(),
            }

        initializer(self)

    @overrides
    def forward(self,
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                gold_tags: Dict[str, torch.LongTensor],
                prev_task_classes: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, _, _ = encoded_text.size()

        if prev_task_classes is not None and self.prev_task_tag_embedding is not None:
            if prev_task_classes[1]:
                embedded_tags = torch.matmul(prev_task_classes[0], self.prev_task_tag_embedding.weight)
            else:
                prev_embed_size = self.prev_task_tag_embedding.get_output_dim()
                embedded_tags = self.dropout(self.prev_task_tag_embedding(prev_task_classes[0]))
                embedded_tags = embedded_tags.view(batch_size, -1, prev_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_tags], -1)

        hidden = encoded_text
        hidden = self.encoder(hidden, mask)

        batch_size, sequence_length, _ = hidden.size()
        output_dim = [batch_size, sequence_length, self.num_classes]

        loss_fn = self._adaptive_loss if self.adaptive else self._loss

        output_dict = loss_fn(hidden, mask, gold_tags.get(self.task, None), output_dim)
        self._features_loss(hidden, mask, gold_tags, output_dict)

        return output_dict

    def _adaptive_loss(self, hidden, mask, gold_tags, output_dim):
        logits = hidden
        reshaped_log_probs = logits.view(-1, logits.size(2))

        class_probabilities = self.task_output.log_prob(reshaped_log_probs).view(output_dim)

        output_dict = {"logits": logits, "class_probabilities": class_probabilities}

        if gold_tags is not None:
            output_dict["loss"] = sequence_cross_entropy(class_probabilities,
                                                         gold_tags,
                                                         mask,
                                                         label_smoothing=self.label_smoothing)
            for metric in self.metrics.values():
                metric(class_probabilities, gold_tags, mask.float())

        return output_dict

    def _loss(self, hidden, mask, gold_tags, output_dim):
        logits = self.task_output(hidden)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        # print(reshaped_log_probs, reshaped_log_probs.shape)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(output_dim)
        # print(class_probabilities, class_probabilities.shape)
        # import sys
        # sys.exit()

        output_dict = {"logits": logits, "class_probabilities": class_probabilities}

        if gold_tags is not None:
            output_dict["loss"] = self.loss_weight * sequence_cross_entropy_with_logits(logits,
                                                                     gold_tags,
                                                                     mask,
                                                                     label_smoothing=self.label_smoothing)
            for metric in self.metrics.values():
                metric(logits, gold_tags, mask.float())

        return output_dict

    def _features_loss(self, hidden, mask, gold_tags, output_dict):
        if gold_tags is None:
            return

        for feature in self.features:
            logits = self.feature_outputs[feature](hidden)
            loss = sequence_cross_entropy_with_logits(logits,
                                                      gold_tags[feature],
                                                      mask,
                                                      label_smoothing=self.label_smoothing)
            loss /= len(self.features)
            output_dict["loss"] += loss

            for metric in self.features_metrics[feature].values():
                metric(logits, gold_tags[feature], mask.float())

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        all_words = output_dict["words"]

        all_predictions = output_dict["class_probabilities"][self.task].cpu().data.numpy()
        if all_predictions.ndim == 3:
            predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        for predictions, words in zip(predictions_list, all_words):
            argmax_indices = numpy.argmax(predictions, axis=-1)
            tags = [self.vocab.get_token_from_index(x, namespace=self.task)
                    for x in argmax_indices]

            # TODO: specific task
            if self.task == "lemmas":
                def decode_lemma(word, rule):
                    if rule == "_":
                        return "_"
                    if rule == "@@UNKNOWN@@":
                        return word
                    return apply_lemma_rule(word, rule)
                tags = [decode_lemma(word, rule) for word, rule in zip(words, tags)]

            all_tags.append(tags)
        output_dict[self.task] = all_tags

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        main_metrics = {
            f".run/{self.task}/{metric_name}": metric.get_metric(reset)
            for metric_name, metric in self.metrics.items()
        }

        features_metrics = {
            f"_run/{self.task}/{feature}/{metric_name}": metric.get_metric(reset)
            for feature in self.features
            for metric_name, metric in self.features_metrics[feature].items()
        }

        return {**main_metrics, **features_metrics}
Ejemplo n.º 11
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 edge_prediction_threshold: float = 0.5,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(GraphParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.edge_prediction_threshold = edge_prediction_threshold
        if not 0 < edge_prediction_threshold < 1:
            raise ConfigurationError(f"edge_prediction_threshold must be between "
                                     f"0 and 1 (exclusive) but found {edge_prediction_threshold}.")

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("labels")
        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = BilinearMatrixAttention(tag_representation_dim,
                                                    tag_representation_dim,
                                                    label_dim=num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim", "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim", "arc feedforward output dim")

        self._unlabelled_f1 = F1Measure(positive_label=1)
        self._arc_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        self._tag_loss = torch.nn.CrossEntropyLoss(reduction='none')
        initializer(self)
Ejemplo n.º 12
0
class DependencyDecoder(Model):
    """
    Modifies BiaffineDependencyParser, removing the input TextFieldEmbedder dependency to allow the model to
    essentially act as a decoder when given intermediate word embeddings instead of as a standalone model.
    """

    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 pos_embed_dim: int = None,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DependencyDecoder, self).__init__(vocab, regularizer)

        self.pos_tag_embedding = None
        if pos_embed_dim is not None:
            self.pos_tag_embedding = Embedding(self.vocab.get_vocab_size("upos"), pos_embed_dim)

        self.dropout = torch.nn.Dropout(p=dropout)

        self.encoder = encoder
        encoder_output_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_output_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_output_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._dropout = InputVariationalDropout(dropout)
        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder_output_dim]))

        check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim", "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim", "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                # words: Dict[str, torch.LongTensor],
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                pos_logits: torch.LongTensor = None,  # predicted
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, _, _ = encoded_text.size()

        pos_tags = None
        if pos_logits is not None and self.pos_tag_embedding is not None:
            # Embed the predicted POS tags and concatenate the embeddings to the input
            num_pos_classes = pos_logits.size(-1)
            pos_logits = pos_logits.view(-1, num_pos_classes)
            _, pos_tags = pos_logits.max(-1)

            pos_embed_size = self.pos_tag_embedding.get_output_dim()
            embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags))
            embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1)

        encoded_text = self.encoder(encoded_text, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
            "words": [meta["words"] for meta in metadata],
            # "pos": [meta["pos"] for meta in metadata]
        }

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

        head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
        heads = output_dict.pop("heads").cpu().detach().numpy()
        mask = output_dict.pop("mask")
        lengths = get_lengths_from_binary_sequence_mask(mask)
        head_tag_labels = []
        head_indices = []
        for instance_heads, instance_tags, length in zip(heads, head_tags, lengths):
            instance_heads = list(instance_heads[1:length])
            instance_tags = instance_tags[1:length]
            labels = [self.vocab.get_token_from_index(label, "head_tags")
                      for label in instance_tags]
            head_tag_labels.append(labels)
            head_indices.append(instance_heads)

        output_dict["predicted_dependencies"] = head_tag_labels
        output_dict["predicted_heads"] = head_indices
        return output_dict

    def _construct_loss(self,
                        head_tag_representation: torch.Tensor,
                        child_tag_representation: torch.Tensor,
                        attended_arcs: torch.Tensor,
                        head_indices: torch.Tensor,
                        head_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.
        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(attended_arcs,
                                                   mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
        normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
                                                        mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll

    def _greedy_decode(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       attended_arcs: torch.Tensor,
                       mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs independently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).byte().unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags

    def _mst_decode(self,
                    head_tag_representation: torch.Tensor,
                    child_tag_representation: torch.Tensor,
                    attended_arcs: torch.Tensor,
                    mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size()

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2)

        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.
        batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits)
        return self._run_mst_decoding(batch_energy, lengths)

    @staticmethod
    def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))

    def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.
        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits

    def _get_mask_for_eval(self,
                           mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.
        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.
        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {f".run/deps/{metric_name}": metric
                for metric_name, metric in self._attachment_scores.get_metric(reset).items()}
Ejemplo n.º 13
0
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            encoder: Seq2SeqEncoder,
            arc_representation_dim: int,
            tag_representation_dim: int,
            rank: int,
            capsule_dim: int,
            iter_num: int,
            arc_feedforward: FeedForward = None,
            tag_feedforward: FeedForward = None,
            pos_tag_embedding: Embedding = None,
            #dep_tag_embedding: Embedding = None,
            predicate_embedding: Embedding = None,
            delta_type: str = "hinge_ce",
            subtract_gold: bool = False,
            dropout: float = 0.0,
            input_dropout: float = 0.0,
            edge_prediction_threshold: float = 0.5,
            gumbel_t: float = 1,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None,
            double_loss: bool = True,
            base_average: bool = False,
            bilinear_matrix_capsule: bool = True,
            using_global: bool = False,
            passing_type: str = 'plain',
            global_node: bool = False,
            comments: str = "") -> None:
        super(SRLGraphParserBase, self).__init__(vocab, regularizer)
        self.capsule_dim = capsule_dim
        num_labels = self.vocab.get_vocab_size("arc_types")
        # print("num_labels", num_labels)

        if global_node == True:
            self.get_global_layer = Plain_Feedforward(
                (num_labels + 1) * capsule_dim, capsule_dim,
                Activation.by_name('relu')())
            self.bilinear_matrix_capsule_layer_for_global_node = BilinearMatrix(
                capsule_dim, capsule_dim)
        self.global_node = global_node

        if using_global == True:
            self.capsule_dim = int(self.capsule_dim / 2)
            if passing_type == 'plain':
                self.get_global_layer = Plain_Feedforward(
                    (num_labels + 1) * capsule_dim,
                    (num_labels + 1) * self.capsule_dim,
                    Activation.by_name('relu')())
            elif passing_type == 'attention':
                self.get_global_layer = Attention_Feedforward(
                    self.capsule_dim, capsule_dim, self.capsule_dim)
            else:
                self.get_global_layer = None
        self.using_global = using_global
        self.passing_type = passing_type

        self.iter_num = iter_num
        self.double_loss = double_loss
        self.base_average = base_average
        self.bilinear_matrix_capsule = bilinear_matrix_capsule

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.subtract_gold = subtract_gold
        self.edge_prediction_threshold = edge_prediction_threshold
        if not 0 < edge_prediction_threshold < 1:
            raise ConfigurationError(
                f"edge_prediction_threshold must be between "
                f"0 and 1 (exclusive) but found {edge_prediction_threshold}.")

    #   print ("predicates",self.vocab._index_to_token["predicates"])
    #   print ("arc_types",self.vocab._index_to_token["arc_types"])
        self.delta_type = delta_type

        self.gumbel_t = gumbel_t
        node_dim = predicate_embedding.get_output_dim()
        encoder_dim = encoder.get_output_dim()
        #self.arg_arc_feedforward = arc_feedforward or \
        #                           FeedForward(encoder_dim, 1,
        #                                       arc_representation_dim,
        #                                       Activation.by_name("elu")())
        #self.pred_arc_feedforward = copy.deepcopy(self.arg_arc_feedforward)

        #self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
        #arc_representation_dim,
        #label_dim=capsule_dim,
        #use_input_biases=True)

        self.arg_tag_feedforward = tag_feedforward or \
                                   FeedForward(encoder_dim, 1,
                                               tag_representation_dim,
                                               Activation.by_name("elu")())
        self.pred_tag_feedforward = copy.deepcopy(self.arg_tag_feedforward)

        self.tag_bilinear = BilinearMatrixAttention_Lowrank(
            tag_representation_dim,
            tag_representation_dim,
            rank,
            label_dim=(num_labels + 1) * self.capsule_dim,
            use_input_biases=True)  #,activation=Activation.by_name("tanh")()
        if self.bilinear_matrix_capsule == True:
            self.bilinear_matrix_capsule_layer = BilinearMatrix(
                capsule_dim, capsule_dim)
        self.predicte_feedforward = FeedForward(encoder_dim, 1, node_dim,
                                                Activation.by_name("elu")())
        self._pos_tag_embedding = pos_tag_embedding or None
        #self._dep_tag_embedding = dep_tag_embedding or None
        self._pred_embedding = predicate_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)

        #   check_dimensions_match(representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim")

        self._labelled_f1 = IterativeLabeledF1Measure(
            negative_label=0,
            negative_pred=0,
            selected_metrics=["F", "l_F", "p_F", "u_F"])
        self._tag_loss = torch.nn.NLLLoss(reduction="none")  # ,ignore_index=-1
        self._sense_loss = torch.nn.NLLLoss(
            reduction="none")  # ,ignore_index=-1
        initializer(self)
Ejemplo n.º 14
0
    def __init__(self,
                 options,
                 tag_dim,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(None, regularizer)

        self.device = options.device

        encoder = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(tag_dim,
                          options.lstm_dims,
                          batch_first=True,
                          bidirectional=True))
        # encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(tag_dim, options.lstm_dims, batch_first=True))
        self.encoder = encoder
        # TODO: IMPORTANT
        num_labels = options.num_labels
        self.ablation = options.ablation
        # print(num_labels)
        tag_representation_dim = options.tag_representation_dim  # 100
        arc_representation_dim = options.arc_representation_dim  # 200

        encoder_dim = self.encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)


        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = tag_dim  #text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        self._pos_to_ignore = set()
        # tags = self.vocab.get_token_to_index_vocabulary("pos")
        # punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        # self._pos_to_ignore = set(punctuation_tag_indices.values())
        # logger.info(f"Found POS tags correspoding to the following punctuation : {punctuation_tag_indices}. "
        #             "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Ejemplo n.º 15
0
class SentimentClassifier(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 label_smoothing: float = 0.1,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SentimentClassifier, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder

        share_rnn = nn.LSTM(input_size=self._text_field_embedder.get_output_dim(),
                            hidden_size=150,
                            batch_first=True,
                            # dropout=dropout,
                            bidirectional=True)
        share_encoder = PytorchSeq2SeqWrapper(share_rnn)

        self._encoder = RNNEncoder(vocab, share_encoder, input_dropout, regularizer)
        self._seq_vec = CnnEncoder(self._encoder.get_output_dim(), 25)
        self._de_dim = len(TASKS_NAME)
        weight = torch.empty(self._de_dim, self._text_field_embedder.get_output_dim())
        torch.nn.init.orthogonal_(weight)
        self._domain_embeddings = Embedding(self._de_dim, self._text_field_embedder.get_output_dim(), weight=weight)
        self._de_attention = BilinearAttention(self._seq_vec.get_output_dim(),
                                               self._domain_embeddings.get_output_dim())
        self._de_feedforward = FeedForward(self._domain_embeddings.get_output_dim(), 1,
                                           self._seq_vec.get_output_dim(), Activation.by_name("elu")())

        self._num_classes = self.vocab.get_vocab_size("label")
        self._sentiment_discriminator = Discriminator(self._seq_vec.get_output_dim(), self._num_classes)
        self._s_domain_discriminator = Discriminator(self._seq_vec.get_output_dim(), len(TASKS_NAME))
        self._valid_discriminator = Discriminator(self._domain_embeddings.get_output_dim(), 2)
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._label_smoothing = label_smoothing

        self.metrics = {
            "s_domain_acc": CategoricalAccuracy(),
            "valid_acc": CategoricalAccuracy()
        }
        for task_name in TASKS_NAME:
            self.metrics["{}_stm_acc".format(task_name)] = CategoricalAccuracy()

        self._loss = torch.nn.CrossEntropyLoss()
        self._domain_loss = torch.nn.CrossEntropyLoss()
        # TODO torch.nn.BCELoss
        self._valid_loss = torch.nn.BCEWithLogitsLoss()

        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                task_index: torch.IntTensor,
                reverse: torch.ByteTensor,
                for_training: torch.ByteTensor,
                train_stage: torch.IntTensor,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        :param task_index:
        :param reverse:
        :param for_training:
        :param train_stage: ["share_senti", "share_classify",
        "share_classify_adversarial", "domain_valid", "domain_valid_adversarial"]
        :param tokens:
        :param label:
        :return:
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()
        embed_tokens = self._encoder(embedded_text, mask)
        batch_size = get_batch_size(embed_tokens)
        # bs * (25*4)
        seq_vec = self._seq_vec(embed_tokens, mask)
        # TODO add linear layer

        domain_embeddings = self._domain_embeddings(torch.arange(self._de_dim).cuda())

        de_scores = F.softmax(
            self._de_attention(seq_vec, domain_embeddings.expand(batch_size, *domain_embeddings.size())), dim=1)
        de_valid = False
        if np.random.rand() < 0.3:
            de_valid = True
            noise = 0.01 * torch.normal(mean=0.5,
                                        # std=torch.std(domain_embeddings).sign_())
                                        std=torch.empty(*de_scores.size()).fill_(1.0))
            de_scores = de_scores + noise.cuda()
        domain_embedding = torch.matmul(de_scores, domain_embeddings)
        domain_embedding = self._de_feedforward(domain_embedding)
        # train sentiment classify
        if train_stage.cpu() == torch.tensor(0) or not for_training:

            de_representation = torch.tanh(torch.add(domain_embedding, seq_vec))

            sentiment_logits = self._sentiment_discriminator(de_representation)
            if label is not None:
                loss = self._loss(sentiment_logits, label)
                self.metrics["{}_stm_acc".format(TASKS_NAME[task_index.cpu()])](sentiment_logits, label)

        if train_stage.cpu() == torch.tensor(1) or not for_training:
            s_domain_logits = self._s_domain_discriminator(seq_vec, reverse=reverse)
            task_index = task_index.expand(batch_size)
            loss = self._domain_loss(s_domain_logits, task_index)
            self.metrics["s_domain_acc"](s_domain_logits, task_index)

        if train_stage.cpu() == torch.tensor(2) or not for_training:
            valid_logits = self._valid_discriminator(domain_embedding, reverse=reverse)
            valid_label = torch.ones(batch_size).cuda()
            if de_valid:
                valid_label = torch.zeros(batch_size).cuda()
            if self._label_smoothing is not None and self._label_smoothing > 0.0:
                loss = sequence_cross_entropy_with_logits(valid_logits,
                                                          valid_label.unsqueeze(0).cuda(),
                                                          torch.tensor(1).unsqueeze(0).cuda(),
                                                          average="token",
                                                          label_smoothing=self._label_smoothing)
            else:
                loss = self._valid_loss(valid_logits,
                                        torch.zeros(2).scatter_(0, valid_label, torch.tensor(1.0)).cuda())
            self.metrics["valid_acc"](valid_logits, valid_label)
        # TODO add orthogonal loss
        output_dict = {"loss": loss}

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does a simple argmax over the class probabilities, converts indices to string labels, and
        adds a ``"label"`` key to the dictionary with the result.
        """
        class_probabilities = F.softmax(output_dict['logits'], dim=-1)
        output_dict['class_probabilities'] = class_probabilities

        predictions = class_probabilities.cpu().data.numpy()
        argmax_indices = np.argmax(predictions, axis=-1)
        labels = [self.vocab.get_token_from_index(x, namespace="label")
                  for x in argmax_indices]
        output_dict['label'] = labels
        return output_dict

    @overrides
    def get_metrics(self, task_name: str, reset: bool = False) -> Dict[str, float]:
        return {metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items() if
                (task_name or "s_domain_acc" or "valid_acc") in metric_name}
Ejemplo n.º 16
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        tag_representation_dim: int,
        arc_representation_dim: int,
        activation=Activation.by_name("tanh")(),
        lemma_tag_embedding: Embedding = None,
        upos_tag_embedding: Embedding = None,
        xpos_tag_embedding: Embedding = None,
        feats_tag_embedding: Embedding = None,
        dropout: float = 0.0,
        input_dropout: float = 0.0,
        edge_prediction_threshold: float = 0.5,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.activation = activation
        self.edge_prediction_threshold = edge_prediction_threshold
        if not 0 < edge_prediction_threshold < 1:
            raise ConfigurationError(
                f"edge_prediction_threshold must be between "
                f"0 and 1 (exclusive) but found {edge_prediction_threshold}.")

        encoder_dim = encoder.get_output_dim()

        # these two matrices together form the feed forward network which takes the vectors of the two words in question and makes predictions from that
        # this is the trick described by Kiperwasser and Goldberg to make training faster.
        self.edge_head = Linear(encoder_dim, arc_representation_dim)
        self.edge_dep = Linear(
            encoder_dim, arc_representation_dim,
            bias=False)  # bias is already added by edge_head

        self.tag_head = Linear(encoder_dim, tag_representation_dim)
        self.tag_dep = Linear(encoder_dim, tag_representation_dim, bias=False)

        num_labels = self.vocab.get_vocab_size("deps")

        self.arc_out_layer = Linear(
            arc_representation_dim, 1,
            bias=False)  # no bias in output layer of K&G model
        self.tag_out_layer = Linear(arc_representation_dim, num_labels)

        self._lemma_tag_embedding = lemma_tag_embedding or None
        self._upos_tag_embedding = upos_tag_embedding or None
        self._xpos_tag_embedding = xpos_tag_embedding or None
        self._feats_tag_embedding = feats_tag_embedding or None

        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)

        # add a head sentinel to accommodate for extra root token
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if lemma_tag_embedding is not None:
            representation_dim += lemma_tag_embedding.get_output_dim()
        if upos_tag_embedding is not None:
            representation_dim += upos_tag_embedding.get_output_dim()
        if xpos_tag_embedding is not None:
            representation_dim += xpos_tag_embedding.get_output_dim()
        if feats_tag_embedding is not None:
            representation_dim += feats_tag_embedding.get_output_dim()

        check_dimensions_match(
            representation_dim,
            encoder.get_input_dim(),
            "text field embedding dim",
            "encoder input dim",
        )

        self._enhanced_attachment_scores = EnhancedAttachmentScores()
        self._arc_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
        self._tag_loss = torch.nn.CrossEntropyLoss(reduction="none")
        initializer(self)
Ejemplo n.º 17
0
Archivo: ud.py Proyecto: lgessler/embur
    def __init__(
        self,
        vocab: Vocabulary,
        embedding_dim: int,
        tag_representation_dim: int,
        arc_representation_dim: int,
        encoder: Seq2SeqEncoder = None,
        tag_feedforward: FeedForward = None,
        arc_feedforward: FeedForward = None,
        pos_tag_embedding: Embedding = None,
        use_mst_decoding_for_validation: bool = True,
        dropout: float = 0.0,
        input_dropout: float = 0.0,
        pos_namespace: str = "xpos_tags",
        deprel_namespace: str = "deprel_labels",
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        self.encoder = encoder

        encoder_dim = (
            encoder.get_output_dim() if encoder is not None
            else embedding_dim + pos_tag_embedding.get_output_dim()
        )

        self.head_arc_feedforward = arc_feedforward or FeedForward(
            encoder_dim, 1, arc_representation_dim, Activation.by_name("elu")()
        )
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(
            arc_representation_dim, arc_representation_dim, use_input_biases=True
        )

        self.pos_namespace = pos_namespace
        self.deprel_namespace = deprel_namespace
        num_labels = self.vocab.get_vocab_size(deprel_namespace)

        self.head_tag_feedforward = tag_feedforward or FeedForward(
            encoder_dim, 1, tag_representation_dim, Activation.by_name("elu")()
        )
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(
            tag_representation_dim, tag_representation_dim, num_labels
        )

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder_dim]))

        representation_dim = embedding_dim
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        if self.encoder is not None:
            check_dimensions_match(
                representation_dim,
                encoder.get_input_dim(),
                "text field embedding dim",
                "encoder input dim",
            )

        check_dimensions_match(
            tag_representation_dim,
            self.head_tag_feedforward.get_output_dim(),
            "tag representation dim",
            "tag feedforward output dim",
        )
        check_dimensions_match(
            arc_representation_dim,
            self.head_arc_feedforward.get_output_dim(),
            "arc representation dim",
            "arc feedforward output dim",
        )

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary(self.pos_namespace)
        punctuation_tag_indices = {
            tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation."
        )

        self._attachment_scores = AttachmentScores()
        initializer(self)
Ejemplo n.º 18
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 arc_representation_dim: int,
                 tag_representation_dim: int,
                 r_lambda: float = 1e-2,
                 normalize: bool = False,
                 arc_feedforward: FeedForward = None,
                 tag_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 dep_tag_embedding: Embedding = None,
                 predicate_embedding: Embedding = None,
                 delta_type: str = "hinge_ce",
                 subtract_gold: float = 0.0,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 gumbel_t: float = 0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SRLGraphParserBase, self).__init__(vocab, regularizer)
        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.r_lambda = r_lambda
        self.normalize = normalize
        self.as_base = False
        #   print ("predicates",self.vocab._index_to_token["predicates"])
        #   print ("tags",self.vocab._index_to_token["tags"])
        self.subtract_gold = subtract_gold
        self.delta_type = delta_type
        num_labels = self.vocab.get_vocab_size("tags")
        print("num_labels", num_labels)
        self.gumbel_t = gumbel_t
        node_dim = predicate_embedding.get_output_dim()
        encoder_dim = encoder.get_output_dim()
        self.arg_arc_feedforward = arc_feedforward or \
                                   FeedForward(encoder_dim, 1,
                                               arc_representation_dim,
                                               Activation.by_name("elu")())
        self.pred_arc_feedforward = copy.deepcopy(self.arg_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        self.arg_tag_feedforward = tag_feedforward or \
                                   FeedForward(encoder_dim, 1,
                                               tag_representation_dim,
                                               Activation.by_name("elu")())
        self.pred_tag_feedforward = copy.deepcopy(self.arg_tag_feedforward)

        self.tag_bilinear = BilinearMatrixAttention(
            tag_representation_dim,
            tag_representation_dim,
            label_dim=num_labels,
            use_input_biases=True)  #,activation=Activation.by_name("tanh")()

        self.predicte_feedforward = FeedForward(encoder_dim, 1, node_dim,
                                                Activation.by_name("elu")())
        self._pos_tag_embedding = pos_tag_embedding or None
        self._dep_tag_embedding = dep_tag_embedding or None
        self._pred_embedding = predicate_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)

        #   check_dimensions_match(representation_dim, encoder.get_input_dim(), "text field embedding dim", "encoder input dim")

        self._labelled_f1 = IterativeLabeledF1Measure(
            negative_label=0,
            negative_pred=0,
            selected_metrics=["F", "p_F", "l_P", "l_R"])
        self._tag_loss = torch.nn.NLLLoss(reduction="none")  # ,ignore_index=-1
        self._sense_loss = torch.nn.NLLLoss(
            reduction="none")  # ,ignore_index=-1
        initializer(self)
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        tag_representation_dim: int,
        arc_representation_dim: int,
        model_name: str = None,
        tag_feedforward: FeedForward = None,
        arc_feedforward: FeedForward = None,
        pos_tag_embedding: Embedding = None,
        use_mst_decoding_for_validation: bool = True,
        dropout: float = 0.0,
        input_dropout: float = 0.0,
        word_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        if model_name:
            from src.data.token_indexers import PretrainedAutoTokenizer
            self._tokenizer = PretrainedAutoTokenizer.load(model_name)

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or FeedForward(
            encoder_dim, 1, arc_representation_dim,
            Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or FeedForward(
            encoder_dim, 1, tag_representation_dim,
            Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._word_dropout = word_dropout
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(
            representation_dim,
            encoder.get_input_dim(),
            "text field embedding dim",
            "encoder input dim",
        )

        check_dimensions_match(
            tag_representation_dim,
            self.head_tag_feedforward.get_output_dim(),
            "tag representation dim",
            "tag feedforward output dim",
        )
        check_dimensions_match(
            arc_representation_dim,
            self.head_arc_feedforward.get_output_dim(),
            "arc representation dim",
            "arc feedforward output dim",
        )

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
Ejemplo n.º 20
0
class MSPointerNetwork(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder_1: TextFieldEmbedder,
                 source_encoder_1: Seq2SeqEncoder,
                 beam_size: int,
                 max_decoding_steps: int,
                 decoder_output_dim: int,
                 target_embedding_dim: int = 30,
                 namespace: str = "tokens",
                 tensor_based_metric: Metric = None,
                 align_embeddings: bool = True,
                 source_embedder_2: TextFieldEmbedder = None,
                 source_encoder_2: Seq2SeqEncoder = None) -> None:
        super().__init__(vocab)
        self._source_embedder_1 = source_embedder_1
        self._source_embedder_2 = source_embedder_1 or self._source_embedder_1
        self._source_encoder_1 = source_encoder_1
        self._source_encoder_2 = source_encoder_2 or self._source_encoder_1

        self._source_namespace = namespace
        self._target_namespace = namespace

        self.encoder_output_dim_1 = self._source_encoder_1.get_output_dim()
        self.encoder_output_dim_2 = self._source_encoder_2.get_output_dim()
        self.cated_encoder_out_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2
        self.decoder_output_dim = decoder_output_dim

        # TODO: AllenNLP实现的Addictive Attention可能没有bias
        self._attention_1 = AdditiveAttention(self.decoder_output_dim,
                                              self.encoder_output_dim_1)
        self._attention_2 = AdditiveAttention(self.decoder_output_dim,
                                              self.encoder_output_dim_2)

        if not align_embeddings:
            self.target_embedding_dim = target_embedding_dim
            self._target_vocab_size = self.vocab.get_vocab_size(
                namespace=self._target_namespace)
            self._target_embedder = Embedding(self._target_vocab_size,
                                              target_embedding_dim)
        else:
            self._target_embedder = self._source_embedder_1._token_embedders[
                "tokens"]
            self._target_vocab_size = self.vocab.get_vocab_size(
                namespace=self._target_namespace)
            self.target_embedding_dim = self._target_embedder.get_output_dim()

        self.decoder_input_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2 + \
                                 self.target_embedding_dim

        self._decoder_cell = LSTMCell(self.decoder_input_dim,
                                      self.decoder_output_dim)

        # 用于将两个encoder的最后隐层状态映射成解码器初始状态
        self._encoder_out_projection_layer = torch.nn.Linear(
            in_features=self.cated_encoder_out_dim,
            out_features=self.decoder_output_dim
        )  #  TODO: bias - true of false?

        # 软门控机制参数,用于计算lambda
        self._gate_projection_layer = torch.nn.Linear(
            in_features=self.decoder_output_dim + self.decoder_input_dim,
            out_features=1,
            bias=False)

        self._start_index = self.vocab.get_token_index(START_SYMBOL, namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     namespace)
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        self._tensor_based_metric = tensor_based_metric or \
            BLEU(exclude_indices={self._pad_index, self._end_index, self._start_index})

    def _encode(
            self, source_tokens_1: Dict[str, torch.Tensor],
            source_tokens_2: Dict[str,
                                  torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        分别将source1和source2的token ids经过encoder编码,输出各自的mask以及encoder_out。
        同时token_ids信息也会附加。
        """

        # 1. 编码source1
        # shape: (batch_size, seq_max_len_1)
        source_mask_1 = util.get_text_field_mask(source_tokens_1)
        # shape: (batch_size, seq_max_len_1, encoder_input_dim_1)
        embedder_out_1 = self._source_embedder_1(source_tokens_1)
        # shape: (batch_size, seq_max_len_1, encoder_output_dim_1)
        encoder_out_1 = self._source_encoder_1(embedder_out_1, source_mask_1)

        # 2. 编码source2
        # shape: (batch_size, seq_max_len_2)
        source_mask_2 = util.get_text_field_mask(source_tokens_2)
        # shape: (batch_size, seq_max_len_2, encoder_input_dim_2)
        embedder_out_2 = self._source_embedder_2(source_tokens_2)
        # shape: (batch_size, seq_max_len_2, encoder_input_dim_2)
        encoder_out_2 = self._source_encoder_2(embedder_out_2, source_mask_2)

        return {
            "source_mask_1": source_mask_1,
            "source_mask_2": source_mask_2,
            "source_token_ids_1": source_tokens_1["tokens"],
            "source_token_ids_2": source_tokens_2["tokens"],
            "encoder_out_1": encoder_out_1,
            "encoder_out_2": encoder_out_2,
        }

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        初始化decoder:更新传入的state,使之带有decoder的context和hidden向量。
                      其中hidden向量(h_0)通过两个编码器的最终隐层状态经过一个
                      映射得到,context初始化为0向量。
        """
        batch_size = state["encoder_out_1"].size()[0]

        # 根据每个batch的mask情况,获取最终rnn隐层状态
        # shape: (batch_size, encoder_output_dim_1)
        encoder_final_output_1 = util.get_final_encoder_states(
            state["encoder_out_1"], state["source_mask_1"],
            self._source_encoder_1.is_bidirectional())
        # shape: (batch_size, encoder_output_dim_2)
        encoder_final_output_2 = util.get_final_encoder_states(
            state["encoder_out_2"], state["source_mask_2"],
            self._source_encoder_2.is_bidirectional())

        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = torch.relu(
            self._encoder_out_projection_layer(
                torch.cat([encoder_final_output_1, encoder_final_output_2],
                          dim=-1)))
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["decoder_hidden"].new_zeros(
            batch_size, self.decoder_output_dim)

        return state

    @overrides
    def forward(
        self,
        source_tokens_1: Dict[str, torch.LongTensor],
        source_tokens_2: Dict[str, torch.LongTensor],
        metadata: List[Dict[str, Any]],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:

        # 分成训练、验证/测试、预测,这三种情况分别考虑

        # 1. 训练时:必然同时提供了target_tokens作为ground truth。
        #    此时,只需要计算loss,无需beam search
        if self.training:
            assert target_tokens is not None

            state = self._encode(source_tokens_1, source_tokens_2)
            state["target_token_ids"] = target_tokens["tokens"]
            state = self._init_decoder_state(state)
            output_dict = self._forward_loss(target_tokens, state)
            output_dict["metadata"] = metadata
            return output_dict  # 包含loss、metadata两项

        # 2. 验证/测试时:self.training为false,但是提供了target_tokens。
        #    此时,需要计算loss、运行beam search、计算评价指标
        elif target_tokens:

            # 计算loss
            state = self._encode(source_tokens_1, source_tokens_2)
            state["target_token_ids"] = target_tokens["tokens"]
            state = self._init_decoder_state(state)
            output_dict = self._forward_loss(target_tokens, state)

            # 运行beam search
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

            # 计算评价指标(BLEU)
            if self._tensor_based_metric is not None:
                # shape: (batch_size, beam_size, max_decoding_steps)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_decoding_steps)
                best_predictions = top_k_predictions[:, 0, :]
                # shape: (batch_size, target_seq_len)
                gold_tokens = target_tokens["tokens"]
                self._tensor_based_metric(best_predictions, gold_tokens)
            output_dict["metadata"] = metadata
            return output_dict  # 包含loss、metadata、top-k、top-k log prob四项

        # 3. 预测时:self.training为false,同时也没有提供target_tokens。
        #    此时,只需要运行beam search执行top-k预测即可
        else:
            state = self._encode(source_tokens_1, source_tokens_2)
            state = self._init_decoder_state(state)
            output_dict = {"metadata": metadata}
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            return output_dict  # 包含metadata、top-k、top-k log prob三项

    def _forward_loss(
            self, target_tokens: Dict[str, torch.Tensor],
            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        为输入的一个batch计算损失(仅在训练时调用)。
        """
        batch_size, target_seq_len = target_tokens["tokens"].size()

        # shape: (batch_size, seq_max_len_1)
        source_mask_1 = state["source_mask_1"]
        # shape: (batch_size, seq_max_len_2)
        source_mask_2 = state["source_mask_2"]

        # 需要生成的最大步数永远比目标序列(<start> ... <end>)的最大长度少1步
        num_decoding_steps = target_seq_len - 1

        step_log_likelihoods = []  # 存放每个时间步,目标词的log似然值
        for timestep in range(num_decoding_steps):  # t: 0..T

            # 当前时刻要输入的token id,shape (batch_size,)
            input_choices = target_tokens["tokens"][:, timestep]

            # 更新一步解码器状态(计算各类中间变量,例如attention分数、软门控分数)
            state = self._decoder_step(input_choices, state)

            # 获取decoder_hidden相对于两个source的attention分数
            # shape: (batch_size, seq_max_len_1)
            attentive_weights_1 = state["attentive_weights_1"]
            # shape: (batch_size, seq_max_len_2)
            attentive_weights_2 = state["attentive_weights_2"]

            # 计算target_to_source,指明当前要输出的target (ground truth),是否出现在source之中
            # shape: (batch_size, seq_max_len_1)
            target_to_source_1 = (state["source_token_ids_1"] ==
                                  state["target_token_ids"][:, timestep +
                                                            1].unsqueeze(-1))
            # shape: (batch_size, seq_max_len_2)
            target_to_source_2 = (state["source_token_ids_2"] ==
                                  state["target_token_ids"][:, timestep +
                                                            1].unsqueeze(-1))

            # 根据上面的信息计算当前时间步target token的对数似然
            step_log_likelihood = self._get_ll_contrib(
                attentive_weights_1, attentive_weights_2, source_mask_1,
                source_mask_2, target_to_source_1, target_to_source_2,
                state["target_token_ids"][:,
                                          timestep + 1], state["gate_score"])
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))

        # 将各个时间步的对数似然合并成一个tensor
        # shape: (batch_size, num_decoding_steps = target_seq_len - 1)
        log_likelihoods = torch.cat(step_log_likelihoods, 1)

        # 获取包含START和END的target mask
        # shape: (batch_size, target_seq_len)
        target_mask = util.get_text_field_mask(target_tokens)

        # 去掉第一个,不会作为目标词的START
        # shape: (batch_size, num_decoding_steps = target_seq_len - 1)
        target_mask = target_mask[:, 1:].float()

        # 将各个时间步上的对数似然tensor使用mask累加,得到整个时间序列的对数似然
        log_likelihood = (log_likelihoods * target_mask).sum(dim=-1)

        loss = -log_likelihood.sum() / batch_size

        return {"loss": loss}

    def _decoder_step(
            self, last_predictions: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        更新一步decoder状态。
        """

        # shape: (group_size, seq_max_len_1, encoder_output_dim_1)
        source_mask_1 = state["source_mask_1"].float()
        # shape: (group_size, seq_max_len_2, encoder_output_dim_2)
        source_mask_2 = state["source_mask_2"].float()
        # y_{t-1}, shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        # a_t, shape: (group_size, seq_max_len_1)
        state["attentive_weights_1"] = self._attention_1(
            state["decoder_hidden"], state["encoder_out_1"], source_mask_1)
        # a'_t, shape: (group_size, seq_max_len_2)
        state["attentive_weights_2"] = self._attention_2(
            state["decoder_hidden"], state["encoder_out_2"], source_mask_2)

        # c_t, shape: (group_size, encoder_output_dim_1)
        attentive_read_1 = util.weighted_sum(state["encoder_out_1"],
                                             state["attentive_weights_1"])
        # c'_t, shape: (group_size, encoder_output_dim_2)
        attentive_read_2 = util.weighted_sum(state["encoder_out_2"],
                                             state["attentive_weights_2"])

        # 计算软门控机制:lambda
        # shape: (group_size, target_embedding_dim + encoder_output_dim_1 + encoder_output_dim_2 + decoder_output_dim)
        gate_input = torch.cat((embedded_input, attentive_read_1,
                                attentive_read_2, state["decoder_hidden"]),
                               dim=-1)
        # shape: (group_size,)
        gate_projected = self._gate_projection_layer(gate_input).squeeze(-1)
        # shape: (group_size,)
        state["gate_score"] = torch.sigmoid(gate_projected)

        # shape: (group_size, target_embedding_dim + encoder_output_dim_1 + encoder_output_dim_2)
        decoder_input = torch.cat(
            (embedded_input, attentive_read_1, attentive_read_2), dim=-1)

        # 更新decoder状态(hidden和context/cell)
        state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
            decoder_input, (state["decoder_hidden"], state["decoder_context"]))

        return state

    def _get_ll_contrib(self, copy_scores_1: torch.Tensor,
                        copy_scores_2: torch.Tensor,
                        source_mask_1: torch.Tensor,
                        source_mask_2: torch.Tensor,
                        target_to_source_1: torch.Tensor,
                        target_to_source_2: torch.Tensor,
                        target_tokens: torch.Tensor,
                        gate_score: torch.Tensor) -> torch.Tensor:
        """
        根据一个时间步的attention分数、黄金token,计算黄金token的对数似然。

        参数:
            - copy_scores_1:对第一个source的注意力分值。
                    shape: (batch_size, seq_max_len_1)
            - copy_scores_2:对第二个source的注意力分值。
                    shape: (batch_size, seq_max_len_2)
            - source_mask_1:第一个source的mask
                    shape: (batch_size, seq_max_len_1)
            - source_mask_2:第二个source的mask
                    shape: (batch_size, seq_max_len_2)
            - target_to_source_1:目标词是否为第一个source对应位置的词
                    shape: (batch_size, seq_max_len_1)
            - target_to_source_2:目标词是否为第二个source对应位置的词
                    shape: (batch_size, seq_max_len_2)
            - target_tokens:当前时间步的目标词
                    shape: (batch_size,)
            - gate_score:从第一个source拷贝词语的概率(0-1之间)
                    shape: (batch_size,)

        返回:
            当前时间步,生成目标词的对数似然(log-likelihood)
                    shape: (batch_size,)
        """
        # 计算第一个source的分值
        # shape: (batch_size, seq_max_len_1)
        combined_log_probs_1 = (copy_scores_1 + 1e-45).log() + (
            target_to_source_1.float() +
            1e-45).log() + (source_mask_1.float() + 1e-45).log()
        # shape: (batch_size,)
        log_probs_1 = util.logsumexp(
            combined_log_probs_1)  # log(exp(a[0]) + ... + exp(a[L]))

        # 计算第二个source的分值
        # shape: (batch_size, seq_max_len_2)
        combined_log_probs_2 = (copy_scores_2 + 1e-45).log() + (
            target_to_source_2.float() +
            1e-45).log() + (source_mask_2.float() + 1e-45).log()
        # shape: (batch_size,)
        log_probs_2 = util.logsumexp(
            combined_log_probs_2)  # log(exp(a[0]) + ... + exp(a[L]))

        # 计算 log(p1 * gate + p2 * (1-gate))
        log_gate_score_1 = gate_score.log()  # shape: (batch_size,)
        log_gate_score_2 = (1 - gate_score).log()  # shape: (batch_size,)
        item_1 = (log_gate_score_1 + log_probs_1).unsqueeze(
            -1)  # shape: (batch_size, 1)
        item_2 = (log_gate_score_2 + log_probs_2).unsqueeze(
            -1)  # shape: (batch_size, 1)
        step_log_likelihood = util.logsumexp(torch.cat(
            (item_1, item_2), -1))  # shape: (batch_size,)
        return step_log_likelihood

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask_1"].size()[0]
        start_predictions = state["source_mask_1"].new_full(
            (batch_size, ), fill_value=self._start_index)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_search_step)
        return {
            "predicted_log_probs": log_probabilities,
            "predictions": all_top_k_predictions
        }

    def take_search_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        用于beam_search。

        参数:
            - last_predictions:上一时间步的预测结果
                    shape: (group_size,)
            - state:状态
        
        返回:
            - final_log_probs:在全词表上的对数似然
                    shape: (group_size, target_vocab_size)
            - state:更新后的状态

        说明:该函数用于提供给Beam Search使用,输入为上一个时间步的预测id(last_predictions,
              初始为start_index),输出为全词表上的对数似然概率(final_log_probs)。
        
        TODO: 考虑OOV情况(需要整体大改)
        """
        # 更新一步decoder状态
        state = self._decoder_step(last_predictions, state)

        # 对第一个source的拷贝概率值,shape: (group_size, seq_max_len_1)
        copy_scores_1 = state["attentive_weights_1"]
        # 对第二个source的拷贝概率值,shape: (group_size, seq_max_len_2)
        copy_scores_2 = state["attentive_weights_2"]
        # 概率值的门控,shape: (group_size,)
        gate_score = state["gate_score"]

        # 计算全词表上的对数似然
        final_log_probs = self._gather_final_log_probs(copy_scores_1,
                                                       copy_scores_2,
                                                       gate_score, state)

        return final_log_probs, state

    def _gather_final_log_probs(
            self, copy_scores_1: torch.Tensor, copy_scores_2: torch.Tensor,
            gate_score: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        根据三个概率,计算全词表上的对数似然。

        参数:
            - copy_scores_1:第一个source的复制概率(经过归一化)
                    shape: (group_size, seq_max_len_1)
            - copy_scores_2:第二个source的复制概率(经过归一化)
                    shape: (group_size, seq_max_len_2)
            - gate_score:门控的分数,决定source1共享多少比例(source2即贡献1-gate_score)
                    shape: (group_size,)
            - state:当前时间步,更新后的解码状态
        
        返回:
            - final_log_probs:全词表上的概率
                    shape: (group_size, target_vocab_size)
        """
        # 获取group_size和两个序列的长度
        group_size, seq_max_len_1 = copy_scores_1.size()
        group_size, seq_max_len_2 = copy_scores_2.size()

        # TODO: 这里默认了source和target使用同一个词表映射,否则需要source2target的映射
        #      (即source词在target词表的index),才能进行匹配
        # shape: (group_size, seq_max_len_1)
        source_token_ids_1 = state["source_token_ids_1"]
        # shape: (group_size, seq_max_len_2)
        source_token_ids_2 = state["source_token_ids_2"]

        # 在序列上扩展gate_score
        # 需要和source1相乘的gate概率,shape: (group_size, seq_max_len_1)
        gate_1 = gate_score.expand(seq_max_len_1, -1).t()
        # 需要和source2相乘的gate概率,shape: (group_size, seq_max_len_2)
        gate_2 = (1 - gate_score).expand(seq_max_len_2, -1).t()

        # 加权后的source1分值,shape: (group_size, seq_max_len_1)
        copy_scores_1 = copy_scores_1 * gate_1
        # 加权后的source2分值,shape: (group_size, seq_max_len_2)
        copy_scores_2 = copy_scores_2 * gate_2

        # shape: (group_size, seq_max_len_1)
        log_probs_1 = (copy_scores_1 + 1e-45).log()
        # shape: (group_size, seq_max_len_2)
        log_probs_2 = (copy_scores_2 + 1e-45).log()

        # 初始化全词表上的概率为全0, shape: (group_size, target_vocab_size)
        final_log_probs = (state["decoder_hidden"].new_zeros(
            (group_size, self._target_vocab_size)) + 1e-45).log()

        for i in range(seq_max_len_1):  # 遍历source1的所有时间步
            # 当前时间步的预测概率,shape: (group_size, 1)
            log_probs_slice = log_probs_1[:, i].unsqueeze(-1)
            # 当前时间步的token ids,shape: (group_size, 1)
            source_to_target_slice = source_token_ids_1[:, i].unsqueeze(-1)

            # 选出要更新位置,原有的词表概率,shape: (group_size, 1)
            selected_log_probs = final_log_probs.gather(
                -1, source_to_target_slice)
            # 更新后的概率值(原有概率+更新概率,混合),shape: (group_size, 1)
            combined_scores = util.logsumexp(
                torch.cat((selected_log_probs, log_probs_slice),
                          dim=-1)).unsqueeze(-1)
            # 将combined_scores设置回final_log_probs中
            final_log_probs = final_log_probs.scatter(-1,
                                                      source_to_target_slice,
                                                      combined_scores)

        # 对source2也同样做一遍
        for i in range(seq_max_len_2):
            log_probs_slice = log_probs_2[:, i].unsqueeze(-1)
            source_to_target_slice = source_token_ids_2[:, i].unsqueeze(-1)
            selected_log_probs = final_log_probs.gather(
                -1, source_to_target_slice)
            combined_scores = util.logsumexp(
                torch.cat((selected_log_probs, log_probs_slice),
                          dim=-1)).unsqueeze(-1)
            final_log_probs = final_log_probs.scatter(-1,
                                                      source_to_target_slice,
                                                      combined_scores)

        return final_log_probs

    def _get_predicted_tokens(
            self,
            predicted_indices: Union[torch.Tensor, numpy.ndarray],
            batch_metadata: List[Any],
            n_best: int = None) -> List[Union[List[List[str]], List[str]]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        predicted_tokens: List[Union[List[List[str]], List[str]]] = []
        for top_k_predictions, metadata in zip(predicted_indices,
                                               batch_metadata):
            batch_predicted_tokens: List[List[str]] = []
            for indices in top_k_predictions[:n_best]:
                tokens: List[str] = []
                indices = list(indices)
                if self._end_index in indices:
                    indices = indices[:indices.index(self._end_index)]
                for index in indices:
                    token = self.vocab.get_token_from_index(
                        index, self._target_namespace)
                    tokens.append(token)
                batch_predicted_tokens.append(tokens)
            if n_best == 1:
                predicted_tokens.append(batch_predicted_tokens[0])
            else:
                predicted_tokens.append(batch_predicted_tokens)
        return predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        将预测结果(tensor)解码成token序列。
        """
        predicted_tokens = self._get_predicted_tokens(
            output_dict["predictions"], output_dict["metadata"])
        output_dict["predicted_tokens"] = predicted_tokens
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            if self._tensor_based_metric is not None:
                all_metrics.update(
                    self._tensor_based_metric.get_metric(reset=reset))
        return all_metrics
Ejemplo n.º 21
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        tag_representation_dim: int,
        arc_representation_dim: int,
        activation = Activation.by_name("tanh")(),
        tag_feedforward: FeedForward = None,
        arc_feedforward: FeedForward = None,
        pos_tag_embedding: Embedding = None,
        use_mst_decoding_for_validation: bool = False,
        dropout: float = 0.0,
        input_dropout: float = 0.0,
        edge_prediction_threshold: float = 0.5,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        
        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.activation = activation

        encoder_dim = encoder.get_output_dim()

        # edge FeedForward
        self.head_arc_feedforward = arc_feedforward or FeedForward(
            encoder_dim, 1, arc_representation_dim, Activation.by_name("tanh")()
        )
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)
        
        # label FeedForward
        self.head_tag_feedforward = tag_feedforward or FeedForward(
            encoder_dim, 1, tag_representation_dim, Activation.by_name("tanh")()
        )
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)
        
        self.arc_out_layer = Linear(arc_representation_dim, 1)

        num_labels = self.vocab.get_vocab_size("head_tags")
        self.tag_out_layer = Linear(arc_representation_dim, num_labels)
    
        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        
        # add a head sentinel to accommodate for extra root token
        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder.get_output_dim()]))
            
        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        
        check_dimensions_match(
            representation_dim,
            encoder.get_input_dim(),
            "text field embedding dim",
            "encoder input dim",
        )
        
        check_dimensions_match(
            tag_representation_dim,
            self.head_tag_feedforward.get_output_dim(),
            "tag representation dim",
            "tag feedforward output dim",
        )
        
        check_dimensions_match(
            arc_representation_dim,
            self.head_arc_feedforward.get_output_dim(),
            "arc representation dim",
            "arc feedforward output dim",
        )
        
        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation
        
        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation."
        )

        self._attachment_scores = AttachmentScores()    
        initializer(self)
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        tag_representation_dim: int,
        arc_representation_dim: int,
        tag_feedforward: FeedForward = None,
        arc_feedforward: FeedForward = None,
        pos_tag_embedding: Embedding = None,
        dropout: float = 0.0,
        input_dropout: float = 0.0,
        edge_prediction_threshold: float = 0.5,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.edge_prediction_threshold = edge_prediction_threshold
        if not 0 < edge_prediction_threshold < 1:
            raise ConfigurationError(
                f"edge_prediction_threshold must be between "
                f"0 and 1 (exclusive) but found {edge_prediction_threshold}.")

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or FeedForward(
            encoder_dim, 1, arc_representation_dim,
            Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("labels")
        self.head_tag_feedforward = tag_feedforward or FeedForward(
            encoder_dim, 1, tag_representation_dim,
            Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = BilinearMatrixAttention(tag_representation_dim,
                                                    tag_representation_dim,
                                                    label_dim=num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(
            representation_dim,
            encoder.get_input_dim(),
            "text field embedding dim",
            "encoder input dim",
        )
        check_dimensions_match(
            tag_representation_dim,
            self.head_tag_feedforward.get_output_dim(),
            "tag representation dim",
            "tag feedforward output dim",
        )
        check_dimensions_match(
            arc_representation_dim,
            self.head_arc_feedforward.get_output_dim(),
            "arc representation dim",
            "arc feedforward output dim",
        )

        self._unlabelled_f1 = F1Measure(positive_label=1)
        self._arc_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
        self._tag_loss = torch.nn.CrossEntropyLoss(reduction="none")
        initializer(self)
Ejemplo n.º 23
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder_0: Seq2SeqEncoder,
                 encoder_1: Seq2SeqEncoder,
                 encoder_2: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 use_layer_normalization: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParser, self).__init__(vocab, regularizer)

        a = vocab.get_index_to_token_vocabulary(namespace='tokens')
        # glyph_config['idx2word'] = {k: v for k, v in a.items()}

        # self.glyph = GlyphEmbedding(glyph_config)

        self.text_field_embedder = text_field_embedder

        self.encoder_0 = encoder_0
        self.encoder_1 = encoder_1
        self.encoder_2 = encoder_2

        encoder_dim = self.encoder_2.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        # self._dropout = Dropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, self.encoder_2.get_output_dim()]))

        self.use_layer_normalization = use_layer_normalization

        if use_layer_normalization:
            self.norm_input = torch.nn.LayerNorm(
                self.encoder_0.get_input_dim())
            self.norm_hidden = torch.nn.LayerNorm(
                self.encoder_0.get_output_dim())

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        # check_dimensions_match(representation_dim, encoder.get_input_dim(),
        #                        "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 treebank_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 use_treebank_embedding: bool = False,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineDependencyParserMonolingual, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._treebank_embedding = treebank_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()
        if treebank_embedding is not None:
            representation_dim += treebank_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim", "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim", "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation
        self.use_treebank_embedding = use_treebank_embedding

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE}
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
                    "Ignoring words with these POS tags for evaluation.")
        
        if self.use_treebank_embedding:
            tbids = self.vocab.get_token_to_index_vocabulary("tbids")
            tbid_indices = {tb: index for tb, index in tbids.items()}
            self._tbids = set(tbid_indices.values())
            logger.info(f"Found TBIDs corresponding to the following treebanks : {tbid_indices}. "
                        "Embedding these as additional features.")

        self._attachment_scores = AttachmentScores()
        initializer(self)
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 edge_prediction_threshold: float = 0.5,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(EnhancedParser, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.edge_prediction_threshold = edge_prediction_threshold
        if not 0 < edge_prediction_threshold < 1:
            raise ConfigurationError(
                f"edge_prediction_threshold must be between "
                f"0 and 1 (exclusive) but found {edge_prediction_threshold}.")

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("labels")
        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = BilinearMatrixAttention(tag_representation_dim,
                                                    tag_representation_dim,
                                                    label_dim=num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)

        # add a head sentinel to accommodate for extra root token
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        # the unlabelled_f1 is confirmed the same from both classes
        self._unlabelled_f1 = F1Measure(positive_label=1)
        self._enhanced_attachment_scores = EnhancedAttachmentScores()

        self._arc_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        self._tag_loss = torch.nn.CrossEntropyLoss(reduction='none')
        initializer(self)
Ejemplo n.º 26
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        tag_representation_dim: int,
        arc_representation_dim: int,
        tag_feedforward: FeedForward = None,
        arc_feedforward: FeedForward = None,
        lemma_tag_embedding: Embedding = None,
        upos_tag_embedding: Embedding = None,
        xpos_tag_embedding: Embedding = None,
        feats_tag_embedding: Embedding = None,
        head_information_embedding: Embedding = None,
        head_tag_embedding: Embedding = None,
        dropout: float = 0.0,
        input_dropout: float = 0.0,
        edge_prediction_threshold: float = 0.5,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.edge_prediction_threshold = edge_prediction_threshold
        if not 0 < edge_prediction_threshold < 1:
            raise ConfigurationError(
                f"edge_prediction_threshold must be between "
                f"0 and 1 (exclusive) but found {edge_prediction_threshold}.")

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or FeedForward(
            encoder_dim, 1, arc_representation_dim,
            Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("deps")
        self.head_tag_feedforward = tag_feedforward or FeedForward(
            encoder_dim, 1, tag_representation_dim,
            Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = BilinearMatrixAttention(tag_representation_dim,
                                                    tag_representation_dim,
                                                    label_dim=num_labels)

        self._lemma_tag_embedding = lemma_tag_embedding or None
        self._upos_tag_embedding = upos_tag_embedding or None
        self._xpos_tag_embedding = xpos_tag_embedding or None
        self._feats_tag_embedding = feats_tag_embedding or None
        self._head_tag_embedding = head_tag_embedding or None
        self._head_information_embedding = head_information_embedding or None

        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)

        # add a head sentinel to accommodate for extra root token in EUD graphs
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if lemma_tag_embedding is not None:
            representation_dim += lemma_tag_embedding.get_output_dim()
        if upos_tag_embedding is not None:
            representation_dim += upos_tag_embedding.get_output_dim()
        if xpos_tag_embedding is not None:
            representation_dim += xpos_tag_embedding.get_output_dim()
        if feats_tag_embedding is not None:
            representation_dim += feats_tag_embedding.get_output_dim()
        if head_tag_embedding is not None:
            representation_dim += head_tag_embedding.get_output_dim()
        if head_information_embedding is not None:
            representation_dim += head_information_embedding.get_output_dim()

        check_dimensions_match(
            representation_dim,
            encoder.get_input_dim(),
            "text field embedding dim",
            "encoder input dim",
        )
        check_dimensions_match(
            tag_representation_dim,
            self.head_tag_feedforward.get_output_dim(),
            "tag representation dim",
            "tag feedforward output dim",
        )
        check_dimensions_match(
            arc_representation_dim,
            self.head_arc_feedforward.get_output_dim(),
            "arc representation dim",
            "arc feedforward output dim",
        )

        self._enhanced_attachment_scores = EnhancedAttachmentScores()
        self._arc_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
        self._tag_loss = torch.nn.CrossEntropyLoss(reduction="none")
        initializer(self)
Ejemplo n.º 27
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 lemmatize_helper: LemmatizeHelper,
                 task_config: TaskConfig,
                 morpho_vector_dim: int = 0,
                 gram_val_representation_dim: int = -1,
                 lemma_representation_dim: int = -1,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(DependencyParser, self).__init__(vocab, regularizer)

        self.TopNCnt = 3

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.lemmatize_helper = lemmatize_helper
        self.task_config = task_config

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    arc_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                        FeedForward(encoder_dim, 1,
                                                    tag_representation_dim,
                                                    Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        assert self.task_config.params.get("use_pos_tag",
                                           False) == (self._pos_tag_embedding
                                                      is not None)

        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        if gram_val_representation_dim <= 0:
            self._gram_val_output = torch.nn.Linear(
                encoder_dim, self.vocab.get_vocab_size("grammar_value_tags"))
        else:
            self._gram_val_output = torch.nn.Sequential(
                Dropout(dropout),
                torch.nn.Linear(encoder_dim, gram_val_representation_dim),
                Dropout(dropout),
                torch.nn.Linear(
                    gram_val_representation_dim,
                    self.vocab.get_vocab_size("grammar_value_tags")))

        if lemma_representation_dim <= 0:
            self._lemma_output = torch.nn.Linear(encoder_dim,
                                                 len(lemmatize_helper))
        else:
            # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ
            #actual_input_dim = encoder_dim
            actual_input_dim = encoder_dim + self.vocab.get_vocab_size(
                "grammar_value_tags")
            self._lemma_output = torch.nn.Sequential(
                Dropout(dropout),
                torch.nn.Linear(actual_input_dim, lemma_representation_dim),
                Dropout(dropout),
                torch.nn.Linear(lemma_representation_dim,
                                len(lemmatize_helper)))

        representation_dim = text_field_embedder.get_output_dim(
        ) + morpho_vector_dim
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        check_dimensions_match(representation_dim, encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info("HELLO FROM INIT")
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()
        self._gram_val_prediction_accuracy = CategoricalAccuracy()
        self._lemma_prediction_accuracy = CategoricalAccuracy()

        initializer(self)