Пример #1
0
 def __init__(self, hparams, **kwargs):
     super(KoBARTConditionalGeneration, self).__init__(hparams, **kwargs)
     self.model = BartForConditionalGeneration.from_pretrained(get_pytorch_kobart_model())
     self.model.train()
     self.bos_token = '<s>'
     self.eos_token = '</s>'
     self.pad_token_id = 0
     self.tokenizer = get_kobart_tokenizer()
Пример #2
0
    def __init__(self, hparam=None, text_logger=None):
        super(BART, self).__init__()

        self._model = BartForConditionalGeneration.from_pretrained(
            get_pytorch_kobart_model())
        self._model.train()
        self.tokenizer = get_kobart_tokenizer()

        self._hparams = hparam

        self._text_logger = text_logger
Пример #3
0
train_dataset = KGBDDataset(train_dev['train'])
valid_dataset = KGBDDataset(train_dev['dev'])
train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True)
valid_dataloader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=False)

from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
from transformers import BartForSequenceClassification

model = BartForSequenceClassification.from_pretrained(
    get_pytorch_kobart_model()).cuda()

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [{
    'params':
    [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
    'weight_decay':
    0.01
}, {
    'params':
    [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
    'weight_decay':
    0.0
}]
optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, correct_bias=False)
Пример #4
0
def main():
    # Get ArgParse
    args = get_args()
    if args.checkpoint:
        args.checkpoint = (
            "./model_checkpoint/" + args.checkpoint[-1]
            if args.checkpoint[-1] == "/"
            else "./model_checkpoint/" + args.checkpoint
        )
    else:
        args.checkpoint = "./model_checkpoint/" + gen_checkpoint_id(args)


    # If checkpoint path exists, load the last model
    if os.path.isdir(args.checkpoint):
        # EXAMPLE: "{engine_name}_{task_name}_{timestamp}/saved_checkpoint_1"     
        args.checkpoint_count = checkpoint_count(args.checkpoint)
        logger = get_logger(args)
        logger.info(f"Checkpoint path directory exists")
        logger.info(f"Loading model from saved_checkpoint_{args.checkpoint_count}")
        model = torch.load(f"{args.checkpoint}/saved_checkpoint_{args.checkpoint_count}") 
        
        args.checkpoint_count += 1 #
    # If there is none, create a checkpoint folder and train from scratch
    else:
        try:
            os.makedirs(args.checkpoint)
        except:
            print("Ignoring Existing File Path ...")

#         model = BartModel.from_pretrained(get_pytorch_kobart_model())
        model = AutoModelForSeq2SeqLM.from_pretrained(get_pytorch_kobart_model())
        
        args.checkpoint_count = 0
        logger = get_logger(args)

        logger.info(f"Creating a new directory for {args.checkpoint}")
    
    args.logger = logger
    
    model.to(args.device)
    
    # Define Tokenizer
    tokenizer = get_kobart_tokenizer()

    # Add Additional Special Tokens 
    #special_tokens_dict = {"sep_token": "<sep>"}
    #tokenizer.add_special_tokens(special_tokens_dict)
    #model.resize_token_embeddings(new_num_tokens=len(tokenizer))

    # Define Optimizer
    optimizer_class = getattr(transformers, args.optimizer_class)
    optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)

    logger.info(f"Loading data from {args.data_dir} ...")
    with open("data/Brunch_accm_20210328_train.json", 'r') as f:
        train_data = json.load(f)
    train_context = [data['text'] for data in train_data]
    train_tag = [data['tag'] for data in train_data]
    with open("data/Brunch_accm_20210328_test.json", 'r') as f:
        test_data = json.load(f)
    test_context = [data['text'] for data in test_data]
    test_tag = [data['tag'] for data in test_data]
    
    train_dataset = SummaryDataset(train_context, train_tag, tokenizer, args.enc_max_len, args.dec_max_len, ignore_index=-100)    
    test_dataset = SummaryDataset(test_context, test_tag, tokenizer, args.enc_max_len, args.dec_max_len, ignore_index=-100)    
#     train_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "train.json"))
#     valid_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "valid.json"))
#     test_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "test.json"))
    

    batch_generator = SummaryBatchGenerator(tokenizer)
    
    train_loader = get_dataloader(
        train_dataset, 
        batch_generator=batch_generator,
        batch_size=args.train_batch_size,
        shuffle=True,
    )
    
    test_loader = get_dataloader(
        test_dataset, 
        batch_generator=batch_generator,
        batch_size=args.eval_batch_size,
        shuffle=False,
    )
    
