Example #1
0
def test_quantized_module_user_naming_param(symbolic):
    class Simple(M.Module):
        def __init__(self, name):
            super().__init__(name=name)
            self.quant = M.QuantStub()
            self.linear = M.Linear(3, 3, bias=True)
            self.dequant = M.DequantStub()

            self.linear.weight.name = "user-weight"
            self.linear.bias.name = "user-bias"

        def forward(self, x):
            out = self.quant(x)
            out = self.linear(out)
            out = self.dequant(out)
            return out

    m = Simple("simple")
    quantize_qat(m)
    quantize(m)
    m.eval()

    ops = _dump_and_load(m, symbolic)

    (matrix_mul_op, ) = [
        op for op in ops if op.name == "simple.linear.MatrixMul"
    ]
    for var in matrix_mul_op.inputs:
        assert var.name in ("simple.quant.TypeCvt",
                            "simple.linear.user-weight")
Example #2
0
def test_load_quantized():
    from megengine.core.tensor import dtype

    data_shape = (2, 28)
    data = tensor(np.random.random(data_shape), dtype="float32")
    data = data.astype(dtype.qint8(0.1))
    mlp = MLP()
    quantize_qat(mlp)
    quantize(mlp)
    mlp.dense0.weight = Parameter(
        mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy())
    mlp.dense1.weight = Parameter(
        mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy())
    mlp.eval()
    pred0 = mlp(data)

    with BytesIO() as fout:
        mge.save(mlp.state_dict(), fout)
        fout.seek(0)
        checkpoint = mge.load(fout)
        # change mlp weight.
        mlp.dense0.weight = Parameter(
            mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy())
        mlp.dense1.weight = Parameter(
            mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy())
        mlp.load_state_dict(checkpoint)
        pred1 = mlp(data)

    np.testing.assert_allclose(pred0.astype("float32").numpy(),
                               pred1.astype("float32").numpy(),
                               atol=5e-6)
Example #3
0
def test_load_quantized():
    data_shape = (2, 28)
    data = tensor(np.random.random(data_shape), dtype="float32")
    data = data.astype(mgb.dtype.qint8(0.1))
    mlp = MLP()
    quantize_qat(mlp)
    quantize(mlp)
    mlp.dense0.weight = Parameter(
        mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy())
    mlp.dense1.weight = Parameter(
        mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy())
    mlp.eval()
    pred0 = mlp(data)

    with BytesIO() as fout:
        mge.save(mlp.state_dict(), fout)
        fout.seek(0)
        checkpoint = mge.load(fout)
        # change mlp weight.
        mlp.dense0.weight = Parameter(
            mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy())
        mlp.dense1.weight = Parameter(
            mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy())
        mlp.load_state_dict(checkpoint)
        pred1 = mlp(data)

    assertTensorClose(pred0.astype("float32").numpy(),
                      pred1.astype("float32").numpy(),
                      max_err=5e-6)
Example #4
0
def test_quantized_module_user_naming(symbolic):
    class Simple(M.Module):
        def __init__(self, name):
            super().__init__(name=name)
            self.quant = M.QuantStub()
            self.linear = M.Linear(3, 3, bias=True, name="user-linear")
            self.dequant = M.DequantStub()

        def forward(self, x):
            out = self.quant(x)
            out = self.linear(out)
            out = self.dequant(out)
            return out

    m = Simple("simple")
    quantize_qat(m)
    quantize(m)
    m.eval()

    ops = _dump_and_load(m, symbolic)
    ops_name = (
        "x",
        "simple.quant.TypeCvt",
        "simple.user-linear.MatrixMul",
        "simple.user-linear.ADD",
        "simple.user-linear.TypeCvt",
        "simple.dequant.TypeCvt",
    )
    for op, name in zip(ops, ops_name):
        assert op.name == name
Example #5
0
def worker(world_size, args):
    # pylint: disable=too-many-statements

    rank = dist.get_rank()
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))

    model = models.__dict__[args.arch]()

    if args.mode != "normal":
        quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)

    if args.checkpoint:
        logger.info("Load pretrained weights from %s", args.checkpoint)
        ckpt = mge.load(args.checkpoint)
        ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        model.load_state_dict(ckpt, strict=False)

    if args.mode == "quantized":
        quantize(model)

    # Define valid graph
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    # Build valid datasets
    logger.info("preparing dataset..")
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=128),
            T.ToMode("CHW")
        ]),
        num_workers=args.workers,
    )

    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    if rank == 0:
        logger.info("TEST %f, %f", valid_acc, valid_acc5)
