Ejemplo n.º 1
0
def test_quint8_typecvt():
    device = "xpux"
    shape = (3, 3, 3)
    data = np.random.random(shape).astype(np.float32) * 5 - 1

    def typecvt(x, dt=None):
        (y, ) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x)
        return y

    # convert to quint8
    dtype = quint8(0.01, 135)
    oup = _get_compiled_result(data,
                               np.float32,
                               shape,
                               device,
                               calc_func=partial(typecvt, dt=dtype))
    _check_result_attr(oup, dtype, "quint8")
    np.testing.assert_equal(oup, convert_to_quint8(data, dtype))

    # convert from quint8 to float32
    oup_float = _get_compiled_result(oup,
                                     dtype,
                                     shape,
                                     device,
                                     calc_func=partial(typecvt, dt=np.float32))
    assert oup_float.dtype == np.float32
    np.testing.assert_equal(
        oup_float, convert_from_quint8(convert_to_quint8(data, dtype)))
Ejemplo n.º 2
0
def test_dtype_int8_ffi_handle():
    device = "xpux"
    shape = (3, 3, 3)
    data = np.random.random(shape).astype(np.float32) * 5 - 1

    def identity(x):
        return x

    dtype = quint8(0.01, 127)
    inp = convert_to_quint8(data, dtype)
    oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity)
    _check_result_attr(oup, dtype, "quint8")
    np.testing.assert_allclose(convert_from_quint8(oup), convert_from_quint8(inp))

    dtype = qint8(0.01)
    inp = convert_to_qint8(data, dtype)
    oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity)
    _check_result_attr(oup, dtype, "qint8", is_unsigned=False)
    np.testing.assert_allclose(convert_from_qint8(oup), convert_from_qint8(inp))
Ejemplo n.º 3
0
def get_qat_inputs_quint8(inp_dtype, num_inp=1, shape=(1, 16, 384, 512)):
    inps = []
    for _ in range(num_inp):
        data1 = mge.tensor(np.random.random(shape)) * 16
        data1 = data1.astype(inp_dtype)
        inp1 = mge.tensor(dtype.convert_from_quint8(data1.numpy()))
        inp1.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype))
        inp1.qparams.zero_point = mge.tensor(dtype.get_zero_point(inp_dtype))
        inp1.qparams.dtype_meta = dtype._builtin_quant_dtypes["quint8"]
        inps.append(inp1)
    return inps