Пример #1
0
 def prepare_for_quant(self, cfg):
     self.avgpool = prepare_qat_fx(
         self.avgpool,
         {"": torch.quantization.get_default_qat_qconfig()},
         self.custom_config_dict,
     )
     return self
Пример #2
0
def prepare_ptq_linear(qconfig):
    qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
    prepared_model = prepare_qat_fx(
        copy.deepcopy(float_model),
        qconfig_dict)  # fuse modules and insert observers
    calibrate(prepared_model,
              data_loader_test)  # run calibration on sample data
    return prepared_model
Пример #3
0
 def test_mobilenet_v2_qat(self):
     # verify that mobilenetv2 graph is able to be matched
     import torchvision
     m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float()
     mp = prepare_qat_fx(m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')})
     # TODO(future PR): prevent the need for copying here, we can copy the
     # modules but should reuse the underlying tensors
     mp_copy = copy.deepcopy(mp)
     mq = convert_fx(mp_copy)
     # assume success if no exceptions
     results = get_matching_subgraph_pairs(mp, mq)
Пример #4
0
def load_model(config, checkpoint):
    opt = config['opt']
    labels = load_label(opt.label_path)
    label_size = len(labels)
    config['labels'] = labels
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, opt.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 opt.embedding_path,
                                 label_size,
                                 emb_non_trainable=True)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         opt.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         opt.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
    else:
        from transformers import AutoTokenizer, AutoConfig, AutoModel
        bert_config = AutoConfig.from_pretrained(opt.bert_output_dir)
        bert_tokenizer = AutoTokenizer.from_pretrained(opt.bert_output_dir)
        bert_model = AutoModel.from_config(bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        model = ModelClass(config, bert_config, bert_model, bert_tokenizer,
                           label_size)
    if opt.enable_qat:
        assert opt.device == 'cpu'
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
        model.eval()
        model.to('cpu')
        logger.info("[Convert to quantized model with device=cpu]")
        model = torch.quantization.convert(model)
    if opt.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)
        logger.info("[Convert to quantized model]")
        model = quantize_fx.convert_fx(model)

    model.load_state_dict(checkpoint)
    model = model.to(opt.device)
    '''
    for name, param in model.named_parameters():
        print(name, param.data, param.device, param.requires_grad)
    '''
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[Model loaded]")
    return model
 def prepare(self, model, configs, attrs):
     model.another_layer = prepare_qat_fx(model.another_layer, configs[""])
     return model
Пример #6
0
 def prep_qat_train(self):
     qconfig_dict = {
         "": torch.quantization.get_default_qat_qconfig('fbgemm')
     }
     self.model.train()
     self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict)
Пример #7
0
def load_model(config, checkpoint):
    args = config['args']
    labels = load_label(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, args.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 args.embedding_path,
                                 label_size,
                                 emb_non_trainable=True)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
    else:
        if config['emb_class'] == 'bart' and config['use_kobart']:
            from transformers import BartModel
            from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
            bert_tokenizer = get_kobart_tokenizer()
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = BartModel.from_pretrained(get_pytorch_kobart_model())
            bert_config = bert_model.config
        elif config['emb_class'] in ['gpt']:
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_tokenizer.bos_token = '<|startoftext|>'
            bert_tokenizer.eos_token = '<|endoftext|>'
            bert_tokenizer.cls_token = '<|startoftext|>'
            bert_tokenizer.sep_token = '<|endoftext|>'
            bert_tokenizer.pad_token = '<|pad|>'
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = AutoModel.from_pretrained(args.bert_output_dir)
        elif config['emb_class'] in ['t5']:
            from transformers import T5EncoderModel
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = T5EncoderModel(bert_config)
        else:
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = AutoModel.from_config(bert_config)

        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        if config['enc_class'] == 'densenet-cnn':
            ModelClass = TextBertDensenetCNN

        model = ModelClass(config, bert_config, bert_model, bert_tokenizer,
                           label_size)

    if args.enable_qat:
        assert args.device == 'cpu'
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
        model.eval()
        model.to('cpu')
        logger.info("[Convert to quantized model with device=cpu]")
        model = torch.quantization.convert(model)
    if args.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)
        logger.info("[Convert to quantized model]")
        model = quantize_fx.convert_fx(model)

    if args.enable_diffq:
        quantizer = DiffQuantizer(model)
        config['quantizer'] = quantizer
        quantizer.restore_quantized_state(checkpoint)
    else:
        model.load_state_dict(checkpoint)

    model = model.to(args.device)
    ''' 
    for name, param in model.named_parameters():
        print(name, param.data, param.device, param.requires_grad)
    '''
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[Model loaded]")
    return model
