示例#1
0
def load_model(config, checkpoint):
    opt = config['opt']
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, opt.embedding_path, opt.label_path)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 opt.embedding_path,
                                 opt.label_path,
                                 emb_non_trainable=True)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         opt.embedding_path,
                                         opt.label_path,
                                         emb_non_trainable=True)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         opt.embedding_path,
                                         opt.label_path,
                                         emb_non_trainable=True)
    if config['emb_class'] in [
            'bert', 'distilbert', 'albert', 'roberta', 'bart', 'electra'
    ]:
        from transformers import AutoTokenizer, AutoConfig, AutoModel
        bert_config = AutoConfig.from_pretrained(opt.bert_output_dir)
        bert_tokenizer = AutoTokenizer.from_pretrained(opt.bert_output_dir)
        bert_model = AutoModel.from_config(bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        model = ModelClass(config, bert_config, bert_model, bert_tokenizer,
                           opt.label_path)
    model.load_state_dict(checkpoint)
    model = model.to(opt.device)
    logger.info("[Model loaded]")
    return model
示例#2
0
def prepare_model(config):
    opt = config['opt']
    emb_non_trainable = not opt.embedding_trainable
    # prepare model
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, opt.embedding_path, opt.label_path)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config, opt.embedding_path, opt.label_path, emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config, opt.embedding_path, opt.label_path, emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config, opt.embedding_path, opt.label_path, emb_non_trainable=emb_non_trainable)
    if config['emb_class'] in ['bert', 'distilbert', 'albert', 'roberta', 'bart', 'electra']:
        from transformers import AutoTokenizer, AutoConfig, AutoModel
        bert_tokenizer = AutoTokenizer.from_pretrained(opt.bert_model_name_or_path,
                                                   do_lower_case=opt.bert_do_lower_case)
        bert_model = AutoModel.from_pretrained(opt.bert_model_name_or_path,
                                           from_tf=bool(".ckpt" in opt.bert_model_name_or_path))
        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        model = ModelClass(config, bert_config, bert_model, bert_tokenizer, opt.label_path, feature_based=opt.bert_use_feature_based)
    model.to(opt.device)
    print(model)
    logger.info("[model prepared]")
    return model
示例#3
0
 def load_model(self, checkpoint):
     config = self.config
     opt = config['opt']
     labels = load_label(opt.label_path)
     label_size = len(labels)
     config['labels'] = labels
     self.labels = labels
     if config['emb_class'] == 'glove':
         if config['enc_class'] == 'gnb':
             model = TextGloveGNB(config, opt.embedding_path, label_size)
         if config['enc_class'] == 'cnn':
             model = TextGloveCNN(config,
                                  opt.embedding_path,
                                  label_size,
                                  emb_non_trainable=True)
         if config['enc_class'] == 'densenet-cnn':
             model = TextGloveDensenetCNN(config,
                                          opt.embedding_path,
                                          label_size,
                                          emb_non_trainable=True)
         if config['enc_class'] == 'densenet-dsa':
             model = TextGloveDensenetDSA(config,
                                          opt.embedding_path,
                                          label_size,
                                          emb_non_trainable=True)
     else:
         from transformers import AutoTokenizer, AutoConfig, AutoModel
         bert_config = AutoConfig.from_pretrained(opt.bert_output_dir)
         bert_tokenizer = AutoTokenizer.from_pretrained(opt.bert_output_dir)
         bert_model = AutoModel.from_config(bert_config)
         ModelClass = TextBertCNN
         if config['enc_class'] == 'cls': ModelClass = TextBertCLS
         model = ModelClass(config, bert_config, bert_model, bert_tokenizer,
                            label_size)
     model.load_state_dict(checkpoint)
     model = model.to(opt.device)
     logger.info("[Model loaded]")
     return model
示例#4
0
def load_model(config, checkpoint):
    opt = config['opt']
    labels = load_label(opt.label_path)
    label_size = len(labels)
    config['labels'] = labels
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, opt.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 opt.embedding_path,
                                 label_size,
                                 emb_non_trainable=True)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         opt.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         opt.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
    else:
        from transformers import AutoTokenizer, AutoConfig, AutoModel
        bert_config = AutoConfig.from_pretrained(opt.bert_output_dir)
        bert_tokenizer = AutoTokenizer.from_pretrained(opt.bert_output_dir)
        bert_model = AutoModel.from_config(bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        model = ModelClass(config, bert_config, bert_model, bert_tokenizer,
                           label_size)
    if opt.enable_qat:
        assert opt.device == 'cpu'
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
        model.eval()
        model.to('cpu')
        logger.info("[Convert to quantized model with device=cpu]")
        model = torch.quantization.convert(model)
    if opt.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)
        logger.info("[Convert to quantized model]")
        model = quantize_fx.convert_fx(model)

    model.load_state_dict(checkpoint)
    model = model.to(opt.device)
    '''
    for name, param in model.named_parameters():
        print(name, param.data, param.device, param.requires_grad)
    '''
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[Model loaded]")
    return model
