def test_load_quantized(): data_shape = (2, 28) data = tensor(np.random.random(data_shape), dtype="float32") data = data.astype(mgb.dtype.qint8(0.1)) mlp = MLP() quantize_qat(mlp) quantize(mlp) mlp.dense0.weight = Parameter( mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy()) mlp.dense1.weight = Parameter( mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy()) mlp.eval() pred0 = mlp(data) with BytesIO() as fout: mge.save(mlp.state_dict(), fout) fout.seek(0) checkpoint = mge.load(fout) # change mlp weight. mlp.dense0.weight = Parameter( mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy()) mlp.dense1.weight = Parameter( mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy()) mlp.load_state_dict(checkpoint) pred1 = mlp(data) assertTensorClose(pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6)
def test_state_dict(): data_shape = (2, 28) data = tensor() data.set_value(np.random.random(data_shape)) mlp = MLP() pred0 = mlp(data) with BytesIO() as fout: mge.save(mlp.state_dict(), fout) fout.seek(0) state_dict = mge.load(fout) state_dict["extra"] = None mlp1 = MLP() mlp1.load_state_dict(state_dict, strict=False) pred1 = mlp1(data) assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6) with pytest.raises(KeyError): mlp1.load_state_dict(state_dict) del state_dict["extra"] del state_dict["dense0.bias"] with pytest.raises(KeyError): mlp1.load_state_dict(state_dict)