Esempio n. 1
0
def load_traced_model(
    model_path: Union[str, Path],
    device: Device = "cpu",
    opt_level: str = None,
) -> ScriptModule:
    """
    Loads a traced model

    Args:
        model_path: Path to traced model
        device (str): Torch device
        opt_level (str): Apex FP16 init level, optional

    Returns:
        (ScriptModule): Traced model
    """
    # jit.load dont work with pathlib.Path
    model_path = str(model_path)

    if opt_level is not None:
        device = "cuda"

    model = torch.jit.load(model_path, map_location=device)

    if opt_level is not None:
        utils.assert_fp16_available()
        from apex import amp

        model = amp.initialize(model, optimizers=None, opt_level=opt_level)

    return model
Esempio n. 2
0
    def get_model(self, stage: str) -> _Model:
        model_params = self._config["model_params"]
        fp16 = model_params.pop("fp16", False)

        model = MODELS.get_from_params(**model_params)

        if fp16:
            utils.assert_fp16_available()
            model = Fp16Wrap(model)

        model = self._preprocess_model_for_stage(stage, model)
        model = self._postprocess_model_for_stage(stage, model)
        return model
Esempio n. 3
0
def process_components(
    model: _Model,
    criterion: _Criterion = None,
    optimizer: _Optimizer = None,
    scheduler: _Scheduler = None,
    distributed_params: Dict = None
) -> Tuple[_Model, _Criterion, _Optimizer, _Scheduler, torch.device]:
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    device = utils.get_device()

    model = maybe_recursive_call(model, "to", device=device)

    if utils.is_wrapped_with_ddp(model):
        pass
    elif len(distributed_params) > 0:
        assert isinstance(model, nn.Module)
        utils.assert_fp16_available()
        from apex import amp
        from apex.parallel import convert_syncbn_model

        distributed_rank = distributed_params.pop("rank", -1)
        syncbn = distributed_params.pop("syncbn", False)

        if distributed_rank > -1:
            torch.cuda.set_device(distributed_rank)
            torch.distributed.init_process_group(
                backend="nccl", init_method="env://"
            )

        model, optimizer = amp.initialize(
            model, optimizer, **distributed_params
        )

        if distributed_rank > -1:
            from apex.parallel import DistributedDataParallel
            model = DistributedDataParallel(model)

            if syncbn:
                model = convert_syncbn_model(model)
        elif torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
    elif torch.cuda.device_count() > 1:
        if isinstance(model, nn.Module):
            model = torch.nn.DataParallel(model)
        elif isinstance(model, dict):
            model = {k: torch.nn.DataParallel(v) for k, v in model.items()}

    model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device
Esempio n. 4
0
def process_components(
    model: _Model,
    criterion: _Criterion = None,
    optimizer: _Optimizer = None,
    scheduler: _Scheduler = None,
    distributed_params: Dict = None
) -> Tuple[_Model, _Criterion, _Optimizer, _Scheduler, torch.device]:
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    device = utils.get_device()

    if torch.cuda.is_available():
        benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
        cudnn.benchmark = benchmark

    model = model.to(device)

    if utils.is_wrapped_with_ddp(model):
        pass
    elif len(distributed_params) > 0:
        utils.assert_fp16_available()
        from apex import amp

        distributed_rank = distributed_params.pop("rank", -1)

        if distributed_rank > -1:
            torch.cuda.set_device(distributed_rank)
            torch.distributed.init_process_group(
                backend="nccl", init_method="env://"
            )

        model, optimizer = amp.initialize(
            model, optimizer, **distributed_params
        )

        if distributed_rank > -1:
            from apex.parallel import DistributedDataParallel
            model = DistributedDataParallel(model)
        elif torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
    elif torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    model = model.to(device)

    return model, criterion, optimizer, scheduler, device
Esempio n. 5
0
def process_components(
    model: Model,
    criterion: Criterion = None,
    optimizer: Optimizer = None,
    scheduler: Scheduler = None,
    distributed_params: Dict = None,
    device: Device = None,
) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
    """
    Returns the processed model, criterion, optimizer, scheduler and device

    Args:
        model (Model): torch model
        criterion (Criterion): criterion function
        optimizer (Optimizer): optimizer
        scheduler (Scheduler): scheduler
        distributed_params (dict, optional): dict with the parameters
            for distributed and FP16 methond
        device (Device, optional): device
    """
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    if device is None:
        device = utils.get_device()

    model: Model = maybe_recursive_call(model, "to", device=device)

    if utils.is_wrapped_with_ddp(model):
        pass
    elif len(distributed_params) > 0:
        assert isinstance(model, nn.Module)
        distributed_rank = distributed_params.pop("rank", -1)
        syncbn = distributed_params.pop("syncbn", False)

        if distributed_rank > -1:
            torch.cuda.set_device(distributed_rank)
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

        if "opt_level" in distributed_params:
            utils.assert_fp16_available()
            from apex import amp

            amp_result = amp.initialize(model, optimizer, **distributed_params)
            if optimizer is not None:
                model, optimizer = amp_result
            else:
                model = amp_result

            if distributed_rank > -1:
                from apex.parallel import DistributedDataParallel
                model = DistributedDataParallel(model)

                if syncbn:
                    from apex.parallel import convert_syncbn_model
                    model = convert_syncbn_model(model)

        if distributed_rank <= -1 and torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
    elif torch.cuda.device_count() > 1:
        if isinstance(model, nn.Module):
            model = torch.nn.DataParallel(model)
        elif isinstance(model, dict):
            model = {k: torch.nn.DataParallel(v) for k, v in model.items()}

    model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device
Esempio n. 6
0
def trace_model(
    model: Model,
    runner: Runner,
    batch=None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
    predict_params: dict = None,
) -> ScriptModule:
    """
    Traces model using runner and batch

    Args:
        model: Model to trace
        runner: Model's native runner that was used to train model
        batch: Batch to trace the model
        method_name (str): Model's method name that will be
            used as entrypoint during tracing
        mode (str): Mode for model to trace (``train`` or ``eval``)
        requires_grad (bool): Flag to use grads
        opt_level (str): Apex FP16 init level, optional
        device (str): Torch device
        predict_params (dict): additional parameters for model forward

    Returns:
        (ScriptModule): Traced model
    """
    if batch is None or runner is None:
        raise ValueError("Both batch and runner must be specified.")

    if mode not in ["train", "eval"]:
        raise ValueError(f"Unknown mode '{mode}'. Must be 'eval' or 'train'")

    predict_params = predict_params or {}

    tracer = _TracingModelWrapper(model, method_name)
    if opt_level is not None:
        utils.assert_fp16_available()
        # If traced in AMP we need to initialize the model before calling
        # the jit
        # https://github.com/NVIDIA/apex/issues/303#issuecomment-493142950
        from apex import amp

        model = model.to(device)
        model = amp.initialize(model, optimizers=None, opt_level=opt_level)
        # TODO: remove `check_trace=False`
        # after fixing this bug https://github.com/pytorch/pytorch/issues/23993
        params = {**predict_params, "check_trace": False}
    else:
        params = predict_params

    getattr(model, mode)()
    utils.set_requires_grad(model, requires_grad=requires_grad)

    _runner_model, _runner_device = runner.model, runner.device

    runner.model, runner.device = tracer, device
    runner.predict_batch(batch, **params)
    result: ScriptModule = tracer.tracing_result

    runner.model, runner.device = _runner_model, _runner_device
    return result