예제 #1
0
        eval_dataset,
        batch_size=args.test_batch_size,
        sampler=eval_sampler,
        num_workers=args.data_workers,
        collate_fn=batchify_features_for_test,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------------------------
    # Preprare Model & Optimizer
    # -------------------------------------------------------------------------------------------
    logger.info(
        " ************************** Initilize Model ************************** "
    )
    try:
        model, checkpoint_epoch = KeyphraseSpanExtraction.load_checkpoint(
            args.eval_checkpoint, args)
        model.set_device()
    except ValueError:
        print("Could't Load Pretrain Model %s" % args.eval_checkpoint)

    if args.local_rank == 0:
        torch.distributed.barrier()

    if args.n_gpu > 1:
        model.parallelize()

    if args.local_rank != -1:
        model.distribute()

    # -------------------------------------------------------------------------------------------
    # Method Select
예제 #2
0
            train_data_loader
        ) // args.gradient_accumulation_steps * args.max_train_epochs

    # -------------------------------------------------------------------------------------------
    # Preprare Model & Optimizer
    # -------------------------------------------------------------------------------------------
    # 7.初始化模型和优化器
    logger.info(
        " ************************** Initialize Model & Optimizer ************************** "
    )
    """
    `args.checkpoint_file`, loaded checkpoint model continue training.
    `args.load_checkpoint`, default=False
    """
    if args.load_checkpoint and os.path.isfile(args.checkpoint_file):
        model = KeyphraseSpanExtraction.load_checkpoint(
            args.checkpoint_file, args)
    else:
        logger.info('Training model from scratch...')
        model = KeyphraseSpanExtraction(args)
    model.init_optimizer(num_total_steps=t_total)

    if args.local_rank == 0:
        torch.distributed.barrier()

    model.set_device()
    if args.n_gpu > 1:
        model.parallelize()

    if args.local_rank != -1:
        model.distribute()
예제 #3
0
파일: train.py 프로젝트: thunlp/BERT-KPE
        t_total = (len(train_data_loader) // args.gradient_accumulation_steps *
                   args.max_train_epochs)

    # -------------------------------------------------------------------------------------------
    # Preprare Model & Optimizer
    # -------------------------------------------------------------------------------------------
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    logger.info(
        " ************************** Initilize Model & Optimizer ************************** "
    )

    if args.load_checkpoint and os.path.isfile(args.checkpoint_file):
        model, checkpoint_epoch = KeyphraseSpanExtraction.load_checkpoint(
            args.checkpoint_file, args)
    else:
        logger.info("Training model from scratch...")
        model = KeyphraseSpanExtraction(args)

    # initial optimizer
    model.init_optimizer(num_total_steps=t_total)

    # -------------------------------------------------------------------------------------------
    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab
    # -------------------------------------------------------------------------------------------

    # set model device
    model.set_device()