def test_repr(): class Net(M.Module): def __init__(self): super().__init__() self.conv_bn = M.ConvBnRelu2d(3, 3, 3) self.linear = M.Linear(3, 3) def forward(self, x): return x net = Net() ground_truth = ( "Net(\n" " (conv_bn): ConvBnRelu2d(\n" " (conv): Conv2d(3, 3, kernel_size=(3, 3))\n" " (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" " )\n" " (linear): Linear(in_features=3, out_features=3, bias=True)\n" ")") assert net.__repr__() == ground_truth quantize_qat(net) ground_truth = ( "Net(\n" " (conv_bn): QAT.ConvBnRelu2d(\n" " (conv): Conv2d(3, 3, kernel_size=(3, 3))\n" " (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" " (act_observer): ExponentialMovingAverageObserver()\n" " (act_fake_quant): FakeQuantize()\n" " (weight_observer): MinMaxObserver()\n" " (weight_fake_quant): FakeQuantize()\n" " )\n" " (linear): QAT.Linear(\n" " in_features=3, out_features=3, bias=True\n" " (act_observer): ExponentialMovingAverageObserver()\n" " (act_fake_quant): FakeQuantize()\n" " (weight_observer): MinMaxObserver()\n" " (weight_fake_quant): FakeQuantize()\n" " )\n" ")") assert net.__repr__() == ground_truth quantize(net) ground_truth = ( "Net(\n" " (conv_bn): Quantized.ConvBnRelu2d(3, 3, kernel_size=(3, 3))\n" " (linear): Quantized.Linear()\n" ")") assert net.__repr__() == ground_truth
def build_observered_net(net: M.Module, observer_cls): qat_net = Q.quantize_qat( net, qconfig=get_observer_config(observer_cls), mapping={MyConvBnRelu2d: MyQATConvBnRelu2d}, ) Q.enable_observer(qat_net) inp = Tensor(np.random.random(size=(5, 3, 32, 32))) qat_net.eval() qat_net(inp) Q.disable_observer(qat_net) return qat_net
def worker(rank, world_size, args): # pylint: disable=too-many-statements if world_size > 1: # Initialize distributed process group logger.info("init distributed process group {} / {}".format( rank, world_size)) dist.init_process_group( master_ip="localhost", master_port=23456, world_size=world_size, rank=rank, dev=rank, ) model = models.__dict__[args.arch]() if args.mode != "normal": Q.quantize_qat(model, Q.ema_fakequant_qconfig) if args.checkpoint: logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) if args.mode == "quantized": Q.quantize(model) # Define valid graph @jit.trace(symbolic=True) def valid_func(image, label): model.eval() logits = model(image) loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) acc1, acc5 = F.accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() return loss, acc1, acc5 # Build valid datasets logger.info("preparing dataset..") valid_dataset = data.dataset.ImageNet(args.data, train=False) valid_sampler = data.SequentialSampler(valid_dataset, batch_size=100, drop_last=False) valid_queue = data.DataLoader( valid_dataset, sampler=valid_sampler, transform=T.Compose([ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW"), ]), num_workers=args.workers, ) _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) logger.info("TEST %f, %f", valid_acc, valid_acc5)
def worker(rank, world_size, args): # pylint: disable=too-many-statements if world_size > 1: # Initialize distributed process group logger.info("init distributed process group {} / {}".format( rank, world_size)) dist.init_process_group( master_ip="localhost", master_port=23456, world_size=world_size, rank=rank, dev=rank, ) save_dir = os.path.join(args.save, args.arch + "." + args.mode) if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) mge.set_log_file(os.path.join(save_dir, "log.txt")) model = models.__dict__[args.arch]() cfg = config.get_finetune_config(args.arch) cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training total_batch_size = cfg.BATCH_SIZE * world_size steps_per_epoch = 1280000 // total_batch_size total_steps = steps_per_epoch * cfg.EPOCHS if args.mode != "normal": Q.quantize_qat(model, Q.ema_fakequant_qconfig) if args.checkpoint: logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) if args.mode == "quantized": raise ValueError("mode = quantized only used during inference") Q.quantize(model) optimizer = optim.SGD( get_parameters(model, cfg), lr=cfg.LEARNING_RATE, momentum=cfg.MOMENTUM, ) # Define train and valid graph @jit.trace(symbolic=True) def train_func(image, label): model.train() logits = model(image) loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) acc1, acc5 = F.accuracy(logits, label, (1, 5)) optimizer.backward(loss) # compute gradients if dist.is_distributed(): # all_reduce_mean loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size() acc1 = dist.all_reduce_sum(acc1, "train_acc1") / dist.get_world_size() acc5 = dist.all_reduce_sum(acc5, "train_acc5") / dist.get_world_size() return loss, acc1, acc5 @jit.trace(symbolic=True) def valid_func(image, label): model.eval() logits = model(image) loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) acc1, acc5 = F.accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() return loss, acc1, acc5 # Build train and valid datasets logger.info("preparing dataset..") train_dataset = data.dataset.ImageNet(args.data, train=True) train_sampler = data.Infinite( data.RandomSampler(train_dataset, batch_size=cfg.BATCH_SIZE, drop_last=True)) train_queue = data.DataLoader( train_dataset, sampler=train_sampler, transform=T.Compose([ T.RandomResizedCrop(224), T.RandomHorizontalFlip(), cfg.COLOR_JITTOR, T.Normalize(mean=128), T.ToMode("CHW"), ]), num_workers=args.workers, ) train_queue = iter(train_queue) valid_dataset = data.dataset.ImageNet(args.data, train=False) valid_sampler = data.SequentialSampler(valid_dataset, batch_size=100, drop_last=False) valid_queue = data.DataLoader( valid_dataset, sampler=valid_sampler, transform=T.Compose([ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW"), ]), num_workers=args.workers, ) def adjust_learning_rate(step, epoch): learning_rate = cfg.LEARNING_RATE if cfg.SCHEDULER == "Linear": learning_rate *= 1 - float(step) / total_steps elif cfg.SCHEDULER == "Multistep": learning_rate *= cfg.SCHEDULER_GAMMA**bisect.bisect_right( cfg.SCHEDULER_STEPS, epoch) else: raise ValueError(cfg.SCHEDULER) for param_group in optimizer.param_groups: param_group["lr"] = learning_rate return learning_rate # Start training objs = AverageMeter("Loss") top1 = AverageMeter("Acc@1") top5 = AverageMeter("Acc@5") total_time = AverageMeter("Time") t = time.time() for step in range(0, total_steps): # Linear learning rate decay epoch = step // steps_per_epoch learning_rate = adjust_learning_rate(step, epoch) image, label = next(train_queue) image = image.astype("float32") label = label.astype("int32") n = image.shape[0] optimizer.zero_grad() loss, acc1, acc5 = train_func(image, label) optimizer.step() top1.update(100 * acc1.numpy()[0], n) top5.update(100 * acc5.numpy()[0], n) objs.update(loss.numpy()[0], n) total_time.update(time.time() - t) t = time.time() if step % args.report_freq == 0 and rank == 0: logger.info("TRAIN e%d %06d %f %s %s %s %s", epoch, step, learning_rate, objs, top1, top5, total_time) objs.reset() top1.reset() top5.reset() total_time.reset() if step % 10000 == 0 and rank == 0: logger.info("SAVING %06d", step) mge.save( { "step": step, "state_dict": model.state_dict() }, os.path.join(save_dir, "checkpoint.pkl"), ) if step % 10000 == 0 and step != 0: _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5) mge.save({ "step": step, "state_dict": model.state_dict() }, os.path.join(save_dir, "checkpoint-final.pkl")) _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
def main(): parser = argparse.ArgumentParser() parser.add_argument("-a", "--arch", default="resnet18", type=str) parser.add_argument("-c", "--checkpoint", default=None, type=str) parser.add_argument("-i", "--image", default=None, type=str) parser.add_argument( "-m", "--mode", default="quantized", type=str, choices=["normal", "qat", "quantized"], help="Quantization Mode\n" "normal: no quantization, using float32\n" "qat: quantization aware training, simulate int8\n" "quantized: convert mode to int8 quantized, inference only") parser.add_argument("--dump", action="store_true", help="Dump quantized model") args = parser.parse_args() if args.mode == "quantized": mge.set_default_device("cpux") model = models.__dict__[args.arch]() if args.mode != "normal": Q.quantize_qat(model, Q.ema_fakequant_qconfig) if args.checkpoint: logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) if args.mode == "quantized": Q.quantize(model) if args.image is None: path = "../assets/cat.jpg" else: path = args.image image = cv2.imread(path, cv2.IMREAD_COLOR) transform = T.Compose([ T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW"), ]) @jit.trace(symbolic=True) def infer_func(processed_img): model.eval() logits = model(processed_img) probs = F.softmax(logits) return probs processed_img = transform.apply(image)[np.newaxis, :] if args.mode == "normal": processed_img = processed_img.astype("float32") elif args.mode == "quantized": processed_img = processed_img.astype("int8") probs = infer_func(processed_img) top_probs, classes = F.top_k(probs, k=5, descending=True) if args.dump: output_file = ".".join([args.arch, args.mode, "megengine"]) logger.info("Dump to {}".format(output_file)) infer_func.dump(output_file, arg_names=["data"]) mge.save(model.state_dict(), output_file.replace("megengine", "pkl")) with open("../assets/imagenet_class_info.json") as fp: imagenet_class_index = json.load(fp) for rank, (prob, classid) in enumerate( zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1))): print("{}: class = {:20s} with probability = {:4.1f} %".format( rank, imagenet_class_index[str(classid)][1], 100 * prob))