Beispiel #1
0
def get_network(
    name: str,
    input_shape: List[int],
    *,
    cache_dir: Optional[str] = None,
) -> Tuple[IRModule, Dict[str, NDArray], Tuple[str, List[int], str]]:
    """Get the symbol definition and random weight of a network

    Parameters
    ----------
    name : str
        The name of the network.
    input_shape : List[int]
        The shape of the input tensor.
    cache_dir : Optional[str], optional
        The directory to cache the generated network.
        If not specified, the cache will be disabled.

    Returns
    -------
    mod : IRModule
        The IRModule representing the network.
    params : Dict[str, NDArray]
        The parameters of the networks.
    inputs : Tuple[str, List[int], str]
        The name, shape and dtype of the input tensor.
    """

    mod: IRModule
    params: Dict[str, NDArray]
    inputs: Tuple[str, List[int], str]
    params_bytearray: bytearray

    filename = f'relay-{name}-{",".join(str(i) for i in input_shape)}.json'
    cached = _load_cache(cache_dir, filename)
    if cached is None:
        with multiprocessing.Pool(processes=1) as pool:
            result = pool.map(_get_network, [(name, input_shape)])
        ((mod, params_bytearray, inputs), ) = result
        cached = [mod, params_bytearray, inputs]
        _save_cache(cache_dir, filename, cached)
    mod, params_bytearray, inputs = cached
    params = load_param_dict(params_bytearray)
    return mod, params, inputs
def extract_and_save_tasks(cache_file):
    """Extract tuning tasks and cache the nonspatial ones in the given directory.

    Parameters
    ----------
    cache_file : str
        The filename of the cached model.

    Returns
    -------
    None
    """

    mod, params_bytearray, _ = _load_cache(args.model_cache_dir, cache_file)
    params = load_param_dict(params_bytearray)
    try:
        extracted_tasks = ms.extract_task_from_relay(mod,
                                                     target=args.target,
                                                     params=params)
    except tvm.error.TVMError as error:
        print(str(error))
        return
    task_cache_path = os.path.join(
        args.task_cache_dir,
        cache_file.split(".")[0] + "_extracted_tasks.json")
    is_spatial = tvm.get_global_func("tir.schedule.IsSpatialPrimFunc")
    with open(task_cache_path, "w", encoding="utf8") as file:
        for i, task in enumerate(extracted_tasks):
            subgraph = task.dispatched[0]
            prim_func = subgraph[subgraph.get_global_vars()[0]]
            if not is_spatial(prim_func):
                subgraph_str = save_json(subgraph)
                json_obj = [task.task_name, json.loads(subgraph_str)]
                json_str = json.dumps(json_obj)
                assert "\n" not in json_str, "Failed to generate single line string."
                if i == len(extracted_tasks) - 1:
                    file.write(json_str)
                else:
                    file.write(json_str + "\n")
Beispiel #3
0
def _deserialize_params(
        params: Optional[bytearray]) -> Optional[Dict[str, NDArray]]:
    if params is None:
        return None
    return load_param_dict(params)