コード例 #1
0
def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True):
    metadata = _metadata_dict[dtype_str]
    assert "mgb_dtype" in oup.dtype.metadata
    assert is_quantize(oup.dtype)
    np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.name)
    np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype))
    if is_unsigned:
        np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype))
コード例 #2
0
def test_as_type():
    x = TensorWrapper([1, 2, 3], dtype=np.float32)
    y = x.astype(qint8(0.1))
    np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
    z = y.astype(qint8(0.2))
    np.testing.assert_almost_equal(get_scale(z.dtype), 0.2)
    a = z.astype(quint8(0.3, 127))
    np.testing.assert_almost_equal(get_scale(a.dtype), 0.3)
    np.testing.assert_equal(get_zero_point(a.dtype), 127)
    b = a.astype(quint8(0.3, 128))
    np.testing.assert_almost_equal(get_scale(b.dtype), 0.3)
    np.testing.assert_equal(get_zero_point(b.dtype), 128)
コード例 #3
0
def test_as_type(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    x_np = np.array([1, 2, 3], dtype=np.float32)
    x = make_tensor(x_np, network)
    y = x.astype(qint8(0.1))
    np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
    z = y.astype(qint8(0.2))
    np.testing.assert_almost_equal(get_scale(z.dtype), 0.2)
    a = z.astype(quint8(0.3, 127))
    np.testing.assert_almost_equal(get_scale(a.dtype), 0.3)
    np.testing.assert_equal(get_zero_point(a.dtype), 127)
    b = a.astype(quint8(0.3, 128))
    np.testing.assert_almost_equal(get_scale(b.dtype), 0.3)
    np.testing.assert_equal(get_zero_point(b.dtype), 128)
コード例 #4
0
def get_qat_inputs_quint8(inp_dtype, num_inp=1, shape=(1, 16, 384, 512)):
    inps = []
    for _ in range(num_inp):
        data1 = mge.tensor(np.random.random(shape)) * 16
        data1 = data1.astype(inp_dtype)
        inp1 = mge.tensor(dtype.convert_from_quint8(data1.numpy()))
        inp1.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype))
        inp1.qparams.zero_point = mge.tensor(dtype.get_zero_point(inp_dtype))
        inp1.qparams.dtype_meta = dtype._builtin_quant_dtypes["quint8"]
        inps.append(inp1)
    return inps
コード例 #5
0
def test_dtype_quint8():
    with pytest.raises(ValueError):
        blah = quint8(0.05, 0.233)
    with pytest.raises(ValueError):
        blah = quint8(0.02, 777)
    with pytest.raises(ValueError):
        blah = quint8(0.02, -1)
    dt = quint8(0.01, 135)
    assert isinstance(dt, np.dtype)
    assert "mgb_dtype" in dt.metadata
    np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01)
    np.testing.assert_equal(dt.metadata["mgb_dtype"]["zero_point"], 135)

    assert is_quantize(dt)
    np.testing.assert_allclose(get_scale(dt), 0.01)
    np.testing.assert_equal(get_zero_point(dt), 135)