Example #6
0
def test_convert_with_custom_mapping():
    class FloatExample(Float.Module):
        def forward(self, x):
            return x

    class QATExample(QAT.QATModule):
        def forward(self, x):
            return x

        @classmethod
        def from_float_module(cls, float_module):
            return cls()

    class Net(Float.Module):
        def __init__(self):
            super().__init__()
            self.example = FloatExample()

        def forward(self, x):
            return self.example(x)

    net = Net()
    qat_net = quantize_qat(net,
                           inplace=False,
                           mapping={FloatExample: QATExample})
    assert isinstance(qat_net.example, QATExample)
Example #7
0
def test_disable_fake_quant():
    class Net(Float.Module):
        def __init__(self):
            super().__init__()
            self.quant = Float.QuantStub()
            self.linear = Float.Linear(3, 3)
            self.dequant = Float.DequantStub()
            self.linear.bias.set_value(np.random.rand(3))

        def forward(self, x):
            x = self.quant(x)
            x = self.linear(x)
            x = self.dequant(x)
            return x

    x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
    net = Net()
    y1 = net(x).numpy()
    net = quantize_qat(net, min_max_fakequant_qconfig)
    y2 = net(x).numpy()
    disable_fake_quant(net)
    y3 = net(x).numpy()
    np.testing.assert_allclose(y1, y3)
    with pytest.raises(AssertionError):
        np.testing.assert_allclose(y2, y3)
Example #8
0
def test_qat_convbn2d():
    in_channels = 32
    out_channels = 64
    kernel_size = 3
    for groups, bias in product([1, 4], [True, False]):
        module = ConvBn2d(in_channels,
                          out_channels,
                          kernel_size,
                          groups=groups,
                          bias=bias)
        module.train()
        qat_module = quantize_qat(module, inplace=False)
        disable_fake_quant(qat_module)
        inputs = tensor(
            np.random.randn(4, in_channels, 32, 32).astype(np.float32))
        normal_outputs = module(inputs)
        qat_outputs = qat_module(inputs)
        assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
        assertTensorClose(module.bn.running_mean,
                          qat_module.bn.running_mean,
                          max_err=5e-8)
        assertTensorClose(module.bn.running_var,
                          qat_module.bn.running_var,
                          max_err=5e-7)
        module.eval()
        normal_outputs = module(inputs)
        qat_module.eval()
        qat_outputs = qat_module(inputs)
        assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
Example #9
0
def test_quantize_qat():
    net = FloatNet()
    qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig)
    assert isinstance(qat_net.quant, QAT.QuantStub)
    assert isinstance(qat_net.linear[0], QAT.Linear)
    assert isinstance(qat_net.linear[1], QAT.Linear)
    assert isinstance(qat_net.dequant, QAT.DequantStub)
Example #10
0
def test_quantize_batchmatmul_activation():
    batch = 4
    in_features = 8
    out_features = 4

    class TestNet(Module):
        def __init__(self, bias):
            super().__init__()
            self.quant = QuantStub()
            self.dequant = DequantStub()
            self.batch_mm = BatchMatMulActivation(
                batch, in_features, out_features, bias=bias
            )

        def forward(self, inp):
            out = self.quant(inp)
            out = self.batch_mm(out)
            out = expand_dims(out, -1)
            out = self.dequant(out)
            return out

    inputs = tensor(
        np.random.randn(batch, in_features, out_features).astype(np.float32)
    )
    for bias in (True, False):
        net = TestNet(bias)
        net.train()
        qat_net = quantize_qat(net, inplace=False)
        disable_fake_quant(qat_net)
        normal_outputs = net(inputs)
        qat_outputs = qat_net(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

        net.eval()
        normal_outputs = net(inputs)
        qat_net.eval()
        qat_outputs = qat_net(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

        enable_fake_quant(qat_net)
        qat_outputs = qat_net(inputs)
        qnet = quantize(qat_net, inplace=False)
        qnet.eval()
        quantize_outputs = qnet(inputs)
        np.testing.assert_allclose(
            qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
        )

        @jit.trace(capture_as_const=True)
        def f(x):
            qnet.eval()
            return qnet(x)

        f(inputs)
        file = io.BytesIO()
        f.dump(file, enable_nchw4=True)
        file.seek(0)
        dumped_outputs = cgtools.load_and_inference(file, [inputs])[0]
        np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
Example #11
0
def get_qat_net(inp_dtype, net, num_inp=1, shape=(1, 16, 32, 32)):
    qat_net = quantize_qat(net)
    inps = []
    for _ in range(num_inp):
        data1 = mge.tensor(np.random.random(shape)) * 16
        data1 = data1.astype(inp_dtype)
        inp1 = mge.tensor(dtype.convert_from_qint8(data1.numpy()))
        inp1.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype))
        inp1.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"]
        inps.append(inp1)
    return qat_net, inps
