示例#1
0
def get_quantizer(model, args, optimizer=None):
    quantizer = None
    if args.diffq:
        quantizer = DiffQuantizer(
            model, min_size=args.q_min_size, group_size=8)
        if optimizer is not None:
            quantizer.setup_optimizer(optimizer)
    elif args.qat:
        quantizer = UniformQuantizer(
                model, bits=args.qat, min_size=args.q_min_size)
    return quantizer
示例#2
0
def get_quantizer(model, args, optimizer=None):
    """Return the quantizer given the XP quantization args."""
    quantizer = None
    if args.diffq:
        quantizer = DiffQuantizer(model,
                                  min_size=args.min_size,
                                  group_size=args.group_size)
        if optimizer is not None:
            quantizer.setup_optimizer(optimizer)
    elif args.qat:
        quantizer = UniformQuantizer(model,
                                     bits=args.qat,
                                     min_size=args.min_size)
    return quantizer
示例#3
0
def demucs(pretrained=True,
           extra=False,
           quantized=False,
           hq=False,
           channels=64):
    if not pretrained and (extra or quantized or hq):
        raise ValueError(
            "if extra or quantized is True, pretrained must be True.")
    model = Demucs(sources=SOURCES, channels=channels)
    if pretrained:
        name = 'demucs'
        if channels != 64:
            name += str(channels)
        quantizer = None
        if sum([extra, quantized, hq]) > 1:
            raise ValueError("Only one of extra, quantized, hq, can be True.")
        if quantized:
            quantizer = DiffQuantizer(model, group_size=8, min_size=1)
            name += '_quantized'
        if extra:
            name += '_extra'
        if hq:
            name += '_hq'
        _load_state(name, model, quantizer)
    return model
示例#4
0
def demucs(pretrained=True, extra=False, quantized=False):
    if not pretrained and (extra or quantized):
        raise ValueError("if extra or quantized is True, pretrained must be True.")
    model = Demucs()
    if pretrained:
        name = 'demucs'
        quantizer = None
        if extra and quantized:
            raise ValueError("Only one of extra or quantized can be True.")
        if quantized:
            quantizer = DiffQuantizer(model, group_size=8, min_size=1)
            name = 'demucs_quantized'
        if extra:
            name = 'demucs_extra'
        _load_state(name, model, quantizer)
    return model
示例#5
0
def load_model(name):
    model = Demucs()

    model_hash = {"demucs": "e07c671f", "demucs_quantized": "07afea75"}
    cp = name + "-" + model_hash[name] + ".th"

    if os.path.exists(cp):
        state = torch.load(cp)
    else:
        root = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
        state = torch.hub.load_state_dict_from_url(root + cp,
                                                   map_location="cpu",
                                                   check_hash=True)

    if "quantized" in name:
        quantizer = DiffQuantizer(model, group_size=8, min_size=1)
        buf = io.BytesIO(zlib.decompress(state["compressed"]))
        state = torch.load(buf, "cpu")
        quantizer.restore_quantized_state(state)
        quantizer.detach()
    else:
        model.load_state_dict(state)

    return model
