Exemplo n.º 1
0
 def finish_deserializing(self):
     self.bpe_ranks = deserialize_bpe_ranks(self._bpe_ranks)
     self.decoder = {v: k for k, v in self.encoder.items()}
     self.byte_encoder = bytes_to_unicode()
     self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
     self.cache = {}
     self.pat = regex.compile(self._regex_pattern, flags=regex.V0)
Exemplo n.º 2
0
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model: str = None,
                 requires_grad: bool = True,
                 transformer_weights_model: str = None,
                 layer_freeze_regexes: List[str] = None,
                 on_load: bool = False,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        if on_load:
            logging.info(f"Skipping loading of initial Transformer weights")
            transformer_config = RobertaConfig.from_pretrained(
                pretrained_model)
            self._transformer_model = RobertaModel(transformer_config)

        elif transformer_weights_model:
            logging.info(
                f"Loading Transformer weights model from {transformer_weights_model}"
            )
            transformer_model_loaded = load_archive(transformer_weights_model)
            self._transformer_model = transformer_model_loaded.model._transformer_model
        else:
            self._transformer_model = RobertaModel.from_pretrained(
                pretrained_model)

        for name, param in self._transformer_model.named_parameters():
            grad = requires_grad
            if layer_freeze_regexes and grad:
                grad = not any(
                    [bool(re.search(r, name)) for r in layer_freeze_regexes])
            param.requires_grad = grad

        transformer_config = self._transformer_model.config
        num_labels = 2  # For start/end
        self.qa_outputs = Linear(transformer_config.hidden_size, num_labels)

        # Import GTP2 machinery to get from tokens to actual text
        self.byte_decoder = {v: k for k, v in bytes_to_unicode().items()}

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._debug = 2
        self._padding_value = 1  # The index of the RoBERTa padding token