def conv_quant(module, cfg): out_int_bits = module.out_int_bits w_int_bits = int_bits(module.weight) w_frac_bits = cfg['bits'] - w_int_bits - cfg['signed'] assert len(module.inp_int_bits) == 1 inp_frac_bits = cfg['bits'] - module.inp_int_bits[0] - cfg['signed'] if out_int_bits + w_frac_bits + inp_frac_bits > cfg['accum_bits'] - cfg[ 'signed']: w_frac_bits -= out_int_bits + w_frac_bits + \ inp_frac_bits - cfg['accum_bits'] + cfg['signed'] bias_frac_bits = cfg['bits'] - out_int_bits - cfg['signed'] params = { 'norm': [ max( 0, out_int_bits + w_frac_bits + inp_frac_bits - cfg['bits'] + cfg['signed']) ], 'weight': integerize(module.weight.data, w_frac_bits, cfg['bits']).cpu().tolist(), 'bias': integerize(module.bias.data, bias_frac_bits, cfg['bits']).cpu().tolist(), 'w_frac_bits': [w_frac_bits], 'b_frac_bits': [bias_frac_bits], 'inp_frac_bits': [inp_frac_bits] } params['out_frac_bits'] = [w_frac_bits + inp_frac_bits - params['norm'][0]] return params
def collect_stats(self): handles = [] for module in self.model.modules(): handles.append(module.register_forward_hook(stats_hook)) dataset = Folder(self.cfg['data_source'], self.loader, self.transform) if self.cfg['verbose']: logging.info('{} images are used to collect statistics'.format(len(dataset))) dataloader = DataLoader(dataset, batch_size=self.cfg['batch_size'], shuffle=False, num_workers=self.cfg['num_workers'], drop_last=False) self.model.eval() if self.cfg['use_gpu']: self.model.cuda() with torch.no_grad(): for imgs in tqdm(dataloader): if self.cfg['raw_input']: imgs.int_bits = self.cfg['bits'] - self.cfg['signed'] else: imgs.int_bits = int_bits(imgs) if self.cfg['use_gpu']: imgs = imgs.cuda() _ = self.model(imgs) for handle in handles: # delete forward hooks handle.remove() if self.cfg['use_gpu']: self.model.cpu()
def stats_hook(module, inputs, output): inp_int_bits = get_int_bits(inputs) if not hasattr(module, 'inp_int_bits'): module.inp_int_bits = inp_int_bits else: for idx, (curr_inp_int_bits, new_inp_int_bits) in enumerate(zip(module.inp_int_bits, inp_int_bits)): if new_inp_int_bits > curr_inp_int_bits: module.inp_int_bits[idx] = new_inp_int_bits if isinstance(module, nn.Conv2d): out_int_bits = int_bits(output) else: out_int_bits = max(inp_int_bits) if module.__class__ in module_classes(nn) \ and not isinstance(module, (nn.Sequential, nn.ModuleList)) \ or module.__class__ in module_classes(gap_quantization.layers): # ignore custom modules: Fire, Bottleneck, ..., high-level PyTorch modules if not hasattr(module, 'out_int_bits') or out_int_bits > module.out_int_bits: module.out_int_bits = out_int_bits # propagate info through the network set_param(output, 'int_bits', out_int_bits)