Esempio n. 1
0
def _wrap_into_data_parallel_with_apex(
    model: RunnerModel, optimizer: RunnerOptimizer, distributed_params: Dict
):
    if isinstance(model, nn.Module):
        model = nn.Sequential(model)
        model, optimizer = _initialize_apex(model, optimizer, **distributed_params)
        model = torch.nn.DataParallel(model[0])
        model = _patch_forward(model)
    elif isinstance(model, dict):
        model = {k: nn.Sequential(v) for k, v in model.items()}
        model, optimizer = _initialize_apex(model, optimizer, **distributed_params)
        model = {k: nn.DataParallel(v[0]) for k, v in model.items()}
        model = {k: _patch_forward(v) for k, v in model.items()}
    else:
        raise NotImplementedError()

    return model, optimizer
Esempio n. 2
0
def pack_checkpoint(
    model: RunnerModel = None,
    criterion: RunnerCriterion = None,
    optimizer: RunnerOptimizer = None,
    scheduler: RunnerScheduler = None,
    **kwargs,
) -> Dict:
    """
    Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
    and some extra info ``**kwargs`` to torch-based checkpoint.

    Args:
        model: torch model
        criterion: torch criterion
        optimizer: torch optimizer
        scheduler: torch scheduler
        **kwargs: some extra info to pack

    Returns:
        torch-based checkpoint with ``model_state_dict``,
        ``criterion_state_dict``, ``optimizer_state_dict``,
        ``scheduler_state_dict`` keys.
    """
    checkpoint = kwargs

    if isinstance(model, dict):
        for key, value in model.items():
            model_module = get_nn_from_ddp_module(value)
            checkpoint[f"model_{key}_state_dict"] = maybe_recursive_call(
                model_module, "state_dict")
    else:
        model_module = get_nn_from_ddp_module(model)
        checkpoint["model_state_dict"] = maybe_recursive_call(
            model_module, "state_dict")

    for dict2save, name2save in zip(
        [criterion, optimizer, scheduler],
        ["criterion", "optimizer", "scheduler"],
    ):
        if dict2save is None:
            continue
        if isinstance(dict2save, dict):
            for key, value in dict2save.items():
                if value is not None:
                    state_dict2save = name2save + "_" + str(key)
                    # checkpoint[name2save_] = value
                    state_dict2save = state_dict2save + "_state_dict"
                    checkpoint[state_dict2save] = value.state_dict()
        else:
            # checkpoint[name2save] = dict2save
            name2save = name2save + "_state_dict"
            checkpoint[name2save] = dict2save.state_dict()
    return checkpoint
Esempio n. 3
0
def process_components(
    model: RunnerModel,
    criterion: Criterion = None,
    optimizer: Optimizer = None,
    scheduler: Scheduler = None,
    distributed_params: Dict = None,
    device: Device = None,
) -> Tuple[RunnerModel, Criterion, Optimizer, Scheduler, Device]:
    """
    Returns the processed model, criterion, optimizer, scheduler and device.

    Args:
        model: torch model
        criterion: criterion function
        optimizer: optimizer
        scheduler: scheduler
        distributed_params (dict, optional): dict with the parameters
            for distributed and FP16 method
        device (Device, optional): device

    Returns:
        tuple with processed model, criterion, optimizer, scheduler and device.

    Raises:
        ValueError: if device is None and TPU available,
            for using TPU need to manualy move model/optimizer/scheduler
            to a TPU device and pass device to a function.
        NotImplementedError: if model is not nn.Module or dict for multi-gpu,
            nn.ModuleDict for DataParallel not implemented yet
    """
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    distributed_params.update(get_distributed_params())

    if device is None and IS_XLA_AVAILABLE:
        raise ValueError(
            "TPU device is available. "
            "Please move model, optimizer and scheduler (if present) "
            "to TPU device manualy and specify a device or "
            "use CPU device.")

    if device is None:
        device = get_device()
    elif isinstance(device, str):
        device = torch.device(device)

    is_apex_enabled = (distributed_params.get("apex", False)
                       and check_apex_available())

    is_amp_enabled = (distributed_params.get("amp", False)
                      and check_amp_available())

    if is_apex_enabled and is_amp_enabled:
        raise ValueError("Both NVidia Apex and Torch.Amp are enabled. "
                         "You must choose only one mixed precision backend")
    model: Model = maybe_recursive_call(model, "to", device=device)

    if check_ddp_wrapped(model):
        pass
    # distributed data parallel run (ddp) (with apex support)
    elif get_rank() >= 0:
        assert isinstance(
            model,
            nn.Module), "Distributed training is not available for KV model"

        local_rank = distributed_params.pop("local_rank", 0) or 0
        device = f"cuda:{local_rank}"
        model = maybe_recursive_call(model, "to", device=device)

        syncbn = distributed_params.pop("syncbn", False)

        if is_apex_enabled:
            import apex

            if syncbn:
                model = apex.parallel.convert_syncbn_model(model)

            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)
            model = apex.parallel.DistributedDataParallel(model)
        else:
            if syncbn:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank)
    # data parallel run (dp) (with apex support)
    else:
        is_data_parallel = (torch.cuda.device_count() > 1
                            and device.type != "cpu" and device.index is None)

        if is_apex_enabled and not is_data_parallel:
            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)

        elif not is_apex_enabled and is_data_parallel:
            if isinstance(model, nn.Module):
                model = nn.DataParallel(model)
            elif isinstance(model, dict):
                model = {k: nn.DataParallel(v) for k, v in model.items()}
            else:
                raise NotImplementedError()

        elif is_apex_enabled and is_data_parallel:
            model, optimizer = _wrap_into_data_parallel_with_apex(
                model, optimizer, distributed_params)

    model: Model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device