Example #12
0
def test_disable_quantize():
    class Net(Float.Module):
        def __init__(self):
            super().__init__()
            self.conv = Float.ConvBnRelu2d(3, 3, 3)
            self.conv.disable_quantize()

        def forward(self, x):
            return self.conv(x)

    net = Net()
    qat_net = quantize_qat(net, inplace=False)
    assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
    assert isinstance(qat_net.conv.conv, Float.Conv2d)
Example #13
0
def test_qat_conv(padding, padding_mode):

    in_channels = 32
    out_channels = 64
    kernel_size = 3

    class TestNet(Module):
        def __init__(self, groups, bias):
            super().__init__()
            self.quant = QuantStub()
            self.dequant = DequantStub()
            self.conv = Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                groups=groups,
                bias=bias,
                padding=padding,
                padding_mode=padding_mode,
            )
            self.conv_relu = ConvRelu2d(out_channels,
                                        in_channels,
                                        kernel_size,
                                        groups=groups,
                                        bias=bias)

        def forward(self, inp):
            out = self.quant(inp)
            out = self.conv(out)
            out = self.conv_relu(out)
            out = self.dequant(out)
            return out

    inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
    for groups, bias in product([1, 4], [True, False]):
        net = TestNet(groups, bias)
        net.train()
        qat_net = quantize_qat(net, inplace=False)
        disable_fake_quant(qat_net)
        normal_outputs = net(inputs)
        qat_outputs = qat_net(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

        net.eval()
        normal_outputs = net(inputs)
        qat_net.eval()
        qat_outputs = qat_net(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
Example #14
0
def test_enable_and_disable_all():
    x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
    net = Net()
    y1 = net(x).numpy()
    net = quantize_qat(net, min_max_fakequant_qconfig)

    init_observer(net, x)

    y2 = net(x).numpy()
    disable_fake_quant(net)
    y3 = net(x).numpy()
    enable_fake_quant(net)
    y4 = net(x).numpy()
    np.testing.assert_allclose(y1, y3)
    np.testing.assert_allclose(y2, y4)
    with pytest.raises(AssertionError):
        np.testing.assert_allclose(y2, y3)
Example #15
0
def test_qat_conv_qint8():
    class QConvOpr(M.Module):
        def __init__(self):
            super().__init__()
            self.normal_conv = M.Conv2d(
                3,
                30,
                3,
                stride=(2, 3),
                padding=(3, 1),
                dilation=(2, 2),
            )
            self.normal_conv.bias = mge.Parameter(
                np.random.random(self.normal_conv.bias.shape).astype(
                    np.float32))

        def forward(self, x):
            x = self.normal_conv(x)
            return x

    net = QConvOpr()
    qat_net = quantize_qat(net)

    inp_dtype = dtype.qint8(16.0 / 128)
    data = mge.tensor(np.random.random((1, 3, 224, 224))) * 16
    data = data.astype(inp_dtype)
    inp = mge.tensor(dtype.convert_from_qint8(data.numpy()))
    inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype))
    inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"]

    traced_module, tm_result = get_traced_module(qat_net, inp)
    print(traced_module.flatten().graph)
    inp = inp.astype(inp_dtype)
    out_dtype = traced_module.graph.outputs[0].qparams
    scale = out_dtype.scale.numpy()
    _test_convert_result(
        inp,
        traced_module,
        tm_result,
        scale=scale,
        require_quantize=True,
        max_err=max_error,
    )
Example #16
0
def test_qat_convrelu():
    net = ConvRelu2dOpr()
    qat_net = quantize_qat(net)
    inp_dtype = dtype.qint8(16.0 / 128)
    data = mge.tensor(np.random.random((1, 3, 224, 224))) * 16
    data = data.astype(inp_dtype)
    inp = mge.tensor(dtype.convert_from_qint8(data.numpy()))
    inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype))
    inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"]

    traced_module, tm_result = get_traced_module(qat_net, inp)
    inp = inp.astype(inp_dtype)
    out_dtype = traced_module.graph.outputs[0].qparams
    scale = out_dtype.scale.numpy()
    _test_convert_result(
        inp,
        traced_module,
        tm_result,
        scale=scale,
        require_quantize=True,
        max_err=max_error,
    )
