Ejemplo n.º 1
0
 def __init__(self, token_embedder: TextFieldEmbedder,
              encoder: Seq2SeqEncoder,
              span_decoder: Topdown_Span_Parser_Factory,
              span_extractor: SpanExtractor,
              remote_parser: Remote_Parser_Factory, evaluator: any,
              vocab: Vocabulary) -> None:
     super().__init__(vocab)
     self.token_embedder = token_embedder
     self.encoder = encoder
     self.span_decoder = span_decoder(span_extractor.get_output_dim(),
                                      vocab)
     self.span_extractor = span_extractor
     self.remote_parser = remote_parser(span_extractor.get_output_dim(),
                                        vocab)
     self.evaluator = evaluator
 def test_bidirectional_endpoint_span_extractor_can_build_from_params(self):
     params = Params({
         "type": "bidirectional_endpoint",
         "input_dim": 4,
         "num_width_embeddings": 5,
         "span_width_embedding_dim": 3,
     })
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, BidirectionalEndpointSpanExtractor)
     assert extractor.get_output_dim() == 2 + 2 + 3
 def test_endpoint_span_extractor_can_build_from_params(self):
     params = Params({
             "type": "endpoint",
             "input_dim": 7,
             "num_width_embeddings": 5,
             "span_width_embedding_dim": 3
             })
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, EndpointSpanExtractor)
     assert extractor.get_output_dim() == 10
 def test_bidirectional_endpoint_span_extractor_can_build_from_params(self):
     params = Params({
             "type": "bidirectional_endpoint",
             "input_dim": 4,
             "num_width_embeddings": 5,
             "span_width_embedding_dim": 3
             })
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, BidirectionalEndpointSpanExtractor)
     assert extractor.get_output_dim() == 2 + 2 + 3
 def test_locally_normalised_span_extractor_can_build_from_params(self):
     params = Params({
         "type": "self_attentive",
         "input_dim": 7,
         "num_width_embeddings": 5,
         "span_width_embedding_dim": 3,
     })
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, SelfAttentiveSpanExtractor)
     assert extractor.get_output_dim(
     ) == 10  # input_dim + span_width_embedding_dim
Ejemplo n.º 6
0
    def __init__(self,
                 subword_embeddings: TextFieldEmbedder,
                 subword_aggregator: SpanExtractor,
                 vocab: Vocabulary = None,
                 freeze_encoder: bool = True):
        super().__init__(vocab)

        self.subword_embeddings = subword_embeddings
        self._freeze_encoder = freeze_encoder
        # turn off gradients if don't want to fine tune encoder
        for parameter in self.subword_embeddings.parameters():
            parameter.requires_grad = not self._freeze_encoder

        self.subword_aggregator = subword_aggregator

        self.classifier = TimeDistributed(
            torch.nn.Linear(in_features=subword_aggregator.get_output_dim(),
                            out_features=vocab.get_vocab_size("labels")))

        self.accuracy = CategoricalAccuracy()
