def set_task(
        DATASET: str, BATCH_SIZE: int, path: str,
        N_WORKERS: int) -> Union[Dataloader, Dataloader, List, List, List]:
    """
    Setting task parameters
    Args:
        DATASET: Dataset name
        BATCH_SIZE: training batch size
        path: path to dataset folder
        N_WORKERS: num workers

    Returns:
    train_loader - loader for training set
    val_loader - loader for validation set
    criterions - loss functions
    list_of_encoders - encoder models
    list_of_decoders - decoder models
    """
    set_seed(999)
    if DATASET == "CIFAR-10":
        train_dst = CIFAR10Loader(root=path, train=True)
        train_loader = train_dst.get_loader(batch_size=BATCH_SIZE,
                                            shuffle=True)

        val_dst = CIFAR10Loader(root=path, train=False)
        val_loader = val_dst.get_loader()

        list_of_encoders = [ResNet18]
        list_of_decoders = [MultiDec] * 10
        criterions = [torch.nn.BCEWithLogitsLoss()] * 10

    elif DATASET == "MNIST":
        train_dst = MNIST(root=path,
                          train=True,
                          download=True,
                          transform=global_transformer(),
                          multi=True)
        train_loader = torch.utils.data.DataLoader(train_dst,
                                                   batch_size=BATCH_SIZE,
                                                   shuffle=True,
                                                   num_workers=N_WORKERS)

        val_dst = MNIST(root=path,
                        train=False,
                        download=True,
                        transform=global_transformer(),
                        multi=True)
        val_loader = torch.utils.data.DataLoader(val_dst,
                                                 batch_size=BATCH_SIZE,
                                                 num_workers=N_WORKERS)

        list_of_encoders = [MultiLeNetEnc]
        list_of_decoders = [MultiLeNetDec] * 2
        criterions = [torch.nn.NLLLoss()] * 2

    elif DATASET == "Cityscapes":
        cityscapes_augmentations = Compose(
            [RandomRotate(10), RandomHorizontallyFlip()])
        img_rows = 256
        img_cols = 512

        train_dst = CITYSCAPES(root=path,
                               is_transform=True,
                               split=['train'],
                               img_size=(img_rows, img_cols),
                               augmentations=cityscapes_augmentations)
        train_loader = torch.utils.data.DataLoader(train_dst,
                                                   batch_size=BATCH_SIZE,
                                                   shuffle=True,
                                                   num_workers=N_WORKERS)

        val_dst = CITYSCAPES(root=path,
                             split=['val'],
                             img_size=(img_rows, img_cols))
        val_loader = torch.utils.data.DataLoader(val_dst,
                                                 batch_size=BATCH_SIZE,
                                                 num_workers=N_WORKERS)

        list_of_encoders = [get_segmentation_encoder]
        list_of_decoders = [
            partialclass(SegmentationDecoder, num_class=19, task_type="C"),
            partialclass(SegmentationDecoder, num_class=2, task_type="R"),
            partialclass(SegmentationDecoder, num_class=1, task_type="R")
        ]
        criterions = [cross_entropy2d, l1_loss_instance, l1_loss_depth]

    elif DATASET == 'NLP':

        export_model.export_model(
            hf_pretrained_model_name_or_path="bert-base-uncased",
            output_base_path="./models/bert-base-uncased",
        )

        for task_name in ["rte", "stsb", "commonsenseqa"]:
            tokenize_and_cache.main(
                tokenize_and_cache.RunConfiguration(
                    task_config_path=f"./tasks/configs/{task_name}_config.json",
                    hf_pretrained_model_name_or_path="bert-base-uncased",
                    output_dir=f"./cache/{task_name}",
                    phases=["train", "val"],
                ))

        jiant_run_config = configurator.SimpleAPIMultiTaskConfigurator(
            task_config_base_path="./tasks/configs",
            task_cache_base_path="./cache",
            train_task_name_list=["rte", "stsb", "commonsenseqa"],
            val_task_name_list=["rte", "stsb", "commonsenseqa"],
            train_batch_size=4,
            eval_batch_size=8,
            epochs=0.5,
            num_gpus=1,
        ).create_config()

        jiant_task_container = container_setup.create_jiant_task_container_from_dict(
            jiant_run_config)

        jiant_model = jiant_model_setup.setup_jiant_model(
            hf_pretrained_model_name_or_path="bert-base-uncased",
            model_config_path="./models/bert-base-uncased/model/config.json",
            task_dict=jiant_task_container.task_dict,
            taskmodels_config=jiant_task_container.taskmodels_config,
        )

        train_cache = jiant_task_container.task_cache_dict['stsb']["train"]
        val_cache = jiant_task_container.task_cache_dict['stsb']["val"]

        train_dataloader = get_train_dataloader_from_cache(
            train_cache, task, 4)
        val_dataloader = get_eval_dataloader_from_cache(val_cache, task, 4)

        list_of_encoders = [jiant_model.encoder]
        decoder1 = deepcopy(jiant_model.taskmodels_dict['stsb'].head)
        reset(decoder1)
        decoder2 = deepcopy(decoder1)
        reset(decoder2)
        decoder3 = deepcopy(decoder2)
        reset(decoder3)

        list_of_decoders = [
            lambda: decoder1, lambda: decoder2, lambda: decoder3
        ]
        criterions = [
            torch.nn.MSELoss(),
            torch.nn.MSELoss(),
            torch.nn.MSELoss()
        ]

    return train_loader, val_loader, criterions, list_of_encoders, list_of_decoders
