Exemple #1
0
def trace_model(
    model: Model,
    predict_fn: Callable,
    batch=None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
    predict_params: dict = None,
) -> jit.ScriptModule:
    """Traces model using runner and batch.

    Args:
        model: Model to trace
        predict_fn: Function to run prediction with the model provided,
            takes model, inputs parameters
        batch: Batch to trace the model
        method_name (str): Model's method name that will be
            used as entrypoint during tracing
        mode (str): Mode for model to trace (``train`` or ``eval``)
        requires_grad (bool): Flag to use grads
        opt_level (str): Apex FP16 init level, optional
        device (str): Torch device
        predict_params (dict): additional parameters for model forward

    Returns:
        jit.ScriptModule: Traced model

    Raises:
        ValueError: if both batch and predict_fn must be specified or
          mode is not in 'eval' or 'train'.
    """
    if batch is None or predict_fn is None:
        raise ValueError("Both batch and predict_fn must be specified.")

    if mode not in ["train", "eval"]:
        raise ValueError(f"Unknown mode '{mode}'. Must be 'eval' or 'train'")

    predict_params = predict_params or {}

    tracer = _TracingModelWrapper(model, method_name)
    if opt_level is not None:
        assert_fp16_available()
        # If traced in AMP we need to initialize the model before calling
        # the jit
        # https://github.com/NVIDIA/apex/issues/303#issuecomment-493142950
        from apex import amp

        model = model.to(device)
        model = amp.initialize(model, optimizers=None, opt_level=opt_level)

    getattr(model, mode)()
    set_requires_grad(model, requires_grad=requires_grad)

    predict_fn(tracer, batch, **predict_params)

    return tracer.tracing_result
Exemple #2
0
def trace_model(
    model: Model,
    runner: "Runner",
    batch=None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
    predict_params: dict = None,
) -> ScriptModule:
    """Traces model using runner and batch.

    Args:
        model: Model to trace
        runner: Model's native runner that was used to train model
        batch: Batch to trace the model
        method_name (str): Model's method name that will be
            used as entrypoint during tracing
        mode (str): Mode for model to trace (``train`` or ``eval``)
        requires_grad (bool): Flag to use grads
        opt_level (str): Apex FP16 init level, optional
        device (str): Torch device
        predict_params (dict): additional parameters for model forward

    Returns:
        (ScriptModule): Traced model
    """
    if batch is None or runner is None:
        raise ValueError("Both batch and runner must be specified.")

    if mode not in ["train", "eval"]:
        raise ValueError(f"Unknown mode '{mode}'. Must be 'eval' or 'train'")

    predict_params = predict_params or {}

    tracer = _TracingModelWrapper(model, method_name)
    if opt_level is not None:
        assert_fp16_available()
        # If traced in AMP we need to initialize the model before calling
        # the jit
        # https://github.com/NVIDIA/apex/issues/303#issuecomment-493142950
        from apex import amp

        model = model.to(device)
        model = amp.initialize(model, optimizers=None, opt_level=opt_level)
        # TODO: remove `check_trace=False`
        # after fixing this bug https://github.com/pytorch/pytorch/issues/23993
        params = {**predict_params, "check_trace": False}
    else:
        params = predict_params

    getattr(model, mode)()
    set_requires_grad(model, requires_grad=requires_grad)

    _runner_model, _runner_device = runner.model, runner.device

    runner.model, runner.device = tracer, device
    runner.predict_batch(batch, **params)
    result: ScriptModule = tracer.tracing_result

    runner.model, runner.device = _runner_model, _runner_device
    return result