예제 #1
0
def conv_quant(module, cfg):
    out_int_bits = module.out_int_bits

    w_int_bits = int_bits(module.weight)
    w_frac_bits = cfg['bits'] - w_int_bits - cfg['signed']
    assert len(module.inp_int_bits) == 1
    inp_frac_bits = cfg['bits'] - module.inp_int_bits[0] - cfg['signed']

    if out_int_bits + w_frac_bits + inp_frac_bits > cfg['accum_bits'] - cfg[
            'signed']:
        w_frac_bits -= out_int_bits + w_frac_bits + \
            inp_frac_bits - cfg['accum_bits'] + cfg['signed']

    bias_frac_bits = cfg['bits'] - out_int_bits - cfg['signed']
    params = {
        'norm': [
            max(
                0, out_int_bits + w_frac_bits + inp_frac_bits - cfg['bits'] +
                cfg['signed'])
        ],
        'weight':
        integerize(module.weight.data, w_frac_bits,
                   cfg['bits']).cpu().tolist(),
        'bias':
        integerize(module.bias.data, bias_frac_bits,
                   cfg['bits']).cpu().tolist(),
        'w_frac_bits': [w_frac_bits],
        'b_frac_bits': [bias_frac_bits],
        'inp_frac_bits': [inp_frac_bits]
    }
    params['out_frac_bits'] = [w_frac_bits + inp_frac_bits - params['norm'][0]]
    return params
예제 #2
0
    def collect_stats(self):
        handles = []
        for module in self.model.modules():
            handles.append(module.register_forward_hook(stats_hook))

        dataset = Folder(self.cfg['data_source'], self.loader, self.transform)

        if self.cfg['verbose']:
            logging.info('{} images are used to collect statistics'.format(len(dataset)))

        dataloader = DataLoader(dataset,
                                batch_size=self.cfg['batch_size'],
                                shuffle=False,
                                num_workers=self.cfg['num_workers'],
                                drop_last=False)
        self.model.eval()

        if self.cfg['use_gpu']:
            self.model.cuda()

        with torch.no_grad():
            for imgs in tqdm(dataloader):
                if self.cfg['raw_input']:
                    imgs.int_bits = self.cfg['bits'] - self.cfg['signed']
                else:
                    imgs.int_bits = int_bits(imgs)
                if self.cfg['use_gpu']:
                    imgs = imgs.cuda()
                _ = self.model(imgs)

        for handle in handles:  # delete forward hooks
            handle.remove()

        if self.cfg['use_gpu']:
            self.model.cpu()
예제 #3
0
def stats_hook(module, inputs, output):
    inp_int_bits = get_int_bits(inputs)

    if not hasattr(module, 'inp_int_bits'):
        module.inp_int_bits = inp_int_bits
    else:
        for idx, (curr_inp_int_bits, new_inp_int_bits) in enumerate(zip(module.inp_int_bits, inp_int_bits)):
            if new_inp_int_bits > curr_inp_int_bits:
                module.inp_int_bits[idx] = new_inp_int_bits

    if isinstance(module, nn.Conv2d):
        out_int_bits = int_bits(output)
    else:
        out_int_bits = max(inp_int_bits)

    if module.__class__ in module_classes(nn) \
            and not isinstance(module, (nn.Sequential, nn.ModuleList)) \
            or module.__class__ in module_classes(gap_quantization.layers):
        # ignore custom modules: Fire, Bottleneck, ..., high-level PyTorch modules
        if not hasattr(module, 'out_int_bits') or out_int_bits > module.out_int_bits:
            module.out_int_bits = out_int_bits
        # propagate info through the network
        set_param(output, 'int_bits', out_int_bits)