示例#2
0
def setup_runner(
    args: RunConfiguration,
    jiant_task_container: container_setup.JiantTaskContainer,
    quick_init_out,
    verbose: bool = True,
) -> jiant_runner.JiantRunner:
    """Setup jiant model, optimizer, and runner, and return runner.

    Args:
        args (RunConfiguration): configuration carrying command line args specifying run params.
        jiant_task_container (container_setup.JiantTaskContainer): task and sampler configs.
        quick_init_out (QuickInitContainer): device (GPU/CPU) and logging configuration.
        verbose: If True, enables printing configuration info (to standard out).

    Returns:
        jiant_runner.JiantRunner

    """
    # TODO document why the distributed.only_first_process() context manager is being used here.
    with distributed.only_first_process(local_rank=args.local_rank):
        # load the model
        jiant_model = jiant_model_setup.setup_jiant_model(
            hf_pretrained_model_name_or_path=args.
            hf_pretrained_model_name_or_path,
            model_config_path=args.model_config_path,
            task_dict=jiant_task_container.task_dict,
            taskmodels_config=jiant_task_container.taskmodels_config,
        )
        jiant_model_setup.delegate_load_from_path(
            jiant_model=jiant_model,
            weights_path=args.model_path,
            load_mode=args.model_load_mode)
        jiant_model.to(quick_init_out.device)

    optimizer_scheduler = model_setup.create_optimizer(
        model=jiant_model,
        learning_rate=args.learning_rate,
        t_total=jiant_task_container.global_train_config.max_steps,
        warmup_steps=jiant_task_container.global_train_config.warmup_steps,
        warmup_proportion=None,
        optimizer_type=args.optimizer_type,
        verbose=verbose,
    )
    jiant_model, optimizer = model_setup.raw_special_model_setup(
        model=jiant_model,
        optimizer=optimizer_scheduler.optimizer,
        fp16=args.fp16,
        fp16_opt_level=args.fp16_opt_level,
        n_gpu=quick_init_out.n_gpu,
        local_rank=args.local_rank,
    )
    print("-----------------", jiant_model)
    optimizer_scheduler.optimizer = optimizer
    rparams = jiant_runner.RunnerParameters(
        local_rank=args.local_rank,
        n_gpu=quick_init_out.n_gpu,
        fp16=args.fp16,
        max_grad_norm=args.max_grad_norm,
    )
    runner = jiant_runner.JiantRunner(
        jiant_task_container=jiant_task_container,
        jiant_model=jiant_model,
        optimizer_scheduler=optimizer_scheduler,
        device=quick_init_out.device,
        rparams=rparams,
        log_writer=quick_init_out.log_writer,
    )
    return runner
