Esempio n. 1
0
def get_quantize_node_info(activation_post_process: Callable) -> Tuple[str, Union[Callable, str], Dict[str, Any]]:
    ''' Given an activation_post_process module,
    return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
    of extracted qparams from the module
    '''
    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
    quantize_op : Optional[Union[Callable, str]] = None
    if dtype in [torch.quint8, torch.qint8]:
        node_type = "call_function"
        scale, zero_point = activation_post_process.calculate_qparams()  # type: ignore[attr-defined]
        if is_per_channel(activation_post_process.qscheme):  # type: ignore[attr-defined]
            ch_axis = int(activation_post_process.ch_axis)  # type: ignore[attr-defined]
            qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
            quantize_op = torch.quantize_per_channel
        else:
            scale = float(scale)
            zero_point = int(zero_point)
            qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
            quantize_op = torch.quantize_per_tensor
    elif dtype == torch.float16:
        node_type = "call_method"
        quantize_op = "to"
        qparams = {"_dtype_": dtype}
    else:
        raise Exception("Unsupported dtype in get_quantize_node_info:" + str(dtype))
    assert quantize_op is not None
    return node_type, quantize_op, qparams
Esempio n. 2
0
def get_quantize_node_info(
    activation_post_process: Callable
) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]:
    ''' Given an activation_post_process module,
    return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
    of extracted qparams from the module
    '''
    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
    compute_dtype = None
    if hasattr(activation_post_process, "compute_dtype"):
        compute_dtype = activation_post_process.compute_dtype  # type: ignore[attr-defined]
    quantize_op: Optional[Union[Callable, str]] = None
    if dtype in [torch.quint8, torch.qint8]:
        node_type = "call_function"
        scale, zero_point = activation_post_process.calculate_qparams(
        )  # type: ignore[attr-defined]
        if is_per_channel(
                activation_post_process.qscheme):  # type: ignore[attr-defined]
            ch_axis = int(
                activation_post_process.ch_axis)  # type: ignore[attr-defined]
            qparams = {
                "_scale_": scale,
                "_zero_point_": zero_point,
                "_axis_": ch_axis,
                "_dtype_": dtype
            }
            quantize_op = torch.quantize_per_channel
        else:
            scale = float(scale)
            zero_point = int(zero_point)
            qparams = {
                "_scale_": scale,
                "_zero_point_": zero_point,
                "_dtype_": dtype
            }
            quantize_op = torch.quantize_per_tensor
    elif dtype == torch.float16:
        node_type = "call_method"
        quantize_op = "to"
        qparams = {"_dtype_": dtype}
    elif dtype == torch.float32 and compute_dtype in [
            torch.quint8, torch.qint8, torch.float16
    ]:
        # dynamic quantization
        node_type = "call_function"
        quantize_op = torch.quantize_per_tensor_dynamic
        # TODO: get reduce range from observer
        # reduce_range = activation_post_process.reduce_range
        reduce_range = torch.backends.quantized.engine == "fbgemm"
        qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range}
    else:
        warnings.warn(
            f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}"
        )
        return None
    return node_type, quantize_op, qparams