Ejemplo n.º 1
0
def quantitize_cfg(data,
                   scale,
                   bitwidth,
                   method,
                   range_method=RangeMethod.RANGE_MAX):
    if (not isinstance(method, torch.autograd.Variable)
            and not torch.is_tensor(method)
            and method == QuantizeMethod.FIX_NONE):
        return data, None

    method_v = get_int(method)

    if method_v == QuantizeMethod.FIX_NONE:
        return data, None
    elif method_v == QuantizeMethod.FIX_AUTO:
        EPS = 1e-5
        range_method_v = get_int(range_method)
        if range_method_v == RangeMethod.RANGE_MAX:
            new_scale = torch.ceil(
                torch.log(
                    torch.max(
                        torch.max(torch.abs(data)),
                        torch.tensor(EPS).float().to(data.device),
                    )) / np.log(2.0))
        elif range_method_v == RangeMethod.RANGE_3SIGMA:
            # TODO: Dr. Sun said he will implement this
            raise NotImplementedError()
        scale.data.numpy()[0] = new_scale
        return _do_quantitize(data, scale, bitwidth)
    elif method_v == QuantizeMethod.FIX_FIXED:
        return _do_quantitize(data, scale, bitwidth)
    raise Exception("Quantitize method not legal: {}".format(method_v))
