예제 #1
0
def load_model(args, checkpoint_path):
    label_list = ["0", "1"]
    num_labels = len(label_list)
    args.model_type = args.model_type.lower()
    configObj = MSMarcoConfigDict[args.model_type]
    args.model_name_or_path = checkpoint_path
    #print(checkpoint_path)

    model = configObj.model_class(args)

    saved_state = load_states_from_checkpoint(checkpoint_path)
    model_to_load = get_model_obj(model)
    logger.info('Loading saved model state ...')
    model_to_load.load_state_dict(saved_state.model_dict)

    model.to(args.device)
    logger.info("Inference parameters %s", args)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )
    return model
예제 #2
0
def _load_saved_state(model, optimizer, scheduler,
                      saved_state: CheckpointState):
    epoch = saved_state.epoch
    step = saved_state.offset
    logger.info('Loading checkpoint @ step=%s', step)

    model_to_load = get_model_obj(model)
    logger.info('Loading saved model state ...')
    model_to_load.load_state_dict(
        saved_state.model_dict)  # set strict=False if you use extra projection

    return step
예제 #3
0
def _save_checkpoint(args, model, optimizer, scheduler, step: int) -> str:
    offset = step
    epoch = 0
    model_to_save = get_model_obj(model)
    cp = os.path.join(args.output_dir, 'checkpoint-' + str(offset))

    meta_params = {}

    state = CheckpointState(model_to_save.state_dict(), optimizer.state_dict(),
                            scheduler.state_dict(), offset, epoch, meta_params)
    torch.save(state._asdict(), cp)
    logger.info('Saved checkpoint at %s', cp)
    return cp
예제 #4
0
def _load_saved_state(model, optimizer, scheduler, saved_state: CheckpointState,load_optimizer_scheduler=False):
    epoch = saved_state.epoch
    step = saved_state.offset
    logger.info('Loading checkpoint @ step=%s', step)

    model_to_load = get_model_obj(model)
    logger.info('Loading saved model state ...')
    model_to_load.load_state_dict(saved_state.model_dict)  # set strict=False if you use extra projection

    if load_optimizer_scheduler:
        optimizer.load_state_dict(saved_state.optimizer_dict)
        scheduler.load_state_dict(saved_state.scheduler_dict)
        logger.info('Loading the optimizer and scheduler to resume training')
    # model.device

    return step, model, optimizer, scheduler
예제 #5
0
def load_model(args, checkpoint_path,load_flag=False):
    label_list = ["0", "1"]
    num_labels = len(label_list)
    args.model_type = args.model_type.lower()
    configObj = MSMarcoConfigDict[args.model_type]
    args.model_name_or_path = checkpoint_path

    model = configObj.model_class(args)

    if args.init_from_fp16_ckpt:
        checkpoint_step = checkpoint_path.split('-')[-1].replace('/','')
        init_step = args.pretrained_checkpoint_dir.split('-')[-1].replace('/','')
        load_flag = checkpoint_step > init_step

    if args.fp16 and load_flag:
        checkpoint = torch.load(checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        new_state_dict = OrderedDict()
        for k, v in checkpoint['model'].items():
            name = k[7:]
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:
        saved_state = load_states_from_checkpoint(checkpoint_path)
        model_to_load = get_model_obj(model)
        logger.info('Loading saved model state ...')
        model_to_load.load_state_dict(saved_state.model_dict)
    
    model.is_representation_l2_normalization = args.representation_l2_normalization
    
    model.to(args.device)
    logger.info("Inference parameters %s", args)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
        )
    return model