Ejemplo n.º 1
0
def prepare_model(config):
    opt = config['opt']
    emb_non_trainable = not opt.embedding_trainable
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'bilstm':
            model = GloveLSTMCRF(config,
                                 opt.embedding_path,
                                 opt.label_path,
                                 opt.pos_path,
                                 emb_non_trainable=emb_non_trainable,
                                 use_crf=opt.use_crf,
                                 use_char_cnn=opt.use_char_cnn)
        if config['enc_class'] == 'densenet':
            model = GloveDensenetCRF(config,
                                     opt.embedding_path,
                                     opt.label_path,
                                     opt.pos_path,
                                     emb_non_trainable=emb_non_trainable,
                                     use_crf=opt.use_crf,
                                     use_char_cnn=opt.use_char_cnn)
    elif config['emb_class'] == 'elmo':
        from allennlp.modules.elmo import Elmo
        elmo_model = Elmo(opt.elmo_options_file,
                          opt.elmo_weights_file,
                          2,
                          dropout=0)
        model = ElmoLSTMCRF(config,
                            elmo_model,
                            opt.embedding_path,
                            opt.label_path,
                            opt.pos_path,
                            emb_non_trainable=emb_non_trainable,
                            use_crf=opt.use_crf,
                            use_char_cnn=opt.use_char_cnn)
    else:
        from transformers import AutoTokenizer, AutoConfig, AutoModel
        bert_tokenizer = AutoTokenizer.from_pretrained(
            opt.bert_model_name_or_path)
        bert_model = AutoModel.from_pretrained(
            opt.bert_model_name_or_path,
            from_tf=bool(".ckpt" in opt.bert_model_name_or_path))
        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = BertLSTMCRF
        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           opt.label_path,
                           opt.pos_path,
                           use_crf=opt.use_crf,
                           use_pos=opt.bert_use_pos,
                           disable_lstm=opt.bert_disable_lstm,
                           feature_based=opt.bert_use_feature_based)
    model.to(opt.device)
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[model prepared]")
    return model
Ejemplo n.º 2
0
def load_model(config, checkpoint):
    opt = config['opt']
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'bilstm':
            model = GloveLSTMCRF(config,
                                 opt.embedding_path,
                                 opt.label_path,
                                 opt.pos_path,
                                 emb_non_trainable=True,
                                 use_crf=opt.use_crf,
                                 use_char_cnn=opt.use_char_cnn)
        if config['enc_class'] == 'densenet':
            model = GloveDensenetCRF(config,
                                     opt.embedding_path,
                                     opt.label_path,
                                     opt.pos_path,
                                     emb_non_trainable=True,
                                     use_crf=opt.use_crf,
                                     use_char_cnn=opt.use_char_cnn)
    if config['emb_class'] in [
            'bert', 'distilbert', 'albert', 'roberta', 'bart', 'electra'
    ]:
        from transformers import AutoTokenizer, AutoConfig, AutoModel
        bert_config = AutoConfig.from_pretrained(opt.bert_output_dir)
        bert_tokenizer = AutoTokenizer.from_pretrained(opt.bert_output_dir)
        bert_model = AutoModel.from_config(bert_config)
        ModelClass = BertLSTMCRF
        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           opt.label_path,
                           opt.pos_path,
                           use_crf=opt.use_crf,
                           use_pos=opt.bert_use_pos,
                           disable_lstm=opt.bert_disable_lstm,
                           feature_based=opt.bert_use_feature_based)
    if config['emb_class'] == 'elmo':
        from allennlp.modules.elmo import Elmo
        elmo_model = Elmo(opt.elmo_options_file,
                          opt.elmo_weights_file,
                          2,
                          dropout=0)
        model = ElmoLSTMCRF(config,
                            elmo_model,
                            opt.embedding_path,
                            opt.label_path,
                            opt.pos_path,
                            emb_non_trainable=True,
                            use_crf=opt.use_crf,
                            use_char_cnn=opt.use_char_cnn)
    model.load_state_dict(checkpoint)
    model = model.to(opt.device)
    logger.info("[Loaded]")
    return model