Ejemplo n.º 2
0
def quantize_cfg(
    data,
    scale,
    bitwidth,
    method,
    range_method=RangeMethod.RANGE_MAX,
    stochastic=False,
    float_scale=False,
    zero_point=False,
    group=False,
):
    """
    stochastic - stochastic rounding
    range_method - how to decide dynamic range
    float_scale - whether the scale is chosen to be 2^K
    zero_point  - symm/asymm quantize
    """
    if (not isinstance(method, torch.autograd.Variable)
            and not torch.is_tensor(method)
            and method == QuantizeMethod.FIX_NONE):
        return data, None

    if group == "batch" and len(data.shape) == 4:
        # only applied for conv units
        max_data = data.view(data.shape[0], -1).max(dim=1)[0]
        min_data = data.view(data.shape[0], -1).min(dim=1)[0]
    elif group == "channel" and len(data.shape) == 4:
        max_data = data.view(data.shape[1], -1).max(dim=1)[0]
        min_data = data.view(data.shape[1], -1).min(dim=1)[0]
    else:
        max_data = data.max()
        min_data = data.min()

    # Avoid extreme value
    EPS = torch.FloatTensor(max_data.shape).fill_(1e-5)
    EPS = EPS.to(max_data.device)
    if not zero_point:
        max_data = torch.max(torch.max(max_data.abs(), min_data.abs()), EPS)

    method_v = get_int(method)

    if method_v == QuantizeMethod.FIX_NONE:
        return data, None
    elif method_v == QuantizeMethod.FIX_AUTO:
        range_method_v = get_int(range_method)
        if float_scale and range_method_v != RangeMethod.RANGE_MAX:
            raise NotImplementedError(
                "Now Only Support Float_Scale with Range-Max")
        if range_method_v == RangeMethod.RANGE_MAX:
            if float_scale:
                if zero_point:
                    new_scale = torch.stack([min_data, max_data])
                    scale.data = new_scale
                    return _do_quantize(
                        data,
                        scale,
                        bitwidth,
                        stochastic=stochastic,
                        symmetric=False,
                        group=group,
                    )
                else:
                    scale.data = max_data
                    return _do_quantize(
                        data,
                        scale,
                        bitwidth,
                        stochastic=stochastic,
                        symmetric=True,
                        group=group,
                    )
            else:
                new_scale = torch.pow(
                    2,
                    torch.ceil(
                        torch.log(max_data) / torch.FloatTensor([1]).fill_(
                            np.log(2.0)).to(max_data.device)),
                )

                scale.data = new_scale
                return _do_quantize(data,
                                    scale,
                                    bitwidth,
                                    stochastic=stochastic,
                                    group=group)

        elif range_method_v == RangeMethod.RANGE_MAX_TENPERCENT:
            # FIXME: Too slow
            scale = torch.pow(
                2,
                torch.ceil(
                    torch.log(
                        torch.max(
                            # torch.kthvalue(torch.abs(data.view(-1)), 9 * (data.nelement() // 10))[0],
                            torch.topk(torch.abs(data.view(-1)),
                                       data.nelement() // 10)[0][-1],
                            # torch.tensor(EPS).float().to(data.device))
                            torch.FloatTensor(1).fill_(EPS).to(data.device),
                        )) / torch.cuda.FloatTensor([1]).fill_(np.log(2.0))),
            )
            return _do_quantize(data, scale, bitwidth, stochastic=stochastic)

        elif range_method_v == RangeMethod.RANGE_3SIGMA:
            new_boundary = torch.max(
                3 * torch.std(data) + torch.abs(torch.mean(data)),
                torch.tensor(EPS).float().to(data.device),
            )
            new_scale = torch.pow(
                2, torch.ceil(torch.log(new_boundary) / np.log(2.0)))
            scale.data = new_scale
            return _do_quantize(data,
                                scale,
                                bitwidth,
                                stochastic=stochastic,
                                symmetric=not zero_point)

        elif range_method_v == RangeMethod.RANGE_SWEEP:
            # Iterat through other scale to find the proper scale to minimize error
            # Noted that the scale is [(MAX - SWEEP),MAX]
            SWEEP = 3
            temp_scale = torch.ceil(
                torch.log(
                    torch.max(torch.max(abs(data)),
                              torch.tensor(EPS).float().to(data.device))) /
                np.log(2.0))
            for i in range(SWEEP):
                errors[i] = torch.abs(
                    _do_quantize(data, temp_scale - i, bitwidth)[0] -
                    data).sum()
            new_scale = torch.pow(2, temp_scale - errors.argmin())
            scale.data = new_scale
            return _do_quantize(data, scale, bitwidth, stochastic=stochastic)

        else:
            raise NotImplementedError()

    elif method_v == QuantizeMethod.FIX_FIXED:

        if group == "batch" and len(data.shape) == 4:
            max_data = data.view(data.shape[0], -1).max(dim=1)[0]
            min_data = data.view(data.shape[0], -1).min(dim=1)[0]
        if group == "channel" and len(data.shape) == 4:
            max_data = data.view(data.shape[1], -1).max(dim=1)[0]
            min_data = data.view(data.shape[1], -1).min(dim=1)[0]
        else:
            max_data = data.max()
            min_data = data.min()

        EPS = torch.FloatTensor(max_data.shape).fill_(1e-5)
        EPS = EPS.to(max_data.device)
        if not zero_point:
            max_data = torch.max(torch.max(max_data.abs(), min_data.abs()),
                                 EPS)

        # TODO: Check whether float_scale automatically adjust through inference
        # If float_scale, do as FIX_AUTO does
        if float_scale:
            if zero_point:
                new_scale = torch.stack([min_data, max_data])
                scale.data = new_scale
                return _do_quantize(data,
                                    scale,
                                    bitwidth,
                                    stochastic=stochastic,
                                    symmetric=False)
            else:
                scale.data = max_data
                return _do_quantize(data,
                                    scale,
                                    bitwidth,
                                    stochastic=stochastic,
                                    symmetric=True)
        else:
            # donot use new_scale when using power-of-2 scale
            return _do_quantize(
                data,
                scale,
                bitwidth,
                stochastic=stochastic,
                symmetric=not zero_point,
                group=group,
            )

    raise Exception("Quantitize method not legal: {}".format(method_v))
Ejemplo n.º 3
0
def quantitize_cfg(data,
                   scale,
                   bitwidth,
                   method,
                   range_method=RangeMethod.RANGE_MAX):
    if (not isinstance(method, torch.autograd.Variable)
            and not torch.is_tensor(method)
            and method == QuantizeMethod.FIX_NONE):
        return data, None

    method_v = get_int(method)

    if method_v == QuantizeMethod.FIX_NONE:
        return data, None
    elif method_v == QuantizeMethod.FIX_AUTO:
        EPS = 1e-5
        range_method_v = get_int(range_method)
        if range_method_v == RangeMethod.RANGE_MAX:
            new_scale = torch.ceil(
                torch.log(
                    torch.max(
                        torch.max(torch.abs(data)),
                        torch.tensor(EPS).float().to(data.device),
                    )) / np.log(2.0))
        elif range_method_v == RangeMethod.RANGE_MAX_TENPERCENT:
            # FIXME: Too slow
            new_scale = torch.ceil(
                torch.log(
                    torch.max(
                        # torch.kthvalue(torch.abs(data.view(-1)), 9 * (data.nelement() // 10))[0],
                        torch.topk(torch.abs(data.view(-1)),
                                   data.nelement() // 10)[0][-1],
                        torch.tensor(EPS).float().to(data.device))) /
                np.log(2.0))
        elif range_method_v == RangeMethod.RANGE_3SIGMA:
            # raise NotImplementedError()
            new_boundary = torch.max(
                3 * torch.std(data) + torch.abs(torch.mean(data)),
                torch.tensor(EPS).float().to(data.device),
            )
            new_scale = torch.ceil(torch.log(new_boundary) / np.log(2.0))

        elif range_method_v == RangeMethod.RANGE_SWEEP:
            SWEEP = 3
            temp_scale = torch.ceil(
                torch.log(
                    torch.max(torch.max(abs(data)),
                              torch.tensor(EPS).float().to(data.device))) /
                np.log(2.0))
            errors = torch.zeros(SWEEP).float().to(data.device)
            for i in range(SWEEP):
                errors[i] = torch.abs(
                    _do_quantitize(data, temp_scale - i, bitwidth)[0] -
                    data).sum()
            new_scale = temp_scale - errors.argmin()
        else:
            raise NotImplementedError()
        scale.data[0] = new_scale
        return _do_quantitize(data, scale, bitwidth)
    elif method_v == QuantizeMethod.FIX_FIXED:
        return _do_quantitize(data, scale, bitwidth)
    raise Exception("Quantitize method not legal: {}".format(method_v))