def load_bert(word2ix, tokenizer=None, model_name="roberta", model_class="seq2seq", target_size=0): """ model_path: 模型位置 这是个统一的接口,用来加载模型的 model_class : seq2seq or encoder """ if model_class == "seq2seq": bert_model = Seq2SeqModel(word2ix, model_name=model_name, tokenizer=tokenizer) return bert_model elif model_class == "cls": if target_size == 0: raise Exception("必须传入参数 target_size,才能确定预测多少分类") bert_model = BertClsClassifier(word2ix, target_size, model_name=model_name) return bert_model elif model_class == "sequence_labeling": ## 序列标注模型 if target_size == 0: raise Exception("必须传入参数 target_size,才能确定预测多少分类") bert_model = BertSeqLabeling(word2ix, target_size, model_name=model_name) return bert_model elif model_class == "sequence_labeling_crf": # 带有crf层的序列标注模型 if target_size == 0: raise Exception("必须传入参数 target_size,才能确定预测多少分类") bert_model = BertSeqLabelingCRF(word2ix, target_size, model_name=model_name) return bert_model elif model_class == "relation_extrac": if target_size == 0: raise Exception("必须传入参数 target_size 表示预测predicate的种类") bert_model = BertRelationExtrac(word2ix, target_size, model_name=model_name) return bert_model elif model_class == "simbert": bert_model = SimBertModel(word2ix, model_name=model_name) return bert_model else: raise Exception("model_name_err")
def load_bert(vocab_path, model_name="roberta", model_class="seq2seq", target_size=0, simplfied=False): """ model_path: 模型位置 这是个统一的接口,用来加载模型的 model_class : seq2seq or encoder """ if model_class == "seq2seq": bert_model = Seq2SeqModel(vocab_path, model_name=model_name, simplfied=simplfied) return bert_model elif model_class == "cls": if target_size == 0: raise Exception("必须传入参数 target_size,才能确定预测多少分类") bert_model = BertClsClassifier(vocab_path, target_size, model_name=model_name) return bert_model elif model_class == "sequence_labeling": ## 序列标注模型 if target_size == 0: raise Exception("必须传入参数 target_size,才能确定预测多少分类") bert_model = BertSeqLabeling(vocab_path, target_size, model_name=model_name) return bert_model elif model_class == "sequence_labeling_crf": # 带有crf层的序列标注模型 if target_size == 0: raise Exception("必须传入参数 target_size,才能确定预测多少分类") bert_model = BertSeqLabelingCRF(vocab_path, target_size, model_name=model_name) return bert_model elif model_class == "relation_extrac": if target_size == 0: raise Exception("必须传入参数 target_size 表示预测predicate的种类") bert_model = BertRelationExtrac(vocab_path, target_size, model_name=model_name) return bert_model else: raise Exception("model_name_err")