示例#5
0
def load_model(config, checkpoint):
    args = config['args']
    labels = load_label(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, args.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 args.embedding_path,
                                 label_size,
                                 emb_non_trainable=True)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
    else:
        if config['emb_class'] == 'bart' and config['use_kobart']:
            from transformers import BartModel
            from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
            bert_tokenizer = get_kobart_tokenizer()
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = BartModel.from_pretrained(get_pytorch_kobart_model())
            bert_config = bert_model.config
        elif config['emb_class'] in ['gpt']:
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_tokenizer.bos_token = '<|startoftext|>'
            bert_tokenizer.eos_token = '<|endoftext|>'
            bert_tokenizer.cls_token = '<|startoftext|>'
            bert_tokenizer.sep_token = '<|endoftext|>'
            bert_tokenizer.pad_token = '<|pad|>'
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = AutoModel.from_pretrained(args.bert_output_dir)
        elif config['emb_class'] in ['t5']:
            from transformers import T5EncoderModel
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = T5EncoderModel(bert_config)
        else:
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = AutoModel.from_config(bert_config)

        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        if config['enc_class'] == 'densenet-cnn':
            ModelClass = TextBertDensenetCNN

        model = ModelClass(config, bert_config, bert_model, bert_tokenizer,
                           label_size)

    if args.enable_qat:
        assert args.device == 'cpu'
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
        model.eval()
        model.to('cpu')
        logger.info("[Convert to quantized model with device=cpu]")
        model = torch.quantization.convert(model)
    if args.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)
        logger.info("[Convert to quantized model]")
        model = quantize_fx.convert_fx(model)

    if args.enable_diffq:
        quantizer = DiffQuantizer(model)
        config['quantizer'] = quantizer
        quantizer.restore_quantized_state(checkpoint)
    else:
        model.load_state_dict(checkpoint)

    model = model.to(args.device)
    ''' 
    for name, param in model.named_parameters():
        print(name, param.data, param.device, param.requires_grad)
    '''
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[Model loaded]")
    return model
示例#6
0
def prepare_model(config, bert_model_name_or_path=None):
    args = config['args']
    emb_non_trainable = not args.embedding_trainable
    labels = load_label(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    # prepare model
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, args.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 args.embedding_path,
                                 label_size,
                                 emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
    else:
        model_name_or_path = args.bert_model_name_or_path
        if bert_model_name_or_path:
            model_name_or_path = bert_model_name_or_path

        if config['emb_class'] == 'bart' and config['use_kobart']:
            from transformers import BartModel
            from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
            bert_tokenizer = get_kobart_tokenizer()
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = BartModel.from_pretrained(get_pytorch_kobart_model())
        elif config['emb_class'] in ['gpt']:
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_tokenizer.bos_token = '<|startoftext|>'
            bert_tokenizer.eos_token = '<|endoftext|>'
            bert_tokenizer.cls_token = '<|startoftext|>'
            bert_tokenizer.sep_token = '<|endoftext|>'
            bert_tokenizer.pad_token = '<|pad|>'
            bert_model = AutoModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))
            # 3 new tokens added
            bert_model.resize_token_embeddings(len(bert_tokenizer))
        elif config['emb_class'] in ['t5']:
            from transformers import T5EncoderModel
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = T5EncoderModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))

        else:
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_model = AutoModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))

        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        if config['enc_class'] == 'densenet-cnn':
            ModelClass = TextBertDensenetCNN

        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           label_size,
                           feature_based=args.bert_use_feature_based,
                           finetune_last=args.bert_use_finetune_last)
    if args.restore_path:
        checkpoint = load_checkpoint(args.restore_path)
        model.load_state_dict(checkpoint)
    if args.enable_qat:
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
    if args.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        model.train()
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)

    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[model prepared]")
    return model
示例#7
0
def prepare_model(config, bert_model_name_or_path=None):
    opt = config['opt']
    emb_non_trainable = not opt.embedding_trainable
    labels = load_label(opt.label_path)
    label_size = len(labels)
    config['labels'] = labels
    # prepare model
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, opt.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 opt.embedding_path,
                                 label_size,
                                 emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         opt.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         opt.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
    else:
        model_name_or_path = opt.bert_model_name_or_path
        if bert_model_name_or_path:
            model_name_or_path = bert_model_name_or_path
        from transformers import AutoTokenizer, AutoConfig, AutoModel
        bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        bert_model = AutoModel.from_pretrained(
            model_name_or_path, from_tf=bool(".ckpt" in model_name_or_path))
        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           label_size,
                           feature_based=opt.bert_use_feature_based)
    if opt.restore_path:
        checkpoint = load_checkpoint(opt.restore_path, device=opt.device)
        model.load_state_dict(checkpoint)
    if opt.enable_qat:
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
    if opt.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        model.train()
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)

    model.to(opt.device)
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[model prepared]")
    return model