def test_quantized_module_user_naming_param(symbolic): class Simple(M.Module): def __init__(self, name): super().__init__(name=name) self.quant = M.QuantStub() self.linear = M.Linear(3, 3, bias=True) self.dequant = M.DequantStub() self.linear.weight.name = "user-weight" self.linear.bias.name = "user-bias" def forward(self, x): out = self.quant(x) out = self.linear(out) out = self.dequant(out) return out m = Simple("simple") quantize_qat(m) quantize(m) m.eval() ops = _dump_and_load(m, symbolic) (matrix_mul_op, ) = [ op for op in ops if op.name == "simple.linear.MatrixMul" ] for var in matrix_mul_op.inputs: assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight")
def test_load_quantized(): from megengine.core.tensor import dtype data_shape = (2, 28) data = tensor(np.random.random(data_shape), dtype="float32") data = data.astype(dtype.qint8(0.1)) mlp = MLP() quantize_qat(mlp) quantize(mlp) mlp.dense0.weight = Parameter( mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) mlp.dense1.weight = Parameter( mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) mlp.eval() pred0 = mlp(data) with BytesIO() as fout: mge.save(mlp.state_dict(), fout) fout.seek(0) checkpoint = mge.load(fout) # change mlp weight. mlp.dense0.weight = Parameter( mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy()) mlp.dense1.weight = Parameter( mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy()) mlp.load_state_dict(checkpoint) pred1 = mlp(data) np.testing.assert_allclose(pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6)
def test_load_quantized(): data_shape = (2, 28) data = tensor(np.random.random(data_shape), dtype="float32") data = data.astype(mgb.dtype.qint8(0.1)) mlp = MLP() quantize_qat(mlp) quantize(mlp) mlp.dense0.weight = Parameter( mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy()) mlp.dense1.weight = Parameter( mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy()) mlp.eval() pred0 = mlp(data) with BytesIO() as fout: mge.save(mlp.state_dict(), fout) fout.seek(0) checkpoint = mge.load(fout) # change mlp weight. mlp.dense0.weight = Parameter( mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy()) mlp.dense1.weight = Parameter( mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy()) mlp.load_state_dict(checkpoint) pred1 = mlp(data) assertTensorClose(pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6)
def test_quantized_module_user_naming(symbolic): class Simple(M.Module): def __init__(self, name): super().__init__(name=name) self.quant = M.QuantStub() self.linear = M.Linear(3, 3, bias=True, name="user-linear") self.dequant = M.DequantStub() def forward(self, x): out = self.quant(x) out = self.linear(out) out = self.dequant(out) return out m = Simple("simple") quantize_qat(m) quantize(m) m.eval() ops = _dump_and_load(m, symbolic) ops_name = ( "x", "simple.quant.TypeCvt", "simple.user-linear.MatrixMul", "simple.user-linear.ADD", "simple.user-linear.TypeCvt", "simple.dequant.TypeCvt", ) for op, name in zip(ops, ops_name): assert op.name == name
def worker(world_size, args): # pylint: disable=too-many-statements rank = dist.get_rank() if world_size > 1: # Initialize distributed process group logger.info("init distributed process group {} / {}".format( rank, world_size)) model = models.__dict__[args.arch]() if args.mode != "normal": quantize_qat(model, qconfig=Q.ema_fakequant_qconfig) if args.checkpoint: logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) if args.mode == "quantized": quantize(model) # Define valid graph def valid_func(image, label): model.eval() logits = model(image) loss = F.loss.cross_entropy(logits, label, label_smooth=0.1) acc1, acc5 = F.topk_accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size() acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size() acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size() return loss, acc1, acc5 # Build valid datasets logger.info("preparing dataset..") valid_dataset = data.dataset.ImageNet(args.data, train=False) valid_sampler = data.SequentialSampler(valid_dataset, batch_size=100, drop_last=False) valid_queue = data.DataLoader( valid_dataset, sampler=valid_sampler, transform=T.Compose([ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW") ]), num_workers=args.workers, ) _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) if rank == 0: logger.info("TEST %f, %f", valid_acc, valid_acc5)
def test_convert_with_custom_mapping(): class FloatExample(Float.Module): def forward(self, x): return x class QATExample(QAT.QATModule): def forward(self, x): return x @classmethod def from_float_module(cls, float_module): return cls() class Net(Float.Module): def __init__(self): super().__init__() self.example = FloatExample() def forward(self, x): return self.example(x) net = Net() qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) assert isinstance(qat_net.example, QATExample)
def test_disable_fake_quant(): class Net(Float.Module): def __init__(self): super().__init__() self.quant = Float.QuantStub() self.linear = Float.Linear(3, 3) self.dequant = Float.DequantStub() self.linear.bias.set_value(np.random.rand(3)) def forward(self, x): x = self.quant(x) x = self.linear(x) x = self.dequant(x) return x x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) net = Net() y1 = net(x).numpy() net = quantize_qat(net, min_max_fakequant_qconfig) y2 = net(x).numpy() disable_fake_quant(net) y3 = net(x).numpy() np.testing.assert_allclose(y1, y3) with pytest.raises(AssertionError): np.testing.assert_allclose(y2, y3)
def test_qat_convbn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 for groups, bias in product([1, 4], [True, False]): module = ConvBn2d(in_channels, out_channels, kernel_size, groups=groups, bias=bias) module.train() qat_module = quantize_qat(module, inplace=False) disable_fake_quant(qat_module) inputs = tensor( np.random.randn(4, in_channels, 32, 32).astype(np.float32)) normal_outputs = module(inputs) qat_outputs = qat_module(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) assertTensorClose(module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8) assertTensorClose(module.bn.running_var, qat_module.bn.running_var, max_err=5e-7) module.eval() normal_outputs = module(inputs) qat_module.eval() qat_outputs = qat_module(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
def test_quantize_qat(): net = FloatNet() qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) assert isinstance(qat_net.quant, QAT.QuantStub) assert isinstance(qat_net.linear[0], QAT.Linear) assert isinstance(qat_net.linear[1], QAT.Linear) assert isinstance(qat_net.dequant, QAT.DequantStub)
def test_quantize_batchmatmul_activation(): batch = 4 in_features = 8 out_features = 4 class TestNet(Module): def __init__(self, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.batch_mm = BatchMatMulActivation( batch, in_features, out_features, bias=bias ) def forward(self, inp): out = self.quant(inp) out = self.batch_mm(out) out = expand_dims(out, -1) out = self.dequant(out) return out inputs = tensor( np.random.randn(batch, in_features, out_features).astype(np.float32) ) for bias in (True, False): net = TestNet(bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) enable_fake_quant(qat_net) qat_outputs = qat_net(inputs) qnet = quantize(qat_net, inplace=False) qnet.eval() quantize_outputs = qnet(inputs) np.testing.assert_allclose( qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6 ) @jit.trace(capture_as_const=True) def f(x): qnet.eval() return qnet(x) f(inputs) file = io.BytesIO() f.dump(file, enable_nchw4=True) file.seek(0) dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
def get_qat_net(inp_dtype, net, num_inp=1, shape=(1, 16, 32, 32)): qat_net = quantize_qat(net) inps = [] for _ in range(num_inp): data1 = mge.tensor(np.random.random(shape)) * 16 data1 = data1.astype(inp_dtype) inp1 = mge.tensor(dtype.convert_from_qint8(data1.numpy())) inp1.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype)) inp1.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"] inps.append(inp1) return qat_net, inps
def test_disable_quantize(): class Net(Float.Module): def __init__(self): super().__init__() self.conv = Float.ConvBnRelu2d(3, 3, 3) self.conv.disable_quantize() def forward(self, x): return self.conv(x) net = Net() qat_net = quantize_qat(net, inplace=False) assert isinstance(qat_net.conv, Float.ConvBnRelu2d) assert isinstance(qat_net.conv.conv, Float.Conv2d)
def test_qat_conv(padding, padding_mode): in_channels = 32 out_channels = 64 kernel_size = 3 class TestNet(Module): def __init__(self, groups, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.conv = Conv2d( in_channels, out_channels, kernel_size, groups=groups, bias=bias, padding=padding, padding_mode=padding_mode, ) self.conv_relu = ConvRelu2d(out_channels, in_channels, kernel_size, groups=groups, bias=bias) def forward(self, inp): out = self.quant(inp) out = self.conv(out) out = self.conv_relu(out) out = self.dequant(out) return out inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) for groups, bias in product([1, 4], [True, False]): net = TestNet(groups, bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
def test_enable_and_disable_all(): x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) net = Net() y1 = net(x).numpy() net = quantize_qat(net, min_max_fakequant_qconfig) init_observer(net, x) y2 = net(x).numpy() disable_fake_quant(net) y3 = net(x).numpy() enable_fake_quant(net) y4 = net(x).numpy() np.testing.assert_allclose(y1, y3) np.testing.assert_allclose(y2, y4) with pytest.raises(AssertionError): np.testing.assert_allclose(y2, y3)
def test_qat_conv_qint8(): class QConvOpr(M.Module): def __init__(self): super().__init__() self.normal_conv = M.Conv2d( 3, 30, 3, stride=(2, 3), padding=(3, 1), dilation=(2, 2), ) self.normal_conv.bias = mge.Parameter( np.random.random(self.normal_conv.bias.shape).astype( np.float32)) def forward(self, x): x = self.normal_conv(x) return x net = QConvOpr() qat_net = quantize_qat(net) inp_dtype = dtype.qint8(16.0 / 128) data = mge.tensor(np.random.random((1, 3, 224, 224))) * 16 data = data.astype(inp_dtype) inp = mge.tensor(dtype.convert_from_qint8(data.numpy())) inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype)) inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"] traced_module, tm_result = get_traced_module(qat_net, inp) print(traced_module.flatten().graph) inp = inp.astype(inp_dtype) out_dtype = traced_module.graph.outputs[0].qparams scale = out_dtype.scale.numpy() _test_convert_result( inp, traced_module, tm_result, scale=scale, require_quantize=True, max_err=max_error, )
def test_qat_convrelu(): net = ConvRelu2dOpr() qat_net = quantize_qat(net) inp_dtype = dtype.qint8(16.0 / 128) data = mge.tensor(np.random.random((1, 3, 224, 224))) * 16 data = data.astype(inp_dtype) inp = mge.tensor(dtype.convert_from_qint8(data.numpy())) inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype)) inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"] traced_module, tm_result = get_traced_module(qat_net, inp) inp = inp.astype(inp_dtype) out_dtype = traced_module.graph.outputs[0].qparams scale = out_dtype.scale.numpy() _test_convert_result( inp, traced_module, tm_result, scale=scale, require_quantize=True, max_err=max_error, )
def test_qat_convbn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 for groups, bias in product([1, 4], [True, False]): module = ConvBn2d(in_channels, out_channels, kernel_size, groups=groups, bias=bias) module.train() qat_module = quantize_qat(module, inplace=False) disable_fake_quant(qat_module) inputs = tensor( np.random.randn(4, in_channels, 32, 32).astype(np.float32)) normal_outputs = module(inputs) # import pdb # pdb.set_trace() qat_outputs = qat_module(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6) np.testing.assert_allclose( module.bn.running_mean.numpy(), qat_module.bn.running_mean.numpy(), atol=5e-8, ) np.testing.assert_allclose( module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, ) module.eval() normal_outputs = module(inputs) qat_module.eval() qat_outputs = qat_module(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6)
def test_qat_batchmatmul_activation(): batch = 4 in_features = 8 out_features = 4 class TestNet(Module): def __init__(self, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.batch_mm = BatchMatMulActivation(batch, in_features, out_features, bias=bias) def forward(self, inp): out = self.quant(inp) out = self.batch_mm(out) out = self.dequant(out) return out inputs = tensor( np.random.randn(batch, in_features, out_features).astype(np.float32)) for bias in (True, False): net = TestNet(bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
def test_qat_conv_transpose2d(): in_channels = 32 out_channels = 64 kernel_size = 3 class TestNet(Module): def __init__(self, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.conv = ConvTranspose2d(in_channels, out_channels, kernel_size, bias=bias) def forward(self, inp): out = self.quant(inp) out = self.conv(out) out = self.dequant(out) return out inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) for bias in [True, False]: net = TestNet(bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
from test.utils import ConvOpr, dump_mge_model import megengine as mge import numpy as np from megengine.core.tensor import dtype from megengine.quantization.quantize import quantize_qat from megengine.traced_module import trace_module if __name__ == "__main__": net = ConvOpr("normal") traced_module = trace_module(net, mge.tensor(net.data)) mge.save(traced_module, "float_model.tm") dump_mge_model(net, net.data, "float_model") qat_net = quantize_qat(net) inp_dtype = dtype.qint8(16.0 / 128) data = mge.tensor(np.random.random((1, 3, 224, 224))) * 16 data = data.astype(inp_dtype) inp = mge.tensor(dtype.convert_from_qint8(data.numpy())) inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype)) inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"] qat_module = trace_module(qat_net, inp) mge.save(qat_module, "qat_model.tm")
def worker(rank, world_size, args): # pylint: disable=too-many-statements if world_size > 1: # Initialize distributed process group logger.info("init distributed process group {} / {}".format(rank, world_size)) dist.init_process_group( master_ip="localhost", master_port=23456, world_size=world_size, rank=rank, dev=rank, ) save_dir = os.path.join(args.save, args.arch + "." + args.mode) if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) mge.set_log_file(os.path.join(save_dir, "log.txt")) model = models.__dict__[args.arch]() cfg = config.get_finetune_config(args.arch) cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training total_batch_size = cfg.BATCH_SIZE * world_size steps_per_epoch = 1280000 // total_batch_size total_steps = steps_per_epoch * cfg.EPOCHS # load calibration model assert args.checkpoint logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) # Build valid datasets valid_dataset = data.dataset.ImageNet(args.data, train=False) # valid_dataset = ImageNetNoriDataset(args.data) valid_sampler = data.SequentialSampler( valid_dataset, batch_size=100, drop_last=False ) valid_queue = data.DataLoader( valid_dataset, sampler=valid_sampler, transform=T.Compose( [ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW"), ] ), num_workers=args.workers, ) # calibration model.fc.disable_quantize() model = quantize_qat(model, qconfig=Q.calibration_qconfig) # calculate scale @jit.trace(symbolic=True) def calculate_scale(image, label): model.eval() enable_observer(model) logits = model(image) loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) acc1, acc5 = F.accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() return loss, acc1, acc5 # model.fc.disable_quantize() infer(calculate_scale, valid_queue, args) # quantized model = quantize(model) # eval quantized model @jit.trace(symbolic=True) def eval_func(image, label): model.eval() logits = model(image) loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) acc1, acc5 = F.accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() return loss, acc1, acc5 _, valid_acc, valid_acc5 = infer(eval_func, valid_queue, args) logger.info("TEST %f, %f", valid_acc, valid_acc5) # save quantized model mge.save( {"step": -1, "state_dict": model.state_dict()}, os.path.join(save_dir, "checkpoint-calibration.pkl") ) logger.info("save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl")))
def worker(world_size, args): # pylint: disable=too-many-statements rank = dist.get_rank() if world_size > 1: logger.info("init distributed process group {} / {}".format( rank, world_size)) save_dir = os.path.join(args.save, args.arch + "." + args.mode) if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) mge.set_log_file(os.path.join(save_dir, "log.txt")) model = models.__dict__[args.arch]() cfg = config.get_config(args.arch) cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training total_batch_size = cfg.BATCH_SIZE * world_size steps_per_epoch = 1280000 // total_batch_size total_steps = steps_per_epoch * cfg.EPOCHS if args.mode != "normal": quantize_qat(model, qconfig=Q.ema_fakequant_qconfig) if world_size > 1: # Sync parameters dist.bcast_list_(model.parameters(), dist.WORLD) # Autodiff gradient manager gm = autodiff.GradManager().attach( model.parameters(), callbacks=dist.make_allreduce_cb("MEAN") if world_size > 1 else None, ) optimizer = optim.SGD( get_parameters(model, cfg), lr=cfg.LEARNING_RATE, momentum=cfg.MOMENTUM, ) # Define train and valid graph def train_func(image, label): with gm: model.train() logits = model(image) loss = F.loss.cross_entropy(logits, label, label_smooth=0.1) acc1, acc5 = F.topk_accuracy(logits, label, (1, 5)) gm.backward(loss) optimizer.step().clear_grad() return loss, acc1, acc5 def valid_func(image, label): model.eval() logits = model(image) loss = F.loss.cross_entropy(logits, label, label_smooth=0.1) acc1, acc5 = F.topk_accuracy(logits, label, (1, 5)) return loss, acc1, acc5 # Build train and valid datasets logger.info("preparing dataset..") train_dataset = data.dataset.ImageNet(args.data, train=True) train_sampler = data.Infinite( data.RandomSampler(train_dataset, batch_size=cfg.BATCH_SIZE, drop_last=True)) train_queue = data.DataLoader( train_dataset, sampler=train_sampler, transform=T.Compose([ T.RandomResizedCrop(224), T.RandomHorizontalFlip(), cfg.COLOR_JITTOR, T.Normalize(mean=128), T.ToMode("CHW"), ]), num_workers=args.workers, ) train_queue = iter(train_queue) valid_dataset = data.dataset.ImageNet(args.data, train=False) valid_sampler = data.SequentialSampler(valid_dataset, batch_size=100, drop_last=False) valid_queue = data.DataLoader( valid_dataset, sampler=valid_sampler, transform=T.Compose([ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW") ]), num_workers=args.workers, ) def adjust_learning_rate(step, epoch): learning_rate = cfg.LEARNING_RATE if cfg.SCHEDULER == "Linear": learning_rate *= 1 - float(step) / total_steps elif cfg.SCHEDULER == "Multistep": learning_rate *= cfg.SCHEDULER_GAMMA**bisect.bisect_right( cfg.SCHEDULER_STEPS, epoch) else: raise ValueError(cfg.SCHEDULER) for param_group in optimizer.param_groups: param_group["lr"] = learning_rate return learning_rate # Start training objs = AverageMeter("Loss") top1 = AverageMeter("Acc@1") top5 = AverageMeter("Acc@5") total_time = AverageMeter("Time") t = time.time() for step in range(0, total_steps): # Linear learning rate decay epoch = step // steps_per_epoch learning_rate = adjust_learning_rate(step, epoch) image, label = next(train_queue) image = mge.tensor(image, dtype="float32") label = mge.tensor(label, dtype="int32") n = image.shape[0] loss, acc1, acc5 = train_func(image, label) top1.update(100 * acc1.numpy()[0], n) top5.update(100 * acc5.numpy()[0], n) objs.update(loss.numpy()[0], n) total_time.update(time.time() - t) t = time.time() if step % args.report_freq == 0 and rank == 0: logger.info( "TRAIN e%d %06d %f %s %s %s %s", epoch, step, learning_rate, objs, top1, top5, total_time, ) objs.reset() top1.reset() top5.reset() total_time.reset() if step != 0 and step % 10000 == 0 and rank == 0: logger.info("SAVING %06d", step) mge.save( { "step": step, "state_dict": model.state_dict() }, os.path.join(save_dir, "checkpoint.pkl"), ) if step % 10000 == 0 and step != 0: _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5) mge.save( { "step": step, "state_dict": model.state_dict() }, os.path.join(save_dir, "checkpoint-final.pkl"), ) _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
def worker(world_size, args): # pylint: disable=too-many-statements rank = dist.get_rank() if world_size > 1: # Initialize distributed process group logger.info("init distributed process group {} / {}".format(rank, world_size)) save_dir = os.path.join(args.save, args.arch + "." + "calibration") if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) mge.set_log_file(os.path.join(save_dir, "log.txt")) model = models.__dict__[args.arch]() # load calibration model assert args.checkpoint logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) # Build valid datasets valid_dataset = data.dataset.ImageNet(args.data, train=False) valid_sampler = data.SequentialSampler( valid_dataset, batch_size=100, drop_last=False ) valid_queue = data.DataLoader( valid_dataset, sampler=valid_sampler, transform=T.Compose( [T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW")] ), num_workers=args.workers, ) # calibration model.fc.disable_quantize() model = quantize_qat(model, qconfig=Q.calibration_qconfig) # calculate scale def calculate_scale(image, label): model.eval() enable_observer(model) logits = model(image) loss = F.loss.cross_entropy(logits, label, label_smooth=0.1) acc1, acc5 = F.topk_accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size() acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size() acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size() return loss, acc1, acc5 infer(calculate_scale, valid_queue, args) # quantized model = quantize(model) # eval quantized model def eval_func(image, label): model.eval() logits = model(image) loss = F.loss.cross_entropy(logits, label, label_smooth=0.1) acc1, acc5 = F.topk_accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size() acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size() acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size() return loss, acc1, acc5 _, valid_acc, valid_acc5 = infer(eval_func, valid_queue, args) logger.info("TEST %f, %f", valid_acc, valid_acc5) # save quantized model mge.save( {"step": -1, "state_dict": model.state_dict()}, os.path.join(save_dir, "checkpoint-calibration.pkl"), ) logger.info( "save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl")) )
def main(): parser = argparse.ArgumentParser() parser.add_argument("-a", "--arch", default="resnet18", type=str) parser.add_argument("-c", "--checkpoint", default=None, type=str) parser.add_argument("-i", "--image", default=None, type=str) parser.add_argument( "-m", "--mode", default="quantized", type=str, choices=["normal", "qat", "quantized"], help="Quantization Mode\n" "normal: no quantization, using float32\n" "qat: quantization aware training, simulate int8\n" "quantized: convert mode to int8 quantized, inference only", ) parser.add_argument("--dump", action="store_true", help="Dump quantized model") args = parser.parse_args() model = models.__dict__[args.arch]() if args.mode != "normal": quantize_qat(model, qconfig=Q.ema_fakequant_qconfig) if args.mode == "quantized": quantize(model) if args.checkpoint: logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) rpath = os.path.realpath(__file__ + "/../../") if args.image is None: path = rpath + "/assets/cat.jpg" else: path = args.image image = cv2.imread(path, cv2.IMREAD_COLOR) transform = T.Compose( [T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW")] ) @trace(symbolic=True, capture_as_const=True) def infer_func(processed_img): model.eval() logits = model(processed_img) probs = F.softmax(logits) return probs processed_img = transform.apply(image)[np.newaxis, :] processed_img = mge.tensor(processed_img, dtype="float32") probs = infer_func(processed_img) top_probs, classes = F.topk(probs, k=5, descending=True) if args.dump: output_file = ".".join([args.arch, args.mode, "megengine"]) logger.info("Dump to {}".format(output_file)) infer_func.dump(output_file, arg_names=["data"]) mge.save(model.state_dict(), output_file.replace("megengine", "pkl")) with open(rpath + "/assets/imagenet_class_info.json") as fp: imagenet_class_index = json.load(fp) for rank, (prob, classid) in enumerate( zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1)) ): print( "{}: class = {:20s} with probability = {:4.1f} %".format( rank, imagenet_class_index[str(classid)][1], 100 * prob ) )
def quantized_resnet18(**kwargs): model = resnet18(**kwargs) quantize_qat(model) quantize(model) return model