Example #1
0
 def __init__(self, token_embedder: TextFieldEmbedder,
              encoder: Seq2SeqEncoder,
              span_decoder: Topdown_Span_Parser_Factory,
              span_extractor: SpanExtractor,
              remote_parser: Remote_Parser_Factory, evaluator: any,
              vocab: Vocabulary) -> None:
     super().__init__(vocab)
     self.token_embedder = token_embedder
     self.encoder = encoder
     self.span_decoder = span_decoder(span_extractor.get_output_dim(),
                                      vocab)
     self.span_extractor = span_extractor
     self.remote_parser = remote_parser(span_extractor.get_output_dim(),
                                        vocab)
     self.evaluator = evaluator
Example #2
0
    def __init__(self,
                 subword_embeddings: TextFieldEmbedder,
                 subword_aggregator: SpanExtractor,
                 vocab: Vocabulary = None,
                 freeze_encoder: bool = True):
        super().__init__(vocab)

        self.subword_embeddings = subword_embeddings
        self._freeze_encoder = freeze_encoder
        # turn off gradients if don't want to fine tune encoder
        for parameter in self.subword_embeddings.parameters():
            parameter.requires_grad = not self._freeze_encoder

        self.subword_aggregator = subword_aggregator

        self.classifier = TimeDistributed(
            torch.nn.Linear(in_features=subword_aggregator.get_output_dim(),
                            out_features=vocab.get_vocab_size("labels")))

        self.accuracy = CategoricalAccuracy()