Ejemplo n.º 3
0
def load_model(config, checkpoint):
    args = config['args']
    labels = load_dict(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    config['label_size'] = label_size
    glabels = load_dict(args.glabel_path)
    glabel_size = len(glabels)
    config['glabels'] = glabels
    config['glabel_size'] = glabel_size
    poss = load_dict(args.pos_path)
    pos_size = len(poss)
    config['poss'] = poss
    config['pos_size'] = pos_size
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'bilstm':
            model = GloveLSTMCRF(config,
                                 args.embedding_path,
                                 label_size,
                                 pos_size,
                                 emb_non_trainable=True,
                                 use_crf=args.use_crf,
                                 use_ncrf=args.use_ncrf,
                                 use_char_cnn=args.use_char_cnn,
                                 use_mha=args.use_mha)
        if config['enc_class'] == 'densenet':
            model = GloveDensenetCRF(config,
                                     args.embedding_path,
                                     label_size,
                                     pos_size,
                                     emb_non_trainable=True,
                                     use_crf=args.use_crf,
                                     use_ncrf=args.use_ncrf,
                                     use_char_cnn=args.use_char_cnn,
                                     use_mha=args.use_mha)
    elif config['emb_class'] == 'elmo':
        from allennlp.modules.elmo import Elmo
        elmo_model = Elmo(args.elmo_options_file,
                          args.elmo_weights_file,
                          2,
                          dropout=0)
        model = ElmoLSTMCRF(config,
                            elmo_model,
                            args.embedding_path,
                            label_size,
                            pos_size,
                            emb_non_trainable=True,
                            use_crf=args.use_crf,
                            use_ncrf=args.use_ncrf,
                            use_char_cnn=args.use_char_cnn,
                            use_mha=args.use_mha)
    else:
        bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
        bert_tokenizer = AutoTokenizer.from_pretrained(args.bert_output_dir)
        bert_model = AutoModel.from_config(bert_config)
        ModelClass = BertLSTMCRF
        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           label_size,
                           glabel_size,
                           pos_size,
                           use_crf=args.use_crf,
                           use_ncrf=args.use_ncrf,
                           use_pos=args.bert_use_pos,
                           use_char_cnn=args.use_char_cnn,
                           use_mha=args.use_mha,
                           use_subword_pooling=args.bert_use_subword_pooling,
                           use_word_embedding=args.bert_use_word_embedding,
                           embedding_path=args.embedding_path,
                           emb_non_trainable=True,
                           use_doc_context=args.bert_use_doc_context,
                           disable_lstm=args.bert_disable_lstm,
                           feature_based=args.bert_use_feature_based,
                           use_mtl=args.bert_use_mtl)
    model.load_state_dict(checkpoint)
    model = model.to(args.device)
    logger.info("[Loaded]")
    return model
