def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True): metadata = _metadata_dict[dtype_str] assert "mgb_dtype" in oup.dtype.metadata assert is_quantize(oup.dtype) np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.name) np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype)) if is_unsigned: np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype))
def test_as_type(): x = TensorWrapper([1, 2, 3], dtype=np.float32) y = x.astype(qint8(0.1)) np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) z = y.astype(qint8(0.2)) np.testing.assert_almost_equal(get_scale(z.dtype), 0.2) a = z.astype(quint8(0.3, 127)) np.testing.assert_almost_equal(get_scale(a.dtype), 0.3) np.testing.assert_equal(get_zero_point(a.dtype), 127) b = a.astype(quint8(0.3, 128)) np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) np.testing.assert_equal(get_zero_point(b.dtype), 128)
def test_as_type(is_varnode): if is_varnode: network = Network() else: network = None x_np = np.array([1, 2, 3], dtype=np.float32) x = make_tensor(x_np, network) y = x.astype(qint8(0.1)) np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) z = y.astype(qint8(0.2)) np.testing.assert_almost_equal(get_scale(z.dtype), 0.2) a = z.astype(quint8(0.3, 127)) np.testing.assert_almost_equal(get_scale(a.dtype), 0.3) np.testing.assert_equal(get_zero_point(a.dtype), 127) b = a.astype(quint8(0.3, 128)) np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) np.testing.assert_equal(get_zero_point(b.dtype), 128)
def get_qat_inputs_quint8(inp_dtype, num_inp=1, shape=(1, 16, 384, 512)): inps = [] for _ in range(num_inp): data1 = mge.tensor(np.random.random(shape)) * 16 data1 = data1.astype(inp_dtype) inp1 = mge.tensor(dtype.convert_from_quint8(data1.numpy())) inp1.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype)) inp1.qparams.zero_point = mge.tensor(dtype.get_zero_point(inp_dtype)) inp1.qparams.dtype_meta = dtype._builtin_quant_dtypes["quint8"] inps.append(inp1) return inps
def test_dtype_quint8(): with pytest.raises(ValueError): blah = quint8(0.05, 0.233) with pytest.raises(ValueError): blah = quint8(0.02, 777) with pytest.raises(ValueError): blah = quint8(0.02, -1) dt = quint8(0.01, 135) assert isinstance(dt, np.dtype) assert "mgb_dtype" in dt.metadata np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01) np.testing.assert_equal(dt.metadata["mgb_dtype"]["zero_point"], 135) assert is_quantize(dt) np.testing.assert_allclose(get_scale(dt), 0.01) np.testing.assert_equal(get_zero_point(dt), 135)