Example #17
0
def test_qat_convbn2d():
    in_channels = 32
    out_channels = 64
    kernel_size = 3
    for groups, bias in product([1, 4], [True, False]):
        module = ConvBn2d(in_channels,
                          out_channels,
                          kernel_size,
                          groups=groups,
                          bias=bias)
        module.train()
        qat_module = quantize_qat(module, inplace=False)
        disable_fake_quant(qat_module)
        inputs = tensor(
            np.random.randn(4, in_channels, 32, 32).astype(np.float32))
        normal_outputs = module(inputs)
        # import pdb
        # pdb.set_trace()
        qat_outputs = qat_module(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(),
                                   qat_outputs.numpy(),
                                   atol=5e-6)
        np.testing.assert_allclose(
            module.bn.running_mean.numpy(),
            qat_module.bn.running_mean.numpy(),
            atol=5e-8,
        )
        np.testing.assert_allclose(
            module.bn.running_var.numpy(),
            qat_module.bn.running_var.numpy(),
            atol=5e-7,
        )
        module.eval()
        normal_outputs = module(inputs)
        qat_module.eval()
        qat_outputs = qat_module(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(),
                                   qat_outputs.numpy(),
                                   atol=5e-6)
Example #18
0
def test_qat_batchmatmul_activation():
    batch = 4
    in_features = 8
    out_features = 4

    class TestNet(Module):
        def __init__(self, bias):
            super().__init__()
            self.quant = QuantStub()
            self.dequant = DequantStub()
            self.batch_mm = BatchMatMulActivation(batch,
                                                  in_features,
                                                  out_features,
                                                  bias=bias)

        def forward(self, inp):
            out = self.quant(inp)
            out = self.batch_mm(out)
            out = self.dequant(out)
            return out

    inputs = tensor(
        np.random.randn(batch, in_features, out_features).astype(np.float32))
    for bias in (True, False):
        net = TestNet(bias)
        net.train()
        qat_net = quantize_qat(net, inplace=False)
        disable_fake_quant(qat_net)
        normal_outputs = net(inputs)
        qat_outputs = qat_net(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

        net.eval()
        normal_outputs = net(inputs)
        qat_net.eval()
        qat_outputs = qat_net(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
Example #19
0
def test_qat_conv_transpose2d():
    in_channels = 32
    out_channels = 64
    kernel_size = 3

    class TestNet(Module):
        def __init__(self, bias):
            super().__init__()
            self.quant = QuantStub()
            self.dequant = DequantStub()
            self.conv = ConvTranspose2d(in_channels,
                                        out_channels,
                                        kernel_size,
                                        bias=bias)

        def forward(self, inp):
            out = self.quant(inp)
            out = self.conv(out)
            out = self.dequant(out)
            return out

    inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
    for bias in [True, False]:
        net = TestNet(bias)
        net.train()
        qat_net = quantize_qat(net, inplace=False)
        disable_fake_quant(qat_net)
        normal_outputs = net(inputs)
        qat_outputs = qat_net(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

        net.eval()
        normal_outputs = net(inputs)
        qat_net.eval()
        qat_outputs = qat_net(inputs)
        np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
Example #20
0
from test.utils import ConvOpr, dump_mge_model

import megengine as mge
import numpy as np
from megengine.core.tensor import dtype
from megengine.quantization.quantize import quantize_qat
from megengine.traced_module import trace_module

if __name__ == "__main__":
    net = ConvOpr("normal")
    traced_module = trace_module(net, mge.tensor(net.data))
    mge.save(traced_module, "float_model.tm")
    dump_mge_model(net, net.data, "float_model")

    qat_net = quantize_qat(net)
    inp_dtype = dtype.qint8(16.0 / 128)
    data = mge.tensor(np.random.random((1, 3, 224, 224))) * 16
    data = data.astype(inp_dtype)
    inp = mge.tensor(dtype.convert_from_qint8(data.numpy()))
    inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype))
    inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"]

    qat_module = trace_module(qat_net, inp)
    mge.save(qat_module, "qat_model.tm")
Example #21
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements

    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch + "." + args.mode)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()
    cfg = config.get_finetune_config(args.arch)

    cfg.LEARNING_RATE *= world_size  # scale learning rate in distributed training
    total_batch_size = cfg.BATCH_SIZE * world_size
    steps_per_epoch = 1280000 // total_batch_size
    total_steps = steps_per_epoch * cfg.EPOCHS
    
    # load calibration model
    assert args.checkpoint
    logger.info("Load pretrained weights from %s", args.checkpoint)
    ckpt = mge.load(args.checkpoint)
    ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    model.load_state_dict(ckpt, strict=False)

    # Build valid datasets
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    # valid_dataset = ImageNetNoriDataset(args.data)
    valid_sampler = data.SequentialSampler(
        valid_dataset, batch_size=100, drop_last=False
    )
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose(
            [
                T.Resize(256),
                T.CenterCrop(224),
                T.Normalize(mean=128),
                T.ToMode("CHW"),
            ]
        ),
        num_workers=args.workers,
    )

    # calibration
    model.fc.disable_quantize()
    model = quantize_qat(model, qconfig=Q.calibration_qconfig)
    
    # calculate scale
    @jit.trace(symbolic=True)
    def calculate_scale(image, label):
        model.eval()
        enable_observer(model)
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5
    
    # model.fc.disable_quantize()
    infer(calculate_scale, valid_queue, args)

    # quantized
    model = quantize(model)

    # eval quantized model
    @jit.trace(symbolic=True)
    def eval_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5
        
    _, valid_acc, valid_acc5 = infer(eval_func, valid_queue, args)
    logger.info("TEST %f, %f", valid_acc, valid_acc5)

    # save quantized model
    mge.save(
        {"step": -1, "state_dict": model.state_dict()},
        os.path.join(save_dir, "checkpoint-calibration.pkl")
    )
    logger.info("save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl")))
