コード例 #1
0
def save_benchmark_results(
    model: Any,
    data: Any,
    batch_size: int,
    iterations: int,
    warmup_iterations: int,
    framework: Optional[str],
    provider: Optional[str] = None,
    device: Optional[str] = None,
    save_path: Optional[str] = None,
    framework_args: Dict[str, Any] = {},
    show_progress: bool = False,
):
    """
    Saves the benchmark results ran for specific framework.
    If path is provided, will save to a json file at the path.
    If path is not provided, will print out the info.

    If no framework is provided, will detect the framework based on the model.

    :param model: model to benchmark
    :param data: data to benchmark
    :param batch_size: batch size
    :param iterations: number of iterations
    :param warmup_iterations: number of warmup iterations
    :param framework: the specific framework run the benchmark in
    :param provider: the specific inference provider to use
    :param device: the specific device to use
    :param save_path: path to save the benchmark results
    :param framework_args: additional framework specific arguments to
        pass to the runner
    :param show_progress: True to show a tqdm bar when running, False otherwise
    """
    results = execute_in_sparseml_framework(
        framework if framework is not None else model,
        "run_benchmark",
        model,
        data,
        batch_size=batch_size,
        iterations=iterations,
        warmup_iterations=warmup_iterations,
        provider=provider,
        device=device,
        framework_args=framework_args,
        show_progress=show_progress,
    )

    if save_path:
        save_path = clean_path(save_path)
        create_parent_dirs(save_path)

        with open(save_path, "w") as file:
            file.write(results.json(indent=4))

        _LOGGER.info(f"saved benchmark results in file at {save_path}"),
    else:
        print(results.json(indent=4))
        _LOGGER.info("printed out benchmark results")
コード例 #2
0
ファイル: info.py プロジェクト: PIlotcnc/neural
def sparsification_info(framework: Any) -> SparsificationInfo:
    """
    Get the available setup for sparsifying model in the given framework.

    :param framework: The item to detect the ML framework for.
        See :func:`detect_framework` for more information.
    :type framework: Any
    :return: The sparsification info for the given framework
    :rtype: SparsificationInfo
    """
    _LOGGER.debug("getting sparsification info for framework %s", framework)
    info: SparsificationInfo = execute_in_sparseml_framework(
        framework, "sparsification_info")
    _LOGGER.info("retrieved sparsification info for framework %s: %s",
                 framework, info)

    return info
コード例 #3
0
    def load_data(self, data: Any, **kwargs) -> Iterable[Any]:
        """
        Uses the framework's load_data method to load the data into
        an iterable for use in benchmarking.

        :param data: data to load
        :param kwargs: additional arguments to pass to the framework's load_data method
        :return: an iterable of the loaded data
        """
        return execute_in_sparseml_framework(
            self.framework,
            "load_data",
            data=data,
            model=self.model,
            batch_size=self.batch_size,
            total_iterations=self.warmup_iterations + self.iterations,
            **kwargs,
        )
コード例 #4
0
ファイル: info.py プロジェクト: PIlotcnc/neural
def framework_info(framework: Any) -> FrameworkInfo:
    """
    Detect the information for the given ML framework such as package versions,
    availability for core actions such as training and inference,
    sparsification support, and inference provider support.

    :param framework: The item to detect the ML framework for.
        See :func:`detect_framework` for more information.
    :type framework: Any
    :return: The framework info for the given framework
    :rtype: FrameworkInfo
    """
    _LOGGER.debug("getting system info for framework %s", framework)
    info: FrameworkInfo = execute_in_sparseml_framework(
        framework, "framework_info")
    _LOGGER.info("retrieved system info for framework %s: %s", framework, info)

    return info
コード例 #5
0
ファイル: test_base.py プロジェクト: kevinaer/sparseml
def test_execute_in_sparseml_framework():
    with pytest.raises(ValueError):
        execute_in_sparseml_framework(Framework.unknown, "unknown")

    with pytest.raises(ValueError):
        execute_in_sparseml_framework(Framework.onnx, "unknown")