Exemplo n.º 1
0
def main():
    """main"""
    parser = get_parser()

    # add model specific args
    parser = BertLabeling.add_model_specific_args(parser)

    # add all the available trainer options to argparse
    # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()

    model = BertLabeling(args)
    if args.pretrained_checkpoint:
        model.load_state_dict(
            torch.load(args.pretrained_checkpoint,
                       map_location=torch.device('cpu'))["state_dict"])

    checkpoint_callback = ModelCheckpoint(
        filepath=args.default_root_dir,
        save_top_k=10,
        verbose=True,
        monitor="span_f1",
        period=-1,
        mode="max",
    )
    trainer = Trainer.from_argparse_args(
        args, checkpoint_callback=checkpoint_callback)

    trainer.fit(model)
Exemplo n.º 2
0
def main():
    """main"""
    parser = get_parser()
    # add model specific arguments.
    parser = BertForQA.add_model_specific_args(parser)
    # add all the available trainer options to argparse
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    print(args, "init")

    model = BertForQA(args)

    if len(args.pretrained_checkpoint) > 1:
        model.load_state_dict(
            torch.load(args.pretrained_checkpoint,
                       map_location=torch.device('cpu'))["state_dict"])
    if args.load_ner_bert:
        model.model.bert.load_state_dict(
            torch.load("./cached_models/ner_bert"), strict=False)

    checkpoint_callback = ModelCheckpoint(filepath=args.output_dir,
                                          save_top_k=args.max_keep_ckpt,
                                          verbose=True,
                                          period=-1,
                                          mode="auto")

    trainer = Trainer.from_argparse_args(
        args, checkpoint_callback=checkpoint_callback, deterministic=True)

    trainer.fit(model)
Exemplo n.º 3
0
def run_dataloader():
    """test dataloader"""
    parser = get_parser()

    # add model specific args
    parser = BertLabeling.add_model_specific_args(parser)

    # add all the available trainer options to argparse
    # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()
    args.workers = 0
    args.default_root_dir = "/scratch/shravya.k/train_logs/debug"

    model = BertLabeling(args)
    from tokenizers import BertWordPieceTokenizer
    tokenizer = BertWordPieceTokenizer(
        os.path.join(args.bert_config_dir, "vocab.txt"))

    loader = model.get_dataloader("dev", limit=1000)
    for d in loader:
        input_ids = d[0][0].tolist()
        match_labels = d[-1][0]
        start_positions, end_positions = torch.where(match_labels > 0)
        start_positions = start_positions.tolist()
        end_positions = end_positions.tolist()
        if not start_positions:
            continue
        print("=" * 20)
        print(tokenizer.decode(input_ids, skip_special_tokens=False))
        for start, end in zip(start_positions, end_positions):
            print(tokenizer.decode(input_ids[start:end + 1]))
Exemplo n.º 4
0
def main():
    parser = get_parser()
    parser = TNewsClassificationTask.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    task_model = TNewsClassificationTask(args)

    checkpoint_callback = ModelCheckpoint(filepath=args.output_dir,
                                          save_top_k=args.max_keep_ckpt,
                                          save_last=False,
                                          monitor="val_f1",
                                          verbose=True,
                                          mode='max',
                                          period=-1)

    task_trainer = Trainer.from_argparse_args(
        args, checkpoint_callback=checkpoint_callback, deterministic=True)

    task_trainer.fit(task_model)

    # after training, use the model checkpoint which achieves the best f1 score on dev set to compute the f1 on test set.
    best_f1_on_dev, path_to_best_checkpoint = find_best_checkpoint_on_dev(
        args.output_dir,
        only_keep_the_best_ckpt=args.only_keep_the_best_ckpt_after_training)
    task_model.result_logger.info("=&" * 20)
    task_model.result_logger.info(f"Best F1 on DEV is {best_f1_on_dev}")
    task_model.result_logger.info(
        f"Best checkpoint on DEV set is {path_to_best_checkpoint}")
    task_model.result_logger.info("=&" * 20)