Example #22
0
def worker(world_size, args):
    # pylint: disable=too-many-statements

    rank = dist.get_rank()
    if world_size > 1:
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))

    save_dir = os.path.join(args.save, args.arch + "." + args.mode)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()
    cfg = config.get_config(args.arch)

    cfg.LEARNING_RATE *= world_size  # scale learning rate in distributed training
    total_batch_size = cfg.BATCH_SIZE * world_size
    steps_per_epoch = 1280000 // total_batch_size
    total_steps = steps_per_epoch * cfg.EPOCHS

    if args.mode != "normal":
        quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)

    if world_size > 1:
        # Sync parameters
        dist.bcast_list_(model.parameters(), dist.WORLD)

    # Autodiff gradient manager
    gm = autodiff.GradManager().attach(
        model.parameters(),
        callbacks=dist.make_allreduce_cb("MEAN") if world_size > 1 else None,
    )

    optimizer = optim.SGD(
        get_parameters(model, cfg),
        lr=cfg.LEARNING_RATE,
        momentum=cfg.MOMENTUM,
    )

    # Define train and valid graph
    def train_func(image, label):
        with gm:
            model.train()
            logits = model(image)
            loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
            acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
            gm.backward(loss)
            optimizer.step().clear_grad()
        return loss, acc1, acc5

    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=cfg.BATCH_SIZE,
                           drop_last=True))
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            cfg.COLOR_JITTOR,
            T.Normalize(mean=128),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    train_queue = iter(train_queue)
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=128),
            T.ToMode("CHW")
        ]),
        num_workers=args.workers,
    )

    def adjust_learning_rate(step, epoch):
        learning_rate = cfg.LEARNING_RATE
        if cfg.SCHEDULER == "Linear":
            learning_rate *= 1 - float(step) / total_steps
        elif cfg.SCHEDULER == "Multistep":
            learning_rate *= cfg.SCHEDULER_GAMMA**bisect.bisect_right(
                cfg.SCHEDULER_STEPS, epoch)
        else:
            raise ValueError(cfg.SCHEDULER)
        for param_group in optimizer.param_groups:
            param_group["lr"] = learning_rate
        return learning_rate

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()
    for step in range(0, total_steps):
        # Linear learning rate decay
        epoch = step // steps_per_epoch
        learning_rate = adjust_learning_rate(step, epoch)

        image, label = next(train_queue)
        image = mge.tensor(image, dtype="float32")
        label = mge.tensor(label, dtype="int32")

        n = image.shape[0]

        loss, acc1, acc5 = train_func(image, label)

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logger.info(
                "TRAIN e%d %06d %f %s %s %s %s",
                epoch,
                step,
                learning_rate,
                objs,
                top1,
                top5,
                total_time,
            )
            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()
        if step != 0 and step % 10000 == 0 and rank == 0:
            logger.info("SAVING %06d", step)
            mge.save(
                {
                    "step": step,
                    "state_dict": model.state_dict()
                },
                os.path.join(save_dir, "checkpoint.pkl"),
            )
        if step % 10000 == 0 and step != 0:
            _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)

    mge.save(
        {
            "step": step,
            "state_dict": model.state_dict()
        },
        os.path.join(save_dir, "checkpoint-final.pkl"),
    )
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
Example #23
0
def worker(world_size, args):
    # pylint: disable=too-many-statements

    rank = dist.get_rank()
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(rank, world_size))

    save_dir = os.path.join(args.save, args.arch + "." + "calibration")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()

    # load calibration model
    assert args.checkpoint
    logger.info("Load pretrained weights from %s", args.checkpoint)
    ckpt = mge.load(args.checkpoint)
    ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    model.load_state_dict(ckpt, strict=False)

    # Build valid datasets
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(
        valid_dataset, batch_size=100, drop_last=False
    )
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose(
            [T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW")]
        ),
        num_workers=args.workers,
    )

    # calibration
    model.fc.disable_quantize()
    model = quantize_qat(model, qconfig=Q.calibration_qconfig)

    # calculate scale
    def calculate_scale(image, label):
        model.eval()
        enable_observer(model)
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    infer(calculate_scale, valid_queue, args)

    # quantized
    model = quantize(model)

    # eval quantized model
    def eval_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    _, valid_acc, valid_acc5 = infer(eval_func, valid_queue, args)
    logger.info("TEST %f, %f", valid_acc, valid_acc5)

    # save quantized model
    mge.save(
        {"step": -1, "state_dict": model.state_dict()},
        os.path.join(save_dir, "checkpoint-calibration.pkl"),
    )
    logger.info(
        "save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl"))
    )