示例#3
0
def setup_runner(
    args: RunConfiguration,
    jiant_task_container: container_setup.JiantTaskContainer,
    quick_init_out,
    verbose: bool = True,
) -> jiant_runner.JiantRunner:
    """Setup jiant model, optimizer, and runner, and return runner.

    Args:
        args (RunConfiguration): configuration carrying command line args specifying run params.
        jiant_task_container (container_setup.JiantTaskContainer): task and sampler configs.
        quick_init_out (QuickInitContainer): device (GPU/CPU) and logging configuration.
        verbose: If True, enables printing configuration info (to standard out).

    Returns:
        jiant_runner.JiantRunner

    """
    # TODO document why the distributed.only_first_process() context manager is being used here.
    jiant_model = jiant_model_setup.setup_jiant_model(
        model_type=args.model_type,
        model_config_path=args.model_config_path,
        tokenizer_path=args.model_tokenizer_path,
        task_dict=jiant_task_container.task_dict,
        taskmodels_config=jiant_task_container.taskmodels_config,
    )
    weights_dict = torch.load(args.model_path)
    jiant_model_setup.load_encoder_from_transformers_weights(
        encoder=jiant_model.encoder, weights_dict=weights_dict,
    )
    if args.adapter_config_path:
        adapter_config = adapters_modeling.AdapterConfig.from_dict(
            py_io.read_json(args.adapter_config_path),
        )
    else:
        adapter_config = adapters_modeling.AdapterConfig()
    adapters_modeling.add_shared_adapters_to_jiant_model(
        jiant_model=jiant_model, adapter_config=adapter_config,
    )
    if args.adapters_load_mode and args.adapters_load_path:
        adapters_modeling.delegate_load_for_shared_adapters(
            jiant_model=jiant_model,
            state_dict=torch.load(args.adapters_load_path),
            load_mode=args.adapters_load_mode,
        )
    jiant_model.to(quick_init_out.device)

    (
        optimized_named_parameters,
        _,
    ) = adapters_modeling.get_optimized_named_parameters_for_jiant_model_with_adapters(
        jiant_model=jiant_model,
    )
    optimizer_scheduler = model_setup.create_optimizer_from_params(
        named_parameters=optimized_named_parameters,
        learning_rate=args.learning_rate,
        t_total=jiant_task_container.global_train_config.max_steps,
        warmup_steps=jiant_task_container.global_train_config.warmup_steps,
        warmup_proportion=None,
        verbose=verbose,
    )
    jiant_model, optimizer = model_setup.raw_special_model_setup(
        model=jiant_model,
        optimizer=optimizer_scheduler.optimizer,
        fp16=args.fp16,
        fp16_opt_level=args.fp16_opt_level,
        n_gpu=quick_init_out.n_gpu,
        local_rank=args.local_rank,
    )
    optimizer_scheduler.optimizer = optimizer
    rparams = jiant_runner.RunnerParameters(
        local_rank=args.local_rank,
        n_gpu=quick_init_out.n_gpu,
        fp16=args.fp16,
        max_grad_norm=args.max_grad_norm,
    )
    runner = jiant_runner.JiantRunner(
        jiant_task_container=jiant_task_container,
        jiant_model=jiant_model,
        optimizer_scheduler=optimizer_scheduler,
        device=quick_init_out.device,
        rparams=rparams,
        log_writer=quick_init_out.log_writer,
    )
    return runner