Exemplo n.º 5
0
def main():
    """main"""
    parser = get_parser()
    # add model specific arguments.
    parser = BertForQA.add_model_specific_args(parser)
    # add all the available trainer options to argparse
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    model = BertForQA(args)

    if len(args.pretrained_checkpoint) > 1:
        model.load_state_dict(
            torch.load(args.pretrained_checkpoint,
                       map_location=torch.device('cpu'))["state_dict"])

    # print(args.output_dir)
    checkpoint_callback = ModelCheckpoint(dirpath=args.output_dir,
                                          filename='{epoch}-{val_loss:.2f}',
                                          verbose=True,
                                          period=-1,
                                          mode="auto")

    trainer = Trainer.from_argparse_args(
        args,
        checkpoint_callback=checkpoint_callback,
        accelerator="ddp",
        deterministic=True)

    trainer.fit(model)
Exemplo n.º 6
0
def main():
    """evaluate model checkpoints on the dev set. """
    eval_parser = init_evaluate_parser(get_parser())
    eval_parser = BertForQA.add_model_specific_args(eval_parser)
    eval_parser = Trainer.add_argparse_args(eval_parser)
    args = eval_parser.parse_args()
    print("here",args.path_to_model_checkpoint)
    if len(args.path_to_model_hparams_file) == 0:
        args.path_to_model_hparams_file = os.path.join("/".join(args.path_to_model_checkpoint.split("/")[:-1]), "lightning_logs", "version_0", "hparams.yaml")

    evaluate(args)
Exemplo n.º 7
0
def main():
    eval_parser = get_parser()
    eval_parser = init_evaluate_parser(eval_parser)
    eval_parser = BertForNERTask.add_model_specific_args(eval_parser)
    eval_parser = Trainer.add_argparse_args(eval_parser)
    args = eval_parser.parse_args()

    if len(args.path_to_model_hparams_file) == 0:
        eval_output_dir = "/".join(args.path_to_model_checkpoint.split("/")[:-1])
        args.path_to_model_hparams_file = os.path.join(eval_output_dir, "lightning_logs", "version_0", "hparams.yaml")

    evaluate(args)
Exemplo n.º 8
0
def main():
    parser = get_parser()
    parser = BertForNERTask.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    task_model = BertForNERTask(args)

    if len(args.pretrained_checkpoint) > 1:
        task_model.load_state_dict(
            torch.load(args.pretrained_checkpoint,
                       map_location=torch.device("cpu"))["state_dict"])

    checkpoint_callback = ModelCheckpoint(filepath=args.output_dir,
                                          save_top_k=args.max_keep_ckpt,
                                          save_last=False,
                                          monitor="val_f1",
                                          verbose=True,
                                          mode='max',
                                          period=-1)

    task_trainer = Trainer.from_argparse_args(
        args,
        checkpoint_callback=checkpoint_callback,
        deterministic=True,
        gradient_clip_val=args.gradient_clip_val)

    task_trainer.fit(task_model)

    # after training, use the model checkpoint which achieves the best f1 score on dev set to compute the f1 on test set.
    best_f1_on_dev, path_to_best_checkpoint = find_best_checkpoint_on_dev(
        args.output_dir,
        only_keep_the_best_ckpt=args.only_keep_the_best_ckpt_after_training)
    task_model.result_logger.info("=&" * 20)
    task_model.result_logger.info(f"Best F1 on DEV is {best_f1_on_dev}")
    task_model.result_logger.info(
        f"Best checkpoint on DEV set is {path_to_best_checkpoint}")
    checkpoint = torch.load(path_to_best_checkpoint)
    task_model.load_state_dict(checkpoint['state_dict'])
    task_trainer.test(task_model)
    task_model.result_logger.info("=&" * 20)