#     test_loader = get_dataloader(
#         test_dataset, 
#         batch_generator=batch_generator,
#         batch_size=args.eval_batch_size,
#         shuffle=False,
#     )
    

    train(model, optimizer, tokenizer, train_loader, test_loader, test_tag, args)# test_loader, args)
Пример #5
0
 def __init__(self, hparams, **kwargs):
     super(KoBARTClassification, self).__init__(hparams, **kwargs)
     self.model = BartForSequenceClassification.from_pretrained(get_pytorch_kobart_model())
     self.model.train()
     self.metric_acc = pl.metrics.classification.Accuracy()
Пример #6
0
def load_model(config, checkpoint):
    args = config['args']
    labels = load_label(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, args.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 args.embedding_path,
                                 label_size,
                                 emb_non_trainable=True)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
    else:
        if config['emb_class'] == 'bart' and config['use_kobart']:
            from transformers import BartModel
            from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
            bert_tokenizer = get_kobart_tokenizer()
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = BartModel.from_pretrained(get_pytorch_kobart_model())
            bert_config = bert_model.config
        elif config['emb_class'] in ['gpt']:
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_tokenizer.bos_token = '<|startoftext|>'
            bert_tokenizer.eos_token = '<|endoftext|>'
            bert_tokenizer.cls_token = '<|startoftext|>'
            bert_tokenizer.sep_token = '<|endoftext|>'
            bert_tokenizer.pad_token = '<|pad|>'
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = AutoModel.from_pretrained(args.bert_output_dir)
        elif config['emb_class'] in ['t5']:
            from transformers import T5EncoderModel
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = T5EncoderModel(bert_config)
        else:
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = AutoModel.from_config(bert_config)

        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        if config['enc_class'] == 'densenet-cnn':
            ModelClass = TextBertDensenetCNN

        model = ModelClass(config, bert_config, bert_model, bert_tokenizer,
                           label_size)

    if args.enable_qat:
        assert args.device == 'cpu'
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
        model.eval()
        model.to('cpu')
        logger.info("[Convert to quantized model with device=cpu]")
        model = torch.quantization.convert(model)
    if args.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)
        logger.info("[Convert to quantized model]")
        model = quantize_fx.convert_fx(model)

    if args.enable_diffq:
        quantizer = DiffQuantizer(model)
        config['quantizer'] = quantizer
        quantizer.restore_quantized_state(checkpoint)
    else:
        model.load_state_dict(checkpoint)

    model = model.to(args.device)
    ''' 
    for name, param in model.named_parameters():
        print(name, param.data, param.device, param.requires_grad)
    '''
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[Model loaded]")
    return model
Пример #7
0
def prepare_model(config, bert_model_name_or_path=None):
    args = config['args']
    emb_non_trainable = not args.embedding_trainable
    labels = load_label(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    # prepare model
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, args.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 args.embedding_path,
                                 label_size,
                                 emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
    else:
        model_name_or_path = args.bert_model_name_or_path
        if bert_model_name_or_path:
            model_name_or_path = bert_model_name_or_path

        if config['emb_class'] == 'bart' and config['use_kobart']:
            from transformers import BartModel
            from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
            bert_tokenizer = get_kobart_tokenizer()
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = BartModel.from_pretrained(get_pytorch_kobart_model())
        elif config['emb_class'] in ['gpt']:
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_tokenizer.bos_token = '<|startoftext|>'
            bert_tokenizer.eos_token = '<|endoftext|>'
            bert_tokenizer.cls_token = '<|startoftext|>'
            bert_tokenizer.sep_token = '<|endoftext|>'
            bert_tokenizer.pad_token = '<|pad|>'
            bert_model = AutoModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))
            # 3 new tokens added
            bert_model.resize_token_embeddings(len(bert_tokenizer))
        elif config['emb_class'] in ['t5']:
            from transformers import T5EncoderModel
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = T5EncoderModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))

        else:
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_model = AutoModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))

        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        if config['enc_class'] == 'densenet-cnn':
            ModelClass = TextBertDensenetCNN

        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           label_size,
                           feature_based=args.bert_use_feature_based,
                           finetune_last=args.bert_use_finetune_last)
    if args.restore_path:
        checkpoint = load_checkpoint(args.restore_path)
        model.load_state_dict(checkpoint)
    if args.enable_qat:
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
    if args.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        model.train()
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)

    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[model prepared]")
    return model