예제 #1
0
def test_32to8(checkpoint, split='val'):
    net_fp32 = mobilenet_v2(num_classes=10)
    net_fp32.train()
    net_fp32.qconfig = torch.quantization.get_default_qat_qconfig(
        'fbgemm')  #fbgemm for pc; qnnpack for mobile
    torch.backends.quantized.engine = 'fbgemm'
    prepared_net_fp32 = torch.quantization.prepare_qat(net_fp32)
    prepared_net_fp32.load_state_dict(torch.load(checkpoint))
    net_int8 = torch.quantization.convert(prepared_net_fp32.cpu().eval())
    # net_int8.load_state_dict(torch.load(checkpoint))
    # print(torch.load(checkpoint))
    net_int8.eval()

    # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # net_int8.to(device)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    valset = torchvision.datasets.CIFAR10(root='./data',
                                          train=False,
                                          download=True,
                                          transform=transform)
    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=64,
                                            shuffle=False,
                                            num_workers=8)
    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=True,
                                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=8)
    if split == 'train':
        loader = trainloader
        dataset = trainset
    else:
        loader = valloader
        dataset = valset

    with torch.no_grad():
        num_samples = len(dataset)
        counter = 0
        for i, data in tqdm(enumerate(loader, 0)):
            inputs, labels = data
            # inputs = inputs.to(device)
            out = net_int8(inputs).cpu().numpy()
            out = np.argmax(out, axis=1)

            labels = labels.cpu().numpy()
            diff = out - labels
            counter += len(np.where(diff == 0)[0])
    return counter / num_samples * 100
예제 #2
0
 def test_mobilenet_v2(self):
     from torchvision.models.quantization import mobilenet_v2
     self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
예제 #3
0
def load_model(model_option: dict, num_classes: int):
    model_name = model_option["model"]
    if model_name == "resnet18":
        model = models.resnet18(pretrained=model_option["pretrained"])
        set_parameter_requires_grad(model,
                                    model_option["feature_extract_flag"])
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == "mobilenetv2":
        model = models.mobilenet_v2(pretrained=model_option["pretrained"])
        set_parameter_requires_grad(model,
                                    model_option["feature_extract_flag"])
        model.classifier[1] = nn.Linear(
            in_features=model.classifier[1].in_features,
            out_features=num_classes)
    elif model_name == "mobilenetv2_q":
        model = MobileNetV2(num_classes=num_classes,
                            width_mult=model_option["width_mult"],
                            pretrained=model_option["pretrained"])

        set_parameter_requires_grad(model,
                                    model_option["feature_extract_flag"])
        model.classifier[1] = nn.Linear(
            in_features=model.classifier[1].in_features,
            out_features=num_classes)
    # elif model_name == "mnasNet":
    #     models = models.mnasnet1_0(pretrained=model_option["pretrained"])
    elif model_name == "squeezenet":
        model = models.squeezenet1_0(pretrained=model_option["pretrained"])
        set_parameter_requires_grad(model,
                                    model_option["feature_extract_flag"])
        model.classifier[1] = nn.Conv2d(512,
                                        num_classes,
                                        kernel_size=(1, 1),
                                        stride=(1, 1))
        model.num_classes = num_classes
    elif model_name == "vgg11_bn":
        model = models.vgg11_bn(pretrained=model_option["pretrained"])
        set_parameter_requires_grad(model,
                                    model_option["feature_extract_flag"])
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs, num_classes)
    elif model_name == "vgg16":
        model = models.vgg16(pretrained=model_option["pretrained"])
        set_parameter_requires_grad(model,
                                    model_option["feature_extract_flag"])
        num_features = model.classifier[6].in_features
        features = list(model.classifier.children())[:-1]  # Remove last layer
        features.extend([nn.Linear(num_features, num_classes)
                         ])  # Add our layer with 4 outputs
        model.classifier = nn.Sequential(
            *features)  # Replace the models classifier
    # elif model_name == "shufflenet":
    #     models = models.shufflenet_v2_x1_0(pretrained=model_option["pretrained"])
    elif model_name == "densenet":
        model = models.densenet161(pretrained=model_option["pretrained"])
        set_parameter_requires_grad(model,
                                    model_option["feature_extract_flag"])
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, num_classes)
    elif model_name == "mobilenetv2_q_ssd":
        model = SSD(num_classes=num_classes + 1,
                    backbone_network=model_option["backbone"])
    #############################################################################################################
    # Quantized Models
    elif model_name == "":
        model = q_models.mobilenet_v2()
    else:
        raise Exception("Wrong Model Name... Check config.json " + model_name)

    return model
예제 #4
0
def train(args):
    os.makedirs(args.cp, exist_ok=True)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=True,
                                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers)

    valset = torchvision.datasets.CIFAR10(root='./data',
                                          train=False,
                                          download=True,
                                          transform=transform)
    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    net_fp32 = mobilenet_v2(num_classes=10)
    net_fp32.train()
    net_fp32.qconfig = torch.quantization.get_default_qat_qconfig(
        'fbgemm')  #fbgemm for pc; qnnpack for mobile
    torch.backends.quantized.engine = 'fbgemm'

    prepared_net_fp32 = torch.quantization.prepare_qat(net_fp32)

    if args.pretrained:
        print("=> Using pretrained model: {}".format(args.pretrained))
        prepared_net_fp32.load_state_dict(torch.load(args.pretrained))

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    prepared_net_fp32.to(device)
    '''
    training loop start here
    '''
    criterion = nn.CrossEntropyLoss().to(device)
    # optimizer = optim.SGD(prepared_net_fp32.parameters(), lr=0.001, momentum=0.9)
    optimizer = optim.Adam(prepared_net_fp32.parameters(), lr=1e-4)
    for epoch in range(args.num_epoches):
        running_loss = 0.0
        counter = 0.0
        print("=> Training phase:")
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = prepared_net_fp32(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            outputs = outputs.detach().cpu().numpy()
            outputs = np.argmax(outputs, axis=1)
            labels = labels.cpu().numpy()
            diff = outputs - labels
            counter += len(np.where(diff == 0)[0])

            if i % 35 == 34:
                accuracy = counter / (35 * args.batch_size)
                print('[%d, %5d] loss: %.3f - acc: %.3f' %
                      (epoch + 1, i + 1, running_loss /
                       (35 * args.batch_size), accuracy))
                running_loss = 0.0
                counter = 0

        print("=> int8 evaluation phase:")
        net_int8 = torch.quantization.convert(prepared_net_fp32.cpu().eval())
        evaluation(args,
                   net_int8,
                   valloader,
                   criterion,
                   valset,
                   args.cp,
                   bitwidths='int8')
        print("=> fp32 evaluation phase:")
        evaluation(args,
                   prepared_net_fp32,
                   valloader,
                   criterion,
                   valset,
                   args.cp,
                   bitwidths='fp32')
    '''
    training loop end here
    '''
    print('=> Finished training')

    prepared_net_fp32.cpu().eval()
    net_int8 = torch.quantization.convert(prepared_net_fp32)
    torch.save(prepared_net_fp32.state_dict(),
               os.path.join(args.cp, "last_fp32.pth"))
    torch.save(net_int8.state_dict(), os.path.join(args.cp, "last_int8.pth"))