Exemplo n.º 9
0
def main():
    """main"""
    parser = get_parser()

    # add model specific args
    parser = BertNerTagger.add_model_specific_args(parser)

    # add all the available trainer options to argparse
    # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()

    # begin{add label2indx augument into the args.}
    label2idx = {}
    if 'conll' in args.dataname:
        label2idx = {"O": 0, "ORG": 1, "PER": 2, "LOC": 3, "MISC": 4}
    elif 'note' in args.dataname:
        label2idx = {
            'O': 0,
            'PERSON': 1,
            'ORG': 2,
            'GPE': 3,
            'DATE': 4,
            'NORP': 5,
            'CARDINAL': 6,
            'TIME': 7,
            'LOC': 8,
            'FAC': 9,
            'PRODUCT': 10,
            'WORK_OF_ART': 11,
            'MONEY': 12,
            'ORDINAL': 13,
            'QUANTITY': 14,
            'EVENT': 15,
            'PERCENT': 16,
            'LAW': 17,
            'LANGUAGE': 18
        }
    elif args.dataname == 'wnut16':
        label2idx = {
            'O': 0,
            'loc': 1,
            'facility': 2,
            'movie': 3,
            'company': 4,
            'product': 5,
            'person': 6,
            'other': 7,
            'tvshow': 8,
            'musicartist': 9,
            'sportsteam': 10
        }
    elif args.dataname == 'wnut17':
        label2idx = {
            'O': 0,
            'location': 1,
            'group': 2,
            'corporation': 3,
            'person': 4,
            'creative-work': 5,
            'product': 6
        }

    label2idx_list = []
    for lab, idx in label2idx.items():
        pair = (lab, idx)
        label2idx_list.append(pair)
    args.label2idx_list = label2idx_list
    # end{add label2indx augument into the args.}

    # begin{add case2idx augument into the args.}
    morph2idx_list = []
    morph2idx = {
        'isupper': 1,
        'islower': 2,
        'istitle': 3,
        'isdigit': 4,
        'other': 5
    }
    for morph, idx in morph2idx.items():
        pair = (morph, idx)
        morph2idx_list.append(pair)
    args.morph2idx_list = morph2idx_list
    # end{add case2idx augument into the args.}

    args.default_root_dir = args.default_root_dir + '_' + args.random_int

    if not os.path.exists(args.default_root_dir):
        os.makedirs(args.default_root_dir)

    fp_epoch_result = args.default_root_dir + '/epoch_results.txt'
    args.fp_epoch_result = fp_epoch_result

    text = '\n'.join([
        hp for hp in str(args).replace('Namespace(', '').replace(
            ')', '').split(', ')
    ])
    print(text)

    text = '\n'.join([
        hp for hp in str(args).replace('Namespace(', '').replace(
            ')', '').split(', ')
    ])
    fn_path = args.default_root_dir + '/' + args.param_name + '.txt'
    if fn_path is not None:
        with open(fn_path, mode='w') as text_file:
            text_file.write(text)

    model = BertNerTagger(args)
    if args.pretrained_checkpoint:
        model.load_state_dict(
            torch.load(args.pretrained_checkpoint,
                       map_location=torch.device('cpu'))["state_dict"])

    # save the best model
    checkpoint_callback = ModelCheckpoint(
        filepath=args.default_root_dir,
        save_top_k=1,
        verbose=True,
        monitor="span_f1",
        period=-1,
        mode="max",
    )
    trainer = Trainer.from_argparse_args(
        args, checkpoint_callback=checkpoint_callback)

    trainer.fit(model)
    trainer.test()
Exemplo n.º 10
0
from models.model_config import BertForQAConfig
from models.bert_qa import BertForQuestionAnswering
from loss import DiceLoss, FocalLoss
from utils.get_parser import get_parser
from task_datasets.squad_dataset import SquadDataset
from task_datasets.truncate_dataset import TruncateDataset
from metrics.squad_em_f1 import SquadEvalMetric

# from models.model_config import BertForQueryNERConfig
# from models.bert_query_ner import BertForQueryNER    



from tasks.squad.train import BertForQA
"""main"""
parser = get_parser()
# add model specific arguments.
parser = BertForQA.add_model_specific_args(parser)
# add all the available trainer options to argparse
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()

qa_model = BertForQA(args)

if len(args.pretrained_checkpoint) > 1:
    model.load_state_dict(torch.load(args.pretrained_checkpoint,map_location=torch.device('cpu'))["state_dict"])
# if args.load_ner_bert:
    # model.model.bert.load_state_dict(torch.load("./cached_models/ner_bert"),strict= False)

checkpoint_callback = ModelCheckpoint(
    filepath=args.output_dir,