Beispiel #1
0
 def build_transformer(self, training, transformer):
     if training:
         self.config.n_words = len(self.transform.form_vocab)
     self._init_config()
     if isinstance(transformer, str):
         if 'albert_chinese' in transformer:
             tokenizer = BertTokenizerFast.from_pretrained(
                 transformer, add_special_tokens=False)
             transformer: TFPreTrainedModel = TFAutoModel.from_pretrained(
                 transformer, name=transformer, from_pt=True)
         elif transformer.startswith('albert') and transformer.endswith(
                 'zh'):
             transformer, tokenizer, path = build_transformer(transformer)
             transformer.config = AlbertConfig.from_json_file(
                 os.path.join(path, "albert_config.json"))
             tokenizer = BertTokenizer.from_pretrained(
                 os.path.join(path, "vocab_chinese.txt"),
                 add_special_tokens=False)
         elif 'chinese-roberta' in transformer:
             tokenizer = BertTokenizer.from_pretrained(transformer)
             transformer = TFBertModel.from_pretrained(transformer,
                                                       name=transformer,
                                                       from_pt=True)
         else:
             tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
                 transformer)
             try:
                 transformer: TFPreTrainedModel = TFAutoModel.from_pretrained(
                     transformer, name=transformer)
             except (TypeError, OSError):
                 transformer: TFPreTrainedModel = TFAutoModel.from_pretrained(
                     transformer, name=transformer, from_pt=True)
     elif transformer[0] == 'AutoModelWithLMHead':
         tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
             transformer[1])
         transformer: TFAutoModelWithLMHead = TFAutoModelWithLMHead.from_pretrained(
             transformer[1])
     else:
         raise ValueError(f'Unknown identifier {transformer}')
     self.transform.tokenizer = tokenizer
     if self.config.get('fp16', None) or self.config.get('use_amp', None):
         policy = tf.keras.mixed_precision.experimental.Policy(
             'mixed_float16')
         tf.keras.mixed_precision.experimental.set_policy(policy)
         # tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
         transformer.set_weights(
             [w.astype('float16') for w in transformer.get_weights()])
     self.transform.transformer_config = transformer.config
     return transformer
Beispiel #2
0
 def build_model(self, transformer, max_length, **kwargs):
     model, self.transform.tokenizer = build_transformer(transformer, max_length, len(self.transform.label_vocab),
                                                         tagging=False)
     return model
Beispiel #3
0
 def load_transform(self, save_dir) -> Transform:
     super().load_transform(save_dir)
     self.transform.tokenizer = build_transformer(self.config.transformer, self.config.max_seq_length,
                                                  len(self.transform.tag_vocab), tagging=True, tokenizer_only=True)
     return self.transform
Beispiel #4
0
 def build_model(self, transformer, max_seq_length, **kwargs) -> tf.keras.Model:
     model, tokenizer = build_transformer(transformer, max_seq_length, len(self.transform.tag_vocab), tagging=True)
     self.transform.tokenizer = tokenizer
     return model