示例#6
0
def load_model(config, checkpoint):
    args = config['args']
    labels = load_label(args.label_path)
    label_size = len(labels)
    config['labels'] = labels
    if config['emb_class'] == 'glove':
        if config['enc_class'] == 'gnb':
            model = TextGloveGNB(config, args.embedding_path, label_size)
        if config['enc_class'] == 'cnn':
            model = TextGloveCNN(config,
                                 args.embedding_path,
                                 label_size,
                                 emb_non_trainable=True)
        if config['enc_class'] == 'densenet-cnn':
            model = TextGloveDensenetCNN(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
        if config['enc_class'] == 'densenet-dsa':
            model = TextGloveDensenetDSA(config,
                                         args.embedding_path,
                                         label_size,
                                         emb_non_trainable=True)
    else:
        if config['emb_class'] == 'bart' and config['use_kobart']:
            from transformers import BartModel
            from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
            bert_tokenizer = get_kobart_tokenizer()
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_model = BartModel.from_pretrained(get_pytorch_kobart_model())
            bert_config = bert_model.config
        elif config['emb_class'] in ['gpt']:
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_tokenizer.bos_token = '<|startoftext|>'
            bert_tokenizer.eos_token = '<|endoftext|>'
            bert_tokenizer.cls_token = '<|startoftext|>'
            bert_tokenizer.sep_token = '<|endoftext|>'
            bert_tokenizer.pad_token = '<|pad|>'
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = AutoModel.from_pretrained(args.bert_output_dir)
        elif config['emb_class'] in ['t5']:
            from transformers import T5EncoderModel
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_tokenizer.cls_token = '<s>'
            bert_tokenizer.sep_token = '</s>'
            bert_tokenizer.pad_token = '<pad>'
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = T5EncoderModel(bert_config)
        else:
            bert_tokenizer = AutoTokenizer.from_pretrained(
                args.bert_output_dir)
            bert_config = AutoConfig.from_pretrained(args.bert_output_dir)
            bert_model = AutoModel.from_config(bert_config)

        ModelClass = TextBertCNN
        if config['enc_class'] == 'cls': ModelClass = TextBertCLS
        if config['enc_class'] == 'densenet-cnn':
            ModelClass = TextBertDensenetCNN

        model = ModelClass(config, bert_config, bert_model, bert_tokenizer,
                           label_size)

    if args.enable_qat:
        assert args.device == 'cpu'
        model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        '''
        # fuse if applicable
        # model = torch.quantization.fuse_modules(model, [['']])
        '''
        model = torch.quantization.prepare_qat(model)
        model.eval()
        model.to('cpu')
        logger.info("[Convert to quantized model with device=cpu]")
        model = torch.quantization.convert(model)
    if args.enable_qat_fx:
        import torch.quantization.quantize_fx as quantize_fx
        qconfig_dict = {
            "": torch.quantization.get_default_qat_qconfig('fbgemm')
        }
        model = quantize_fx.prepare_qat_fx(model, qconfig_dict)
        logger.info("[Convert to quantized model]")
        model = quantize_fx.convert_fx(model)

    if args.enable_diffq:
        quantizer = DiffQuantizer(model)
        config['quantizer'] = quantizer
        quantizer.restore_quantized_state(checkpoint)
    else:
        model.load_state_dict(checkpoint)

    model = model.to(args.device)
    ''' 
    for name, param in model.named_parameters():
        print(name, param.data, param.device, param.requires_grad)
    '''
    logger.info("[model] :\n{}".format(model.__str__()))
    logger.info("[Model loaded]")
    return model
示例#7
0
def prepare_others(config, model, data_loader, lr=None, weight_decay=None):
    args = config['args']
    accelerator = None
    if 'accelerator' in config: accelerator = config['accelerator']

    default_lr = args.lr
    if lr: default_lr = lr
    default_weight_decay = args.weight_decay
    if weight_decay: default_weight_decay = weight_decay

    num_update_steps_per_epoch = math.ceil(
        len(data_loader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.epoch * num_update_steps_per_epoch
    else:
        args.epoch = math.ceil(args.max_train_steps /
                               num_update_steps_per_epoch)
    if args.num_warmup_steps is None:
        if args.warmup_ratio:
            args.num_warmup_steps = args.max_train_steps * args.warmup_ratio
        if args.warmup_epoch:
            args.num_warmup_steps = num_update_steps_per_epoch * args.warmup_epoch
        if args.num_warmup_steps is None: args.num_warmup_steps = 0

    logger.info(
        f"(num_update_steps_per_epoch, max_train_steps, num_warmup_steps): ({num_update_steps_per_epoch}, {args.max_train_steps}, {args.num_warmup_steps})"
    )

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        default_weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=default_lr,
                      eps=args.adam_epsilon)

    if args.enable_diffq:
        quantizer = DiffQuantizer(model)
        quantizer.setup_optimizer(optimizer)
        config['quantizer'] = quantizer

    if accelerator:
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
                                      lr=default_lr,
                                      eps=args.adam_epsilon)
        model, optimizer, _ = accelerator.prepare(model, optimizer,
                                                  data_loader)

    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps)

    try:
        writer = SummaryWriter(log_dir=args.log_dir)
    except:
        writer = None
    return model, optimizer, scheduler, writer