Exemplo n.º 1
0
def main():
    pl.seed_everything(1234)
    # ------------
    # args
    # ------------
    parser = get_parser()
    parser = BiafDependency.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # ------------
    # model
    # ------------
    model = BiafDependency(args)

    # load pretrained_model
    if args.pretrained:
        model.load_state_dict(
            torch.load(args.pretrained,
                       map_location=torch.device('cpu'))["state_dict"])

    # call backs
    checkpoint_callback = ModelCheckpoint(monitor='val_UAS',
                                          dirpath=args.default_root_dir,
                                          save_top_k=10,
                                          save_last=True,
                                          mode='max',
                                          verbose=True)

    early_stop_callback = early_stopping.EarlyStopping(monitor='val_UAS',
                                                       min_delta=0.00,
                                                       patience=10,
                                                       verbose=True,
                                                       mode='max')

    lr_monitor = LearningRateMonitor(logging_interval='step')
    print_model = ModelPrintCallback(print_modules=["model"])
    callbacks = [
        checkpoint_callback, lr_monitor, print_model, early_stop_callback
    ]
    if args.freeze_bert:
        callbacks.append(EvalCallback(["model.bert"]))

    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=callbacks,
                                            replace_sampler_ddp=False)
    trainer.fit(model)

    trainer.test()
Exemplo n.º 2
0
def main():
    pl.seed_everything(1234)
    # ------------
    # args
    # ------------
    parser = get_parser()
    parser = MrcSpanProposal.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # ------------
    # model
    # ------------
    model = MrcSpanProposal(args)

    # load pretrained_model
    if args.pretrained:
        model.load_state_dict(
            torch.load(args.pretrained, map_location=torch.device('cpu'))["state_dict"]
        )

    # call backs
    checkpoint_callback = ModelCheckpoint(
        monitor=f'val_top{MrcSpanProposal.acc_topk}_acc',
        dirpath=args.default_root_dir,
        save_top_k=10,
        save_last=True,
        mode='max',
        verbose=True
    )

    lr_monitor = LearningRateMonitor(logging_interval='step')
    print_model = ModelPrintCallback(print_modules=["model"])
    callbacks = [checkpoint_callback, lr_monitor, print_model]
    if args.freeze_bert:
        callbacks.append(EvalCallback(["model.bert"]))

    trainer = pl.Trainer.from_argparse_args(
        args,
        callbacks=callbacks,
        replace_sampler_ddp=False
    )

    trainer.fit(model)