Ejemplo n.º 7
0
 def from_params(cls, vocab, params: Params) -> 'SpanAe':
     source_embedder_params = params.pop("source_embedder")
     source_embedder = TextFieldEmbedder.from_params(
         vocab, source_embedder_params)
     encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))
     max_decoding_steps = params.pop("max_decoding_steps")
     target_namespace = params.pop("target_namespace", "tokens")
     # If no attention function is specified, we should not use attention, not attention with
     # default similarity function.
     attention_function_type = params.pop("attention_function", None)
     if attention_function_type is not None:
         attention_function = SimilarityFunction.from_params(
             attention_function_type)
     else:
         attention_function = None
     scheduled_sampling_ratio = params.pop_float("scheduled_sampling_ratio",
                                                 0.0)
     spans_extractor, spans_scorer_feedforward = None, None
     spans_extractor_params = params.pop("span_extractor", None)
     if spans_extractor_params is not None:
         spans_extractor = SpanExtractor.from_params(spans_extractor_params)
     spans_scorer_params = params.pop("span_scorer_feedforward", None)
     if spans_scorer_params is not None:
         spans_scorer_feedforward = FeedForward.from_params(
             spans_scorer_params)
     spans_per_word = params.pop_float("spans_per_word")
     params.assert_empty(cls.__name__)
     return cls(vocab,
                source_embedder=source_embedder,
                encoder=encoder,
                max_decoding_steps=max_decoding_steps,
                spans_per_word=spans_per_word,
                target_namespace=target_namespace,
                attention_function=attention_function,
                scheduled_sampling_ratio=scheduled_sampling_ratio,
                spans_extractor=spans_extractor,
                spans_scorer_feedforward=spans_scorer_feedforward)
 def test_locally_normalised_span_extractor_can_build_from_params(self):
     params = Params({"type": "self_attentive", "input_dim": 5})
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, SelfAttentiveSpanExtractor)
 def test_locally_normalised_span_extractor_can_build_from_params(self):
     params = Params({"type": "self_attentive", "input_dim": 5})
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, SelfAttentiveSpanExtractor)
Ejemplo n.º 10
0
    def __init__(self,
                 vocab: Vocabulary,
                 params: Params,
                 regularizer: RegularizerApplicator = None):
        super(JointDCS, self).__init__(vocab=vocab, regularizer=regularizer)

        # Base text Field Embedder
        text_field_embedder_params = params.pop("text_field_embedder")
        text_field_embedder = BasicTextFieldEmbedder.from_params(
            vocab=vocab, params=text_field_embedder_params)
        self._text_field_embedder = text_field_embedder

        # Encoder
        encoder_params = params.pop("encoder")
        encoder = Seq2SeqEncoder.from_params(encoder_params)
        self._encoder = encoder

        self._tag_representation_dim = params.pop('tag_representation_dim')
        self._arc_representation_dim = params.pop('arc_representation_dim')

        self._dropout = params.pop('dropout')
        self._input_dropout = params.pop('input_dropout')

        ############
        # DSP Stuffs
        ############
        dsp_params = params.pop("dsp")

        init_params = dsp_params.pop("initializer", None)
        self._initializer = (InitializerApplicator.from_params(init_params)
                             if init_params is not None else
                             InitializerApplicator())
        pos_params = dsp_params.pop("pos_tag_embedding")
        self._pos_tag_embedding = Embedding.from_params(vocab, pos_params)

        # Tagger DSP - Biaffine Tagger
        tagger_dsp = BiaffineParser(
            vocab=vocab,
            task_type='dsp',
            text_field_embedder=self._text_field_embedder,
            encoder=self._encoder,
            tag_representation_dim=self._tag_representation_dim,
            arc_representation_dim=self._arc_representation_dim,
            pos_tag_embedding=self._pos_tag_embedding,
            dropout=self._dropout,
            input_dropout=self._input_dropout,
            initializer=self._initializer)
        self._tagger_dsp = tagger_dsp

        # arc shared
        self._arc_attention = tagger_dsp.arc_attention
        self._head_arc_feedforward = tagger_dsp.head_arc_feedforward
        self._child_arc_feedforward = tagger_dsp.child_arc_feedforward

        ############
        # SRL Stuffs
        ############
        srl_params = params.pop("srl")

        # init_params = srl_params.pop("initializer", None)
        # self._initializer = (
        #     InitializerApplicator.from_params(init_params) if init_params is not None else InitializerApplicator()
        # )
        # pos_params = srl_params.pop("pos_tag_embedding")
        # self._pos_tag_embedding = Embedding.from_params(vocab, pos_params)

        # Tagger: SRL - Biaffine Tagger
        tagger_srl = BiaffineParser(
            vocab=vocab,
            task_type='srl',
            text_field_embedder=self._text_field_embedder,
            encoder=self._encoder,
            tag_representation_dim=self._tag_representation_dim,
            arc_representation_dim=self._arc_representation_dim,
            pos_tag_embedding=self._pos_tag_embedding,
            dropout=self._dropout,
            input_dropout=self._input_dropout,
            initializer=self._initializer)
        tagger_srl.arc_attention = self._arc_attention
        tagger_srl.head_arc_feedforward = self._head_arc_feedforward
        tagger_srl.child_arc_feedforward = self._child_arc_feedforward
        self._tagger_srl = tagger_srl

        ############
        # CSP Stuffs
        ############

        csp_params = params.pop("csp")
        init_params = csp_params.pop("initializer", None)
        self._initializer = (InitializerApplicator.from_params(init_params)
                             if init_params is not None else
                             InitializerApplicator())
        # pos_params = csp_params.pop("pos_tag_embedding")
        # self._pos_tag_embedding = Embedding.from_params(vocab, pos_params)

        span_params = csp_params.pop("span_extractor")
        self._span_extractor = SpanExtractor.from_params(span_params)

        feed_forward_params = csp_params.pop("feedforward")
        self._feed_forward = FeedForward.from_params(feed_forward_params)

        # Tagger: CSP - SpanConstituencyParser Tagger
        tagger_csp = SpanConstituencyParser(
            vocab=vocab,
            text_field_embedder=self._text_field_embedder,
            span_extractor=self._span_extractor,
            encoder=self._encoder,
            feedforward=self._feed_forward,
            pos_tag_embedding=self._pos_tag_embedding,
            initializer=self._initializer)

        self._tagger_csp = tagger_csp

        logger.info("Multi-Task Learning Model has been instantiated.")