def test_32to8(checkpoint, split='val'): net_fp32 = mobilenet_v2(num_classes=10) net_fp32.train() net_fp32.qconfig = torch.quantization.get_default_qat_qconfig( 'fbgemm') #fbgemm for pc; qnnpack for mobile torch.backends.quantized.engine = 'fbgemm' prepared_net_fp32 = torch.quantization.prepare_qat(net_fp32) prepared_net_fp32.load_state_dict(torch.load(checkpoint)) net_int8 = torch.quantization.convert(prepared_net_fp32.cpu().eval()) # net_int8.load_state_dict(torch.load(checkpoint)) # print(torch.load(checkpoint)) net_int8.eval() # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # net_int8.to(device) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=False, num_workers=8) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=8) if split == 'train': loader = trainloader dataset = trainset else: loader = valloader dataset = valset with torch.no_grad(): num_samples = len(dataset) counter = 0 for i, data in tqdm(enumerate(loader, 0)): inputs, labels = data # inputs = inputs.to(device) out = net_int8(inputs).cpu().numpy() out = np.argmax(out, axis=1) labels = labels.cpu().numpy() diff = out - labels counter += len(np.where(diff == 0)[0]) return counter / num_samples * 100
def test_mobilenet_v2(self): from torchvision.models.quantization import mobilenet_v2 self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
def load_model(model_option: dict, num_classes: int): model_name = model_option["model"] if model_name == "resnet18": model = models.resnet18(pretrained=model_option["pretrained"]) set_parameter_requires_grad(model, model_option["feature_extract_flag"]) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, num_classes) elif model_name == "mobilenetv2": model = models.mobilenet_v2(pretrained=model_option["pretrained"]) set_parameter_requires_grad(model, model_option["feature_extract_flag"]) model.classifier[1] = nn.Linear( in_features=model.classifier[1].in_features, out_features=num_classes) elif model_name == "mobilenetv2_q": model = MobileNetV2(num_classes=num_classes, width_mult=model_option["width_mult"], pretrained=model_option["pretrained"]) set_parameter_requires_grad(model, model_option["feature_extract_flag"]) model.classifier[1] = nn.Linear( in_features=model.classifier[1].in_features, out_features=num_classes) # elif model_name == "mnasNet": # models = models.mnasnet1_0(pretrained=model_option["pretrained"]) elif model_name == "squeezenet": model = models.squeezenet1_0(pretrained=model_option["pretrained"]) set_parameter_requires_grad(model, model_option["feature_extract_flag"]) model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) model.num_classes = num_classes elif model_name == "vgg11_bn": model = models.vgg11_bn(pretrained=model_option["pretrained"]) set_parameter_requires_grad(model, model_option["feature_extract_flag"]) num_ftrs = model.classifier[6].in_features model.classifier[6] = nn.Linear(num_ftrs, num_classes) elif model_name == "vgg16": model = models.vgg16(pretrained=model_option["pretrained"]) set_parameter_requires_grad(model, model_option["feature_extract_flag"]) num_features = model.classifier[6].in_features features = list(model.classifier.children())[:-1] # Remove last layer features.extend([nn.Linear(num_features, num_classes) ]) # Add our layer with 4 outputs model.classifier = nn.Sequential( *features) # Replace the models classifier # elif model_name == "shufflenet": # models = models.shufflenet_v2_x1_0(pretrained=model_option["pretrained"]) elif model_name == "densenet": model = models.densenet161(pretrained=model_option["pretrained"]) set_parameter_requires_grad(model, model_option["feature_extract_flag"]) num_ftrs = model.classifier.in_features model.classifier = nn.Linear(num_ftrs, num_classes) elif model_name == "mobilenetv2_q_ssd": model = SSD(num_classes=num_classes + 1, backbone_network=model_option["backbone"]) ############################################################################################################# # Quantized Models elif model_name == "": model = q_models.mobilenet_v2() else: raise Exception("Wrong Model Name... Check config.json " + model_name) return model
def train(args): os.makedirs(args.cp, exist_ok=True) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') net_fp32 = mobilenet_v2(num_classes=10) net_fp32.train() net_fp32.qconfig = torch.quantization.get_default_qat_qconfig( 'fbgemm') #fbgemm for pc; qnnpack for mobile torch.backends.quantized.engine = 'fbgemm' prepared_net_fp32 = torch.quantization.prepare_qat(net_fp32) if args.pretrained: print("=> Using pretrained model: {}".format(args.pretrained)) prepared_net_fp32.load_state_dict(torch.load(args.pretrained)) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') prepared_net_fp32.to(device) ''' training loop start here ''' criterion = nn.CrossEntropyLoss().to(device) # optimizer = optim.SGD(prepared_net_fp32.parameters(), lr=0.001, momentum=0.9) optimizer = optim.Adam(prepared_net_fp32.parameters(), lr=1e-4) for epoch in range(args.num_epoches): running_loss = 0.0 counter = 0.0 print("=> Training phase:") for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = prepared_net_fp32(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() outputs = outputs.detach().cpu().numpy() outputs = np.argmax(outputs, axis=1) labels = labels.cpu().numpy() diff = outputs - labels counter += len(np.where(diff == 0)[0]) if i % 35 == 34: accuracy = counter / (35 * args.batch_size) print('[%d, %5d] loss: %.3f - acc: %.3f' % (epoch + 1, i + 1, running_loss / (35 * args.batch_size), accuracy)) running_loss = 0.0 counter = 0 print("=> int8 evaluation phase:") net_int8 = torch.quantization.convert(prepared_net_fp32.cpu().eval()) evaluation(args, net_int8, valloader, criterion, valset, args.cp, bitwidths='int8') print("=> fp32 evaluation phase:") evaluation(args, prepared_net_fp32, valloader, criterion, valset, args.cp, bitwidths='fp32') ''' training loop end here ''' print('=> Finished training') prepared_net_fp32.cpu().eval() net_int8 = torch.quantization.convert(prepared_net_fp32) torch.save(prepared_net_fp32.state_dict(), os.path.join(args.cp, "last_fp32.pth")) torch.save(net_int8.state_dict(), os.path.join(args.cp, "last_int8.pth"))