Ejemplo n.º 4
0
def prepare_model(config):
    opt = config['opt']
    emb_non_trainable = not opt.embedding_trainable
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'bilstm':
            model = GloveLSTMCRF(config,
                                 opt.embedding_path,
                                 opt.label_path,
                                 opt.pos_path,
                                 emb_non_trainable=emb_non_trainable,
                                 use_crf=opt.use_crf,
                                 use_char_cnn=opt.use_char_cnn)
        if config['enc_class'] == 'densenet':
            model = GloveDensenetCRF(config,
                                     opt.embedding_path,
                                     opt.label_path,
                                     opt.pos_path,
                                     emb_non_trainable=emb_non_trainable,
                                     use_crf=opt.use_crf,
                                     use_char_cnn=opt.use_char_cnn)
    if config['emb_class'] in [
            'bert', 'distilbert', 'albert', 'roberta', 'bart', 'electra'
    ]:
        from transformers import BertTokenizer, BertConfig, BertModel
        from transformers import DistilBertTokenizer, DistilBertConfig, DistilBertModel
        from transformers import AlbertTokenizer, AlbertConfig, AlbertModel
        from transformers import RobertaConfig, RobertaTokenizer, RobertaModel
        from transformers import BartConfig, BartTokenizer, BartModel
        from transformers import ElectraConfig, ElectraTokenizer, ElectraModel
        MODEL_CLASSES = {
            "bert": (BertConfig, BertTokenizer, BertModel),
            "distilbert":
            (DistilBertConfig, DistilBertTokenizer, DistilBertModel),
            "albert": (AlbertConfig, AlbertTokenizer, AlbertModel),
            "roberta": (RobertaConfig, RobertaTokenizer, RobertaModel),
            "bart": (BartConfig, BartTokenizer, BartModel),
            "electra": (ElectraConfig, ElectraTokenizer, ElectraModel),
        }
        Config = MODEL_CLASSES[config['emb_class']][0]
        Tokenizer = MODEL_CLASSES[config['emb_class']][1]
        Model = MODEL_CLASSES[config['emb_class']][2]
        bert_tokenizer = Tokenizer.from_pretrained(
            opt.bert_model_name_or_path, do_lower_case=opt.bert_do_lower_case)
        output_hidden_states = True
        bert_model = Model.from_pretrained(
            opt.bert_model_name_or_path,
            from_tf=bool(".ckpt" in opt.bert_model_name_or_path),
            output_hidden_states=output_hidden_states)
        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = BertLSTMCRF
        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           opt.label_path,
                           opt.pos_path,
                           use_crf=opt.use_crf,
                           use_pos=opt.bert_use_pos,
                           disable_lstm=opt.bert_disable_lstm,
                           feature_based=opt.bert_use_feature_based)
    if config['emb_class'] == 'elmo':
        from allennlp.modules.elmo import Elmo
        elmo_model = Elmo(opt.elmo_options_file,
                          opt.elmo_weights_file,
                          2,
                          dropout=0)
        model = ElmoLSTMCRF(config,
                            elmo_model,
                            opt.embedding_path,
                            opt.label_path,
                            opt.pos_path,
                            emb_non_trainable=emb_non_trainable,
                            use_crf=opt.use_crf,
                            use_char_cnn=opt.use_char_cnn)
    model.to(opt.device)
    print(model)
    logger.info("[model prepared]")
    return model
Ejemplo n.º 5
0
def prepare_model(config):
    args = config['args']
    labels = load_dict(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    config['label_size'] = label_size
    glabels = load_dict(args.glabel_path)
    glabel_size = len(glabels)
    config['glabels'] = glabels
    config['glabel_size'] = glabel_size
    poss = load_dict(args.pos_path)
    pos_size = len(poss)
    config['poss'] = poss
    config['pos_size'] = pos_size
    emb_non_trainable = not args.embedding_trainable
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'bilstm':
            model = GloveLSTMCRF(config,
                                 args.embedding_path,
                                 label_size,
                                 pos_size,
                                 emb_non_trainable=emb_non_trainable,
                                 use_crf=args.use_crf,
                                 use_ncrf=args.use_ncrf,
                                 use_char_cnn=args.use_char_cnn,
                                 use_mha=args.use_mha)
        if config['enc_class'] == 'densenet':
            model = GloveDensenetCRF(config,
                                     args.embedding_path,
                                     label_size,
                                     pos_size,
                                     emb_non_trainable=emb_non_trainable,
                                     use_crf=args.use_crf,
                                     use_ncrf=args.use_ncrf,
                                     use_char_cnn=args.use_char_cnn,
                                     use_mha=args.use_mha)
    elif config['emb_class'] == 'elmo':
        from allennlp.modules.elmo import Elmo
        elmo_model = Elmo(args.elmo_options_file,
                          args.elmo_weights_file,
                          2,
                          dropout=0)
        model = ElmoLSTMCRF(config,
                            elmo_model,
                            args.embedding_path,
                            label_size,
                            pos_size,
                            emb_non_trainable=emb_non_trainable,
                            use_crf=args.use_crf,
                            use_ncrf=args.use_ncrf,
                            use_char_cnn=args.use_char_cnn,
                            use_mha=args.use_mha)
    else:
        bert_tokenizer = AutoTokenizer.from_pretrained(
            args.bert_model_name_or_path)
        bert_model = AutoModel.from_pretrained(
            args.bert_model_name_or_path,
            from_tf=bool(".ckpt" in args.bert_model_name_or_path))
        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = BertLSTMCRF
        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           label_size,
                           glabel_size,
                           pos_size,
                           use_crf=args.use_crf,
                           use_ncrf=args.use_ncrf,
                           use_pos=args.bert_use_pos,
                           use_char_cnn=args.use_char_cnn,
                           use_mha=args.use_mha,
                           use_subword_pooling=args.bert_use_subword_pooling,
                           use_word_embedding=args.bert_use_word_embedding,
                           embedding_path=args.embedding_path,
                           emb_non_trainable=emb_non_trainable,
                           use_doc_context=args.bert_use_doc_context,
                           disable_lstm=args.bert_disable_lstm,
                           feature_based=args.bert_use_feature_based,
                           use_mtl=args.bert_use_mtl)
    if args.restore_path:
        checkpoint = load_checkpoint(args.restore_path)
        model.load_state_dict(checkpoint)
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[model prepared]")
    return model
