예제 #1
0
def setup_runner(
    args: RunConfiguration,
    jiant_task_container: container_setup.JiantTaskContainer,
    quick_init_out,
    verbose: bool = True,
) -> jiant_runner.JiantRunner:
    """Setup jiant model, optimizer, and runner, and return runner.

    Args:
        args (RunConfiguration): configuration carrying command line args specifying run params.
        jiant_task_container (container_setup.JiantTaskContainer): task and sampler configs.
        quick_init_out (QuickInitContainer): device (GPU/CPU) and logging configuration.
        verbose: If True, enables printing configuration info (to standard out).

    Returns:
        jiant_runner.JiantRunner

    """
    # TODO document why the distributed.only_first_process() context manager is being used here.
    with distributed.only_first_process(local_rank=args.local_rank):
        # load the model
        jiant_model = jiant_model_setup.setup_jiant_model(
            hf_pretrained_model_name_or_path=args.
            hf_pretrained_model_name_or_path,
            model_config_path=args.model_config_path,
            task_dict=jiant_task_container.task_dict,
            taskmodels_config=jiant_task_container.taskmodels_config,
        )
        jiant_model_setup.delegate_load_from_path(
            jiant_model=jiant_model,
            weights_path=args.model_path,
            load_mode=args.model_load_mode)
        jiant_model.to(quick_init_out.device)

    optimizer_scheduler = model_setup.create_optimizer(
        model=jiant_model,
        learning_rate=args.learning_rate,
        t_total=jiant_task_container.global_train_config.max_steps,
        warmup_steps=jiant_task_container.global_train_config.warmup_steps,
        warmup_proportion=None,
        optimizer_type=args.optimizer_type,
        verbose=verbose,
    )
    jiant_model, optimizer = model_setup.raw_special_model_setup(
        model=jiant_model,
        optimizer=optimizer_scheduler.optimizer,
        fp16=args.fp16,
        fp16_opt_level=args.fp16_opt_level,
        n_gpu=quick_init_out.n_gpu,
        local_rank=args.local_rank,
    )
    print("-----------------", jiant_model)
    optimizer_scheduler.optimizer = optimizer
    rparams = jiant_runner.RunnerParameters(
        local_rank=args.local_rank,
        n_gpu=quick_init_out.n_gpu,
        fp16=args.fp16,
        max_grad_norm=args.max_grad_norm,
    )
    runner = jiant_runner.JiantRunner(
        jiant_task_container=jiant_task_container,
        jiant_model=jiant_model,
        optimizer_scheduler=optimizer_scheduler,
        device=quick_init_out.device,
        rparams=rparams,
        log_writer=quick_init_out.log_writer,
    )
    return runner
예제 #2
0
def setup_runner(
    args: RunConfiguration,
    jiant_task_container: container_setup.JiantTaskContainer,
    quick_init_out,
    verbose: bool = True,
) -> jiant_runner.JiantRunner:
    """Setup jiant model, optimizer, and runner, and return runner.

    Args:
        args (RunConfiguration): configuration carrying command line args specifying run params.
        jiant_task_container (container_setup.JiantTaskContainer): task and sampler configs.
        quick_init_out (QuickInitContainer): device (GPU/CPU) and logging configuration.
        verbose: If True, enables printing configuration info (to standard out).

    Returns:
        jiant_runner.JiantRunner

    """
    # TODO document why the distributed.only_first_process() context manager is being used here.
    jiant_model = jiant_model_setup.setup_jiant_model(
        model_type=args.model_type,
        model_config_path=args.model_config_path,
        tokenizer_path=args.model_tokenizer_path,
        task_dict=jiant_task_container.task_dict,
        taskmodels_config=jiant_task_container.taskmodels_config,
    )
    weights_dict = torch.load(args.model_path)
    jiant_model_setup.load_encoder_from_transformers_weights(
        encoder=jiant_model.encoder, weights_dict=weights_dict,
    )
    if args.adapter_config_path:
        adapter_config = adapters_modeling.AdapterConfig.from_dict(
            py_io.read_json(args.adapter_config_path),
        )
    else:
        adapter_config = adapters_modeling.AdapterConfig()
    adapters_modeling.add_shared_adapters_to_jiant_model(
        jiant_model=jiant_model, adapter_config=adapter_config,
    )
    if args.adapters_load_mode and args.adapters_load_path:
        adapters_modeling.delegate_load_for_shared_adapters(
            jiant_model=jiant_model,
            state_dict=torch.load(args.adapters_load_path),
            load_mode=args.adapters_load_mode,
        )
    jiant_model.to(quick_init_out.device)

    (
        optimized_named_parameters,
        _,
    ) = adapters_modeling.get_optimized_named_parameters_for_jiant_model_with_adapters(
        jiant_model=jiant_model,
    )
    optimizer_scheduler = model_setup.create_optimizer_from_params(
        named_parameters=optimized_named_parameters,
        learning_rate=args.learning_rate,
        t_total=jiant_task_container.global_train_config.max_steps,
        warmup_steps=jiant_task_container.global_train_config.warmup_steps,
        warmup_proportion=None,
        verbose=verbose,
    )
    jiant_model, optimizer = model_setup.raw_special_model_setup(
        model=jiant_model,
        optimizer=optimizer_scheduler.optimizer,
        fp16=args.fp16,
        fp16_opt_level=args.fp16_opt_level,
        n_gpu=quick_init_out.n_gpu,
        local_rank=args.local_rank,
    )
    optimizer_scheduler.optimizer = optimizer
    rparams = jiant_runner.RunnerParameters(
        local_rank=args.local_rank,
        n_gpu=quick_init_out.n_gpu,
        fp16=args.fp16,
        max_grad_norm=args.max_grad_norm,
    )
    runner = jiant_runner.JiantRunner(
        jiant_task_container=jiant_task_container,
        jiant_model=jiant_model,
        optimizer_scheduler=optimizer_scheduler,
        device=quick_init_out.device,
        rparams=rparams,
        log_writer=quick_init_out.log_writer,
    )
    return runner