Пример #8
0
def load(checkpoint_dir, model, **kwargs):
    """Execute the quantize process on the specified model.

    Args:
        checkpoint_dir (dir): The folder of checkpoint.
                              'best_configure.yaml' and 'best_model_weights.pt' are needed
                              in This directory. 'checkpoint' dir is under workspace folder
                              and workspace folder is define in configure yaml file.
        model (object): fp32 model need to do quantization.

    Returns:
        (object): quantized model
    """

    tune_cfg_file = os.path.join(
        os.path.abspath(os.path.expanduser(checkpoint_dir)),
        'best_configure.yaml')
    weights_file = os.path.join(
        os.path.abspath(os.path.expanduser(checkpoint_dir)),
        'best_model_weights.pt')
    assert os.path.exists(
        tune_cfg_file), "tune configure file %s didn't exist" % tune_cfg_file
    assert os.path.exists(
        weights_file), "weight file %s didn't exist" % weights_file

    with open(tune_cfg_file, 'r') as f:
        tune_cfg = yaml.safe_load(f)

    version = get_torch_version()
    if tune_cfg['approach'] != "post_training_dynamic_quant":
        if version < '1.7':
            q_mapping = tq.default_mappings.DEFAULT_MODULE_MAPPING
        elif version < '1.8':
            q_mapping = \
                tq.quantization_mappings.get_static_quant_module_mappings()
        else:
            q_mapping = \
                tq.quantization_mappings.get_default_static_quant_module_mappings()
    else:
        if version < '1.7':
            q_mapping = \
                tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING
        elif version < '1.8':
            q_mapping = \
                tq.quantization_mappings.get_dynamic_quant_module_mappings()
        else:
            q_mapping = \
                tq.quantization_mappings.get_default_dynamic_quant_module_mappings()

    if version < '1.7':
        white_list = \
            tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING \
            if tune_cfg['approach'] == 'post_training_dynamic_quant' else \
            tq.default_mappings.DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST - \
            {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding}
    elif version < '1.8':
        white_list = \
            tq.quantization_mappings.get_dynamic_quant_module_mappings() \
            if tune_cfg['approach'] == 'post_training_dynamic_quant' else \
            tq.quantization_mappings.get_qconfig_propagation_list() - \
            {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding}
    else:
        white_list = \
            tq.quantization_mappings.get_default_dynamic_quant_module_mappings() \
            if tune_cfg['approach'] == 'post_training_dynamic_quant' else \
            tq.quantization_mappings.get_default_qconfig_propagation_list() - \
            {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding}

    if tune_cfg['approach'] == "post_training_dynamic_quant":
        op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg['approach'])
    else:
        op_cfgs = _cfg_to_qconfig(tune_cfg)

    if tune_cfg['framework'] == "pytorch_fx":  # pragma: no cover
        # For torch.fx approach
        assert version >= '1.8', \
                      "Please use PyTroch 1.8 or higher version with pytorch_fx backend"
        from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx
        q_model = copy.deepcopy(model.eval())
        fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg['approach'])
        if tune_cfg['approach'] == "quant_aware_training":
            q_model.train()
            q_model = prepare_qat_fx(
                q_model,
                fx_op_cfgs,
                prepare_custom_config_dict=kwargs if kwargs != {} else None)
        else:
            q_model = prepare_fx(
                q_model,
                fx_op_cfgs,
                prepare_custom_config_dict=kwargs if kwargs != {} else None)
        q_model = convert_fx(q_model)
        weights = torch.load(weights_file)
        q_model.load_state_dict(weights)
        return q_model

    q_model = copy.deepcopy(model.eval())
    _propagate_qconfig(q_model,
                       op_cfgs,
                       white_list=white_list,
                       approach=tune_cfg['approach'])
    # sanity check common API misusage
    if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model.modules()):
        logger.warn(
            "None of the submodule got qconfig applied. Make sure you "
            "passed correct configuration through `qconfig_dict` or "
            "by assigning the `.qconfig` attribute directly on submodules")
    if tune_cfg['approach'] != "post_training_dynamic_quant":
        add_observer_(q_model)
    q_model = convert(q_model, mapping=q_mapping, inplace=True)
    weights = torch.load(weights_file)
    q_model.load_state_dict(weights)
    return q_model
Пример #9
0
def prepare_model(config, bert_model_name_or_path=None):
    args = config['args']
    emb_non_trainable = not args.embedding_trainable
    labels = load_label(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    # prepare model
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, args.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 args.embedding_path,
                                 label_size,
                                 emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
    else:
        model_name_or_path = args.bert_model_name_or_path
        if bert_model_name_or_path:
            model_name_or_path = bert_model_name_or_path

        if config['emb_class'] == 'bart' and config['use_kobart']:
            from transformers import BartModel
            from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
            bert_tokenizer = get_kobart_tokenizer()
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = BartModel.from_pretrained(get_pytorch_kobart_model())
        elif config['emb_class'] in ['gpt']:
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_tokenizer.bos_token = '<|startoftext|>'
            bert_tokenizer.eos_token = '<|endoftext|>'
            bert_tokenizer.cls_token = '<|startoftext|>'
            bert_tokenizer.sep_token = '<|endoftext|>'
            bert_tokenizer.pad_token = '<|pad|>'
            bert_model = AutoModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))
            # 3 new tokens added
            bert_model.resize_token_embeddings(len(bert_tokenizer))
        elif config['emb_class'] in ['t5']:
            from transformers import T5EncoderModel
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = T5EncoderModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))

        else:
            bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            bert_model = AutoModel.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path))

        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        if config['enc_class'] == 'densenet-cnn':
            ModelClass = TextBertDensenetCNN

        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           label_size,
                           feature_based=args.bert_use_feature_based,
                           finetune_last=args.bert_use_finetune_last)
    if args.restore_path:
        checkpoint = load_checkpoint(args.restore_path)
        model.load_state_dict(checkpoint)
    if args.enable_qat:
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
    if args.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        model.train()
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)

    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[model prepared]")
    return model