Example #24
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-a", "--arch", default="resnet18", type=str)
    parser.add_argument("-c", "--checkpoint", default=None, type=str)
    parser.add_argument("-i", "--image", default=None, type=str)

    parser.add_argument(
        "-m",
        "--mode",
        default="quantized",
        type=str,
        choices=["normal", "qat", "quantized"],
        help="Quantization Mode\n"
        "normal: no quantization, using float32\n"
        "qat: quantization aware training, simulate int8\n"
        "quantized: convert mode to int8 quantized, inference only",
    )
    parser.add_argument("--dump", action="store_true", help="Dump quantized model")
    args = parser.parse_args()

    model = models.__dict__[args.arch]()

    if args.mode != "normal":
        quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)

    if args.mode == "quantized":
        quantize(model)

    if args.checkpoint:
        logger.info("Load pretrained weights from %s", args.checkpoint)
        ckpt = mge.load(args.checkpoint)
        ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        model.load_state_dict(ckpt, strict=False)

    rpath = os.path.realpath(__file__ + "/../../")
    if args.image is None:
        path = rpath + "/assets/cat.jpg"
    else:
        path = args.image
    image = cv2.imread(path, cv2.IMREAD_COLOR)

    transform = T.Compose(
        [T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW")]
    )

    @trace(symbolic=True, capture_as_const=True)
    def infer_func(processed_img):
        model.eval()
        logits = model(processed_img)
        probs = F.softmax(logits)
        return probs

    processed_img = transform.apply(image)[np.newaxis, :]

    processed_img = mge.tensor(processed_img, dtype="float32")

    probs = infer_func(processed_img)

    top_probs, classes = F.topk(probs, k=5, descending=True)

    if args.dump:
        output_file = ".".join([args.arch, args.mode, "megengine"])
        logger.info("Dump to {}".format(output_file))
        infer_func.dump(output_file, arg_names=["data"])
        mge.save(model.state_dict(), output_file.replace("megengine", "pkl"))

    with open(rpath + "/assets/imagenet_class_info.json") as fp:
        imagenet_class_index = json.load(fp)

    for rank, (prob, classid) in enumerate(
        zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1))
    ):
        print(
            "{}: class = {:20s} with probability = {:4.1f} %".format(
                rank, imagenet_class_index[str(classid)][1], 100 * prob
            )
        )
Example #25
0
def quantized_resnet18(**kwargs):
    model = resnet18(**kwargs)
    quantize_qat(model)
    quantize(model)
    return model