Ejemplo n.º 6
0
def load_model(config, checkpoint):
    opt = config['opt']
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'bilstm':
            model = GloveLSTMCRF(config,
                                 opt.embedding_path,
                                 opt.label_path,
                                 opt.pos_path,
                                 emb_non_trainable=True,
                                 use_crf=opt.use_crf,
                                 use_char_cnn=opt.use_char_cnn)
        if config['enc_class'] == 'densenet':
            model = GloveDensenetCRF(config,
                                     opt.embedding_path,
                                     opt.label_path,
                                     opt.pos_path,
                                     emb_non_trainable=True,
                                     use_crf=opt.use_crf,
                                     use_char_cnn=opt.use_char_cnn)
    if config['emb_class'] in [
            'bert', 'distilbert', 'albert', 'roberta', 'bart', 'electra'
    ]:
        from transformers import BertTokenizer, BertConfig, BertModel
        from transformers import DistilBertTokenizer, DistilBertConfig, DistilBertModel
        from transformers import AlbertTokenizer, AlbertConfig, AlbertModel
        from transformers import RobertaConfig, RobertaTokenizer, RobertaModel
        from transformers import BartConfig, BartTokenizer, BartModel
        from transformers import ElectraConfig, ElectraTokenizer, ElectraModel
        MODEL_CLASSES = {
            "bert": (BertConfig, BertTokenizer, BertModel),
            "distilbert":
            (DistilBertConfig, DistilBertTokenizer, DistilBertModel),
            "albert": (AlbertConfig, AlbertTokenizer, AlbertModel),
            "roberta": (RobertaConfig, RobertaTokenizer, RobertaModel),
            "bart": (BartConfig, BartTokenizer, BartModel),
            "electra": (ElectraConfig, ElectraTokenizer, ElectraModel),
        }
        Config = MODEL_CLASSES[config['emb_class']][0]
        Tokenizer = MODEL_CLASSES[config['emb_class']][1]
        Model = MODEL_CLASSES[config['emb_class']][2]
        bert_config = Config.from_pretrained(opt.bert_output_dir)
        bert_tokenizer = Tokenizer.from_pretrained(opt.bert_output_dir)
        # no need to use 'from_pretrained'
        bert_model = Model(bert_config)
        ModelClass = BertLSTMCRF
        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           opt.label_path,
                           opt.pos_path,
                           use_crf=opt.use_crf,
                           use_pos=opt.bert_use_pos,
                           disable_lstm=opt.bert_disable_lstm,
                           feature_based=opt.bert_use_feature_based)
    if config['emb_class'] == 'elmo':
        from allennlp.modules.elmo import Elmo
        elmo_model = Elmo(opt.elmo_options_file,
                          opt.elmo_weights_file,
                          2,
                          dropout=0)
        model = ElmoLSTMCRF(config,
                            elmo_model,
                            opt.embedding_path,
                            opt.label_path,
                            opt.pos_path,
                            emb_non_trainable=True,
                            use_crf=opt.use_crf,
                            use_char_cnn=opt.use_char_cnn)
    model.load_state_dict(checkpoint)
    model = model.to(opt.device)
    logger.info("[Loaded]")
    return model