def prepare_for_quant(self, cfg): self.avgpool = prepare_qat_fx( self.avgpool, {"": torch.quantization.get_default_qat_qconfig()}, self.custom_config_dict, ) return self
def prepare_ptq_linear(qconfig): qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]} prepared_model = prepare_qat_fx( copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers calibrate(prepared_model, data_loader_test) # run calibration on sample data return prepared_model
def test_mobilenet_v2_qat(self): # verify that mobilenetv2 graph is able to be matched import torchvision m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float() mp = prepare_qat_fx(m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')}) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions results = get_matching_subgraph_pairs(mp, mq)
def load_model(config, checkpoint): opt = config['opt'] labels = load_label(opt.label_path) label_size = len(labels) config['labels'] = labels if config['emb_class'] == 'glove': if config['enc_class'] == 'gnb': model = TextGloveGNB(config, opt.embedding_path, label_size) if config['enc_class'] == 'cnn': model = TextGloveCNN(config, opt.embedding_path, label_size, emb_non_trainable=True) if config['enc_class'] == 'densenet-cnn': model = TextGloveDensenetCNN(config, opt.embedding_path, label_size, emb_non_trainable=True) if config['enc_class'] == 'densenet-dsa': model = TextGloveDensenetDSA(config, opt.embedding_path, label_size, emb_non_trainable=True) else: from transformers import AutoTokenizer, AutoConfig, AutoModel bert_config = AutoConfig.from_pretrained(opt.bert_output_dir) bert_tokenizer = AutoTokenizer.from_pretrained(opt.bert_output_dir) bert_model = AutoModel.from_config(bert_config) ModelClass = TextBertCNN if config['enc_class'] == 'cls': ModelClass = TextBertCLS model = ModelClass(config, bert_config, bert_model, bert_tokenizer, label_size) if opt.enable_qat: assert opt.device == 'cpu' model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') ''' # fuse if applicable # model = torch.quantization.fuse_modules(model, [['']]) ''' model = torch.quantization.prepare_qat(model) model.eval() model.to('cpu') logger.info("[Convert to quantized model with device=cpu]") model = torch.quantization.convert(model) if opt.enable_qat_fx: import torch.quantization.quantize_fx as quantize_fx qconfig_dict = { "": torch.quantization.get_default_qat_qconfig('fbgemm') } model = quantize_fx.prepare_qat_fx(model, qconfig_dict) logger.info("[Convert to quantized model]") model = quantize_fx.convert_fx(model) model.load_state_dict(checkpoint) model = model.to(opt.device) ''' for name, param in model.named_parameters(): print(name, param.data, param.device, param.requires_grad) ''' logger.info("[model] :\n{}".format(model.__str__())) logger.info("[Model loaded]") return model
def prepare(self, model, configs, attrs): model.another_layer = prepare_qat_fx(model.another_layer, configs[""]) return model
def prep_qat_train(self): qconfig_dict = { "": torch.quantization.get_default_qat_qconfig('fbgemm') } self.model.train() self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict)
def load_model(config, checkpoint): args = config['args'] labels = load_label(args.label_path) label_size = len(labels) config['labels'] = labels if config['emb_class'] == 'glove': if config['enc_class'] == 'gnb': model = TextGloveGNB(config, args.embedding_path, label_size) if config['enc_class'] == 'cnn': model = TextGloveCNN(config, args.embedding_path, label_size, emb_non_trainable=True) if config['enc_class'] == 'densenet-cnn': model = TextGloveDensenetCNN(config, args.embedding_path, label_size, emb_non_trainable=True) if config['enc_class'] == 'densenet-dsa': model = TextGloveDensenetDSA(config, args.embedding_path, label_size, emb_non_trainable=True) else: if config['emb_class'] == 'bart' and config['use_kobart']: from transformers import BartModel from kobart import get_kobart_tokenizer, get_pytorch_kobart_model bert_tokenizer = get_kobart_tokenizer() bert_tokenizer.cls_token = '<s>' bert_tokenizer.sep_token = '</s>' bert_tokenizer.pad_token = '<pad>' bert_model = BartModel.from_pretrained(get_pytorch_kobart_model()) bert_config = bert_model.config elif config['emb_class'] in ['gpt']: bert_tokenizer = AutoTokenizer.from_pretrained( args.bert_output_dir) bert_tokenizer.bos_token = '<|startoftext|>' bert_tokenizer.eos_token = '<|endoftext|>' bert_tokenizer.cls_token = '<|startoftext|>' bert_tokenizer.sep_token = '<|endoftext|>' bert_tokenizer.pad_token = '<|pad|>' bert_config = AutoConfig.from_pretrained(args.bert_output_dir) bert_model = AutoModel.from_pretrained(args.bert_output_dir) elif config['emb_class'] in ['t5']: from transformers import T5EncoderModel bert_tokenizer = AutoTokenizer.from_pretrained( args.bert_output_dir) bert_tokenizer.cls_token = '<s>' bert_tokenizer.sep_token = '</s>' bert_tokenizer.pad_token = '<pad>' bert_config = AutoConfig.from_pretrained(args.bert_output_dir) bert_model = T5EncoderModel(bert_config) else: bert_tokenizer = AutoTokenizer.from_pretrained( args.bert_output_dir) bert_config = AutoConfig.from_pretrained(args.bert_output_dir) bert_model = AutoModel.from_config(bert_config) ModelClass = TextBertCNN if config['enc_class'] == 'cls': ModelClass = TextBertCLS if config['enc_class'] == 'densenet-cnn': ModelClass = TextBertDensenetCNN model = ModelClass(config, bert_config, bert_model, bert_tokenizer, label_size) if args.enable_qat: assert args.device == 'cpu' model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') ''' # fuse if applicable # model = torch.quantization.fuse_modules(model, [['']]) ''' model = torch.quantization.prepare_qat(model) model.eval() model.to('cpu') logger.info("[Convert to quantized model with device=cpu]") model = torch.quantization.convert(model) if args.enable_qat_fx: import torch.quantization.quantize_fx as quantize_fx qconfig_dict = { "": torch.quantization.get_default_qat_qconfig('fbgemm') } model = quantize_fx.prepare_qat_fx(model, qconfig_dict) logger.info("[Convert to quantized model]") model = quantize_fx.convert_fx(model) if args.enable_diffq: quantizer = DiffQuantizer(model) config['quantizer'] = quantizer quantizer.restore_quantized_state(checkpoint) else: model.load_state_dict(checkpoint) model = model.to(args.device) ''' for name, param in model.named_parameters(): print(name, param.data, param.device, param.requires_grad) ''' logger.info("[model] :\n{}".format(model.__str__())) logger.info("[Model loaded]") return model
def load(checkpoint_dir, model, **kwargs): """Execute the quantize process on the specified model. Args: checkpoint_dir (dir): The folder of checkpoint. 'best_configure.yaml' and 'best_model_weights.pt' are needed in This directory. 'checkpoint' dir is under workspace folder and workspace folder is define in configure yaml file. model (object): fp32 model need to do quantization. Returns: (object): quantized model """ tune_cfg_file = os.path.join( os.path.abspath(os.path.expanduser(checkpoint_dir)), 'best_configure.yaml') weights_file = os.path.join( os.path.abspath(os.path.expanduser(checkpoint_dir)), 'best_model_weights.pt') assert os.path.exists( tune_cfg_file), "tune configure file %s didn't exist" % tune_cfg_file assert os.path.exists( weights_file), "weight file %s didn't exist" % weights_file with open(tune_cfg_file, 'r') as f: tune_cfg = yaml.safe_load(f) version = get_torch_version() if tune_cfg['approach'] != "post_training_dynamic_quant": if version < '1.7': q_mapping = tq.default_mappings.DEFAULT_MODULE_MAPPING elif version < '1.8': q_mapping = \ tq.quantization_mappings.get_static_quant_module_mappings() else: q_mapping = \ tq.quantization_mappings.get_default_static_quant_module_mappings() else: if version < '1.7': q_mapping = \ tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING elif version < '1.8': q_mapping = \ tq.quantization_mappings.get_dynamic_quant_module_mappings() else: q_mapping = \ tq.quantization_mappings.get_default_dynamic_quant_module_mappings() if version < '1.7': white_list = \ tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING \ if tune_cfg['approach'] == 'post_training_dynamic_quant' else \ tq.default_mappings.DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST - \ {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding} elif version < '1.8': white_list = \ tq.quantization_mappings.get_dynamic_quant_module_mappings() \ if tune_cfg['approach'] == 'post_training_dynamic_quant' else \ tq.quantization_mappings.get_qconfig_propagation_list() - \ {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding} else: white_list = \ tq.quantization_mappings.get_default_dynamic_quant_module_mappings() \ if tune_cfg['approach'] == 'post_training_dynamic_quant' else \ tq.quantization_mappings.get_default_qconfig_propagation_list() - \ {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding} if tune_cfg['approach'] == "post_training_dynamic_quant": op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg['approach']) else: op_cfgs = _cfg_to_qconfig(tune_cfg) if tune_cfg['framework'] == "pytorch_fx": # pragma: no cover # For torch.fx approach assert version >= '1.8', \ "Please use PyTroch 1.8 or higher version with pytorch_fx backend" from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx q_model = copy.deepcopy(model.eval()) fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg['approach']) if tune_cfg['approach'] == "quant_aware_training": q_model.train() q_model = prepare_qat_fx( q_model, fx_op_cfgs, prepare_custom_config_dict=kwargs if kwargs != {} else None) else: q_model = prepare_fx( q_model, fx_op_cfgs, prepare_custom_config_dict=kwargs if kwargs != {} else None) q_model = convert_fx(q_model) weights = torch.load(weights_file) q_model.load_state_dict(weights) return q_model q_model = copy.deepcopy(model.eval()) _propagate_qconfig(q_model, op_cfgs, white_list=white_list, approach=tune_cfg['approach']) # sanity check common API misusage if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model.modules()): logger.warn( "None of the submodule got qconfig applied. Make sure you " "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules") if tune_cfg['approach'] != "post_training_dynamic_quant": add_observer_(q_model) q_model = convert(q_model, mapping=q_mapping, inplace=True) weights = torch.load(weights_file) q_model.load_state_dict(weights) return q_model
def prepare_model(config, bert_model_name_or_path=None): args = config['args'] emb_non_trainable = not args.embedding_trainable labels = load_label(args.label_path) label_size = len(labels) config['labels'] = labels # prepare model if config['emb_class'] == 'glove': if config['enc_class'] == 'gnb': model = TextGloveGNB(config, args.embedding_path, label_size) if config['enc_class'] == 'cnn': model = TextGloveCNN(config, args.embedding_path, label_size, emb_non_trainable=emb_non_trainable) if config['enc_class'] == 'densenet-cnn': model = TextGloveDensenetCNN(config, args.embedding_path, label_size, emb_non_trainable=emb_non_trainable) if config['enc_class'] == 'densenet-dsa': model = TextGloveDensenetDSA(config, args.embedding_path, label_size, emb_non_trainable=emb_non_trainable) else: model_name_or_path = args.bert_model_name_or_path if bert_model_name_or_path: model_name_or_path = bert_model_name_or_path if config['emb_class'] == 'bart' and config['use_kobart']: from transformers import BartModel from kobart import get_kobart_tokenizer, get_pytorch_kobart_model bert_tokenizer = get_kobart_tokenizer() bert_tokenizer.cls_token = '<s>' bert_tokenizer.sep_token = '</s>' bert_tokenizer.pad_token = '<pad>' bert_model = BartModel.from_pretrained(get_pytorch_kobart_model()) elif config['emb_class'] in ['gpt']: bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) bert_tokenizer.bos_token = '<|startoftext|>' bert_tokenizer.eos_token = '<|endoftext|>' bert_tokenizer.cls_token = '<|startoftext|>' bert_tokenizer.sep_token = '<|endoftext|>' bert_tokenizer.pad_token = '<|pad|>' bert_model = AutoModel.from_pretrained( model_name_or_path, from_tf=bool(".ckpt" in model_name_or_path)) # 3 new tokens added bert_model.resize_token_embeddings(len(bert_tokenizer)) elif config['emb_class'] in ['t5']: from transformers import T5EncoderModel bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) bert_tokenizer.cls_token = '<s>' bert_tokenizer.sep_token = '</s>' bert_tokenizer.pad_token = '<pad>' bert_model = T5EncoderModel.from_pretrained( model_name_or_path, from_tf=bool(".ckpt" in model_name_or_path)) else: bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) bert_model = AutoModel.from_pretrained( model_name_or_path, from_tf=bool(".ckpt" in model_name_or_path)) bert_config = bert_model.config # bert model reduction reduce_bert_model(config, bert_model, bert_config) ModelClass = TextBertCNN if config['enc_class'] == 'cls': ModelClass = TextBertCLS if config['enc_class'] == 'densenet-cnn': ModelClass = TextBertDensenetCNN model = ModelClass(config, bert_config, bert_model, bert_tokenizer, label_size, feature_based=args.bert_use_feature_based, finetune_last=args.bert_use_finetune_last) if args.restore_path: checkpoint = load_checkpoint(args.restore_path) model.load_state_dict(checkpoint) if args.enable_qat: model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') ''' # fuse if applicable # model = torch.quantization.fuse_modules(model, [['']]) ''' model = torch.quantization.prepare_qat(model) if args.enable_qat_fx: import torch.quantization.quantize_fx as quantize_fx model.train() qconfig_dict = { "": torch.quantization.get_default_qat_qconfig('fbgemm') } model = quantize_fx.prepare_qat_fx(model, qconfig_dict) logger.info("[model] :\n{}".format(model.__str__())) logger.info("[model prepared]") return model
def prepare_model(config, bert_model_name_or_path=None): opt = config['opt'] emb_non_trainable = not opt.embedding_trainable labels = load_label(opt.label_path) label_size = len(labels) config['labels'] = labels # prepare model if config['emb_class'] == 'glove': if config['enc_class'] == 'gnb': model = TextGloveGNB(config, opt.embedding_path, label_size) if config['enc_class'] == 'cnn': model = TextGloveCNN(config, opt.embedding_path, label_size, emb_non_trainable=emb_non_trainable) if config['enc_class'] == 'densenet-cnn': model = TextGloveDensenetCNN(config, opt.embedding_path, label_size, emb_non_trainable=emb_non_trainable) if config['enc_class'] == 'densenet-dsa': model = TextGloveDensenetDSA(config, opt.embedding_path, label_size, emb_non_trainable=emb_non_trainable) else: model_name_or_path = opt.bert_model_name_or_path if bert_model_name_or_path: model_name_or_path = bert_model_name_or_path from transformers import AutoTokenizer, AutoConfig, AutoModel bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) bert_model = AutoModel.from_pretrained( model_name_or_path, from_tf=bool(".ckpt" in model_name_or_path)) bert_config = bert_model.config # bert model reduction reduce_bert_model(config, bert_model, bert_config) ModelClass = TextBertCNN if config['enc_class'] == 'cls': ModelClass = TextBertCLS model = ModelClass(config, bert_config, bert_model, bert_tokenizer, label_size, feature_based=opt.bert_use_feature_based) if opt.restore_path: checkpoint = load_checkpoint(opt.restore_path, device=opt.device) model.load_state_dict(checkpoint) if opt.enable_qat: model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') ''' # fuse if applicable # model = torch.quantization.fuse_modules(model, [['']]) ''' model = torch.quantization.prepare_qat(model) if opt.enable_qat_fx: import torch.quantization.quantize_fx as quantize_fx model.train() qconfig_dict = { "": torch.quantization.get_default_qat_qconfig('fbgemm') } model = quantize_fx.prepare_qat_fx(model, qconfig_dict) model.to(opt.device) logger.info("[model] :\n{}".format(model.__str__())) logger.info("[model prepared]") return model
def main(args): # data train_transform = tv.transforms.Compose([]) if args.data_augmentation: train_transform.transforms.append( tv.transforms.RandomCrop(32, padding=4)) train_transform.transforms.append(tv.transforms.RandomHorizontalFlip()) train_transform.transforms.append(tv.transforms.ToTensor()) normalize = tv.transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) train_transform.transforms.append(normalize) test_transform = tv.transforms.Compose( [tv.transforms.ToTensor(), normalize]) train_dataset = tv.datasets.CIFAR10(root='data/', train=True, transform=train_transform, download=True) test_dataset = tv.datasets.CIFAR10(root='data/', train=False, transform=test_transform, download=True) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.bs, shuffle=True, pin_memory=True, num_workers=4) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.bs, shuffle=False, pin_memory=True, num_workers=4) # net net = tv.models.mobilenet_v2(num_classes=10) net.load_state_dict(torch.load('mobilenet_v2.pth', map_location='cpu')) net.dropout = torch.nn.Sequential() # quantization model_to_quantize = copy.deepcopy(net).to(device) qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')} model_to_quantize.train() model_prepared = prepare_qat_fx(model_to_quantize, qconfig_dict) # optimizer and loss optimizer = torch.optim.SGD(model_prepared.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 16], 0.1) criterion = torch.nn.CrossEntropyLoss().to(device) # train model_prepared.train() for i_epoch in range(20): time_s = time.time() for i_iter, data in enumerate(train_loader): img, label = data img, label = img.to(device), label.to(device) optimizer.zero_grad() feat = model_prepared(img) loss = criterion(feat, label) loss.backward() optimizer.step() time_e = time.time() print( 'Epoch:{:3}/20 || Iter: {:4}/{} || ' 'Loss: {:2.4f} ' 'ETA: {:.2f}min'.format(i_epoch + 1, i_iter + 1, len(train_loader), loss.item(), (time_e - time_s) * (20 - i_epoch) * len(train_loader) / (i_iter + 1) / 60)) scheduler.step() # to int8 model_int8 = convert_fx(model_prepared) torch.jit.save(torch.jit.script(model_int8), 'int8-qat.pth') # valid loaded_quantized_model = torch.jit.load('int8-qat.pth') correct = 0. total = 0. with torch.no_grad(): loaded_quantized_model.eval() for images, labels in tqdm(test_loader): images = images labels = labels pred = loaded_quantized_model(images) pred = torch.max(pred.data, 1)[1] total += labels.size(0) correct += (pred == labels).sum().item() val_acc = correct / total print(val_acc)