Exemplo n.º 1
0
 def test_quantization_saved(self):
     for fake_yaml in [
             'dynamic_yaml.yaml', 'qat_yaml.yaml', 'ptq_yaml.yaml'
     ]:
         if fake_yaml == 'dynamic_yaml.yaml':
             model = torchvision.models.resnet18()
         else:
             model = copy.deepcopy(self.model)
         if fake_yaml == 'ptq_yaml.yaml':
             model.eval().fuse_model()
         quantizer = Quantization(fake_yaml)
         dataset = quantizer.dataset('dummy', (100, 3, 256, 256),
                                     label=True)
         quantizer.model = common.Model(model)
         if fake_yaml == 'qat_yaml.yaml':
             quantizer.q_func = q_func
         else:
             quantizer.calib_dataloader = common.DataLoader(dataset)
         quantizer.eval_dataloader = common.DataLoader(dataset)
         q_model = quantizer()
         q_model.save('./saved')
         # Load configure and weights by lpot.utils
         saved_model = load("./saved", model)
         eval_func(saved_model)
         shutil.rmtree('./saved', ignore_errors=True)
     from lpot.experimental import Benchmark
     evaluator = Benchmark('ptq_yaml.yaml')
     # Load configure and weights by lpot.model
     evaluator.model = common.Model(model)
     evaluator.b_dataloader = common.DataLoader(dataset)
     evaluator()
     evaluator.model = common.Model(model)
     evaluator()
Exemplo n.º 2
0
    def test_quantization_saved(self):
        from lpot.utils.pytorch import load

        for fake_yaml in [
                'dynamic_yaml.yaml', 'qat_yaml.yaml', 'ptq_yaml.yaml'
        ]:
            if fake_yaml == 'dynamic_yaml.yaml':
                model = torchvision.models.resnet18()
            else:
                model = copy.deepcopy(self.model)
            if fake_yaml == 'ptq_yaml.yaml':
                model.eval().fuse_model()
            quantizer = Quantization(fake_yaml)
            dataset = quantizer.dataset('dummy', (100, 3, 256, 256),
                                        label=True)
            quantizer.model = common.Model(model)
            quantizer.calib_dataloader = common.DataLoader(dataset)
            quantizer.eval_dataloader = common.DataLoader(dataset)
            if fake_yaml == 'qat_yaml.yaml':
                quantizer.q_func = q_func
            q_model = quantizer()
            q_model.save('./saved')
            # Load configure and weights by lpot.utils
            saved_model = load("./saved", model)
            eval_func(saved_model)
        from lpot.experimental import Benchmark
        evaluator = Benchmark('ptq_yaml.yaml')
        # Load configure and weights by lpot.model
        evaluator.model = common.Model(model)
        evaluator.b_dataloader = common.DataLoader(dataset)
        results = evaluator()
        evaluator.model = common.Model(model)
        fp32_results = evaluator()
        self.assertTrue(
            (fp32_results['accuracy'][0] - results['accuracy'][0]) < 0.01)
Exemplo n.º 3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    print("Use CPU: {} for training".format(gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True, quantize=False)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallelCPU(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    if args.tune:

        def training_func_for_lpot(model):
            epochs = 8
            iters = 30
            optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
            for nepoch in range(epochs):
                model.train()
                cnt = 0
                for image, target in train_loader:
                    print('.', end='')
                    cnt += 1
                    output = model(image)
                    loss = criterion(output, target)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    if cnt >= iters:
                        break

                if nepoch > 3:
                    # Freeze quantizer parameters
                    model.apply(torch.quantization.disable_observer)
                if nepoch > 2:
                    # Freeze batch norm mean and variance estimates
                    model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

            return

        model.module.fuse_model()
        from lpot.experimental import Quantization, common
        quantizer = Quantization(args.config)
        quantizer.model = common.Model(model)
        quantizer.q_func = training_func_for_lpot
        quantizer.eval_dataloader = val_loader
        q_model = quantizer()
        q_model.save(args.tuned_checkpoint)
        return

    if args.benchmark:
        model.eval()
        model.module.fuse_model()
        if args.int8:
            from lpot.utils.pytorch import load
            new_model = load(
                os.path.abspath(os.path.expanduser(args.tuned_checkpoint)),
                model)
        else:
            new_model = model
        validate(val_loader, new_model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)