Пример #10
0
def prepare_model(config, bert_model_name_or_path=None):
    opt = config['opt']
    emb_non_trainable = not opt.embedding_trainable
    labels = load_label(opt.label_path)
    label_size = len(labels)
    config['labels'] = labels
    # prepare model
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, opt.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 opt.embedding_path,
                                 label_size,
                                 emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         opt.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         opt.embedding_path,
                                         label_size,
                                         emb_non_trainable=emb_non_trainable)
    else:
        model_name_or_path = opt.bert_model_name_or_path
        if bert_model_name_or_path:
            model_name_or_path = bert_model_name_or_path
        from transformers import AutoTokenizer, AutoConfig, AutoModel
        bert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        bert_model = AutoModel.from_pretrained(
            model_name_or_path, from_tf=bool(".ckpt" in model_name_or_path))
        bert_config = bert_model.config
        # bert model reduction
        reduce_bert_model(config, bert_model, bert_config)
        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        model = ModelClass(config,
                           bert_config,
                           bert_model,
                           bert_tokenizer,
                           label_size,
                           feature_based=opt.bert_use_feature_based)
    if opt.restore_path:
        checkpoint = load_checkpoint(opt.restore_path, device=opt.device)
        model.load_state_dict(checkpoint)
    if opt.enable_qat:
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
    if opt.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        model.train()
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)

    model.to(opt.device)
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[model prepared]")
    return model
Пример #11
0
def main(args):
    # data
    train_transform = tv.transforms.Compose([])
    if args.data_augmentation:
        train_transform.transforms.append(
            tv.transforms.RandomCrop(32, padding=4))
        train_transform.transforms.append(tv.transforms.RandomHorizontalFlip())
    train_transform.transforms.append(tv.transforms.ToTensor())
    normalize = tv.transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    train_transform.transforms.append(normalize)

    test_transform = tv.transforms.Compose(
        [tv.transforms.ToTensor(), normalize])

    train_dataset = tv.datasets.CIFAR10(root='data/',
                                        train=True,
                                        transform=train_transform,
                                        download=True)

    test_dataset = tv.datasets.CIFAR10(root='data/',
                                       train=False,
                                       transform=test_transform,
                                       download=True)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.bs,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=4)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.bs,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=4)

    # net
    net = tv.models.mobilenet_v2(num_classes=10)
    net.load_state_dict(torch.load('mobilenet_v2.pth', map_location='cpu'))
    net.dropout = torch.nn.Sequential()

    # quantization
    model_to_quantize = copy.deepcopy(net).to(device)
    qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')}
    model_to_quantize.train()
    model_prepared = prepare_qat_fx(model_to_quantize, qconfig_dict)
    # optimizer and loss
    optimizer = torch.optim.SGD(model_prepared.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 16], 0.1)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    # train
    model_prepared.train()
    for i_epoch in range(20):
        time_s = time.time()
        for i_iter, data in enumerate(train_loader):
            img, label = data
            img, label = img.to(device), label.to(device)

            optimizer.zero_grad()
            feat = model_prepared(img)
            loss = criterion(feat, label)
            loss.backward()
            optimizer.step()
            time_e = time.time()

            print(
                'Epoch:{:3}/20 || Iter: {:4}/{} || '
                'Loss: {:2.4f} '
                'ETA: {:.2f}min'.format(i_epoch + 1, i_iter + 1,
                                        len(train_loader), loss.item(),
                                        (time_e - time_s) * (20 - i_epoch) *
                                        len(train_loader) / (i_iter + 1) / 60))
        scheduler.step()

    # to int8
    model_int8 = convert_fx(model_prepared)
    torch.jit.save(torch.jit.script(model_int8), 'int8-qat.pth')

    # valid
    loaded_quantized_model = torch.jit.load('int8-qat.pth')
    correct = 0.
    total = 0.
    with torch.no_grad():
        loaded_quantized_model.eval()
        for images, labels in tqdm(test_loader):
            images = images
            labels = labels

            pred = loaded_quantized_model(images)

            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum().item()

        val_acc = correct / total
        print(val_acc)