def test_run_mbm_with_perm(self): torch.manual_seed(0) W = torch.randn((16, 32, 3, 3)) W[::2, ::2] += 100.0 G = 4 gnd_in1, gnd_out1, cost1 = mask_utils.run_mbm(W, G, perm=None, num_iters=1) gnd_in2, gnd_out2, cost2 = mask_utils.run_mbm(W, G, perm="SORT") self.assertTrue(cost1 > cost2) C = mask_utils.get_criterion(W) C /= torch.norm(C) C = mask_utils._get_numpy(C) sum1 = 0 for ind_in, ind_out in zip(gnd_in1, gnd_out1): sum1 += (C[ind_out, :])[:, ind_in].sum() self.assertTrue(np.allclose(C.sum() - sum1, cost1, rtol=1e-4)) sum2 = 0 for ind_in, ind_out in zip(gnd_in2, gnd_out2): sum2 += (C[ind_out, :])[:, ind_in].sum() self.assertTrue(np.allclose(C.sum() - sum2, cost2, rtol=1e-4))
def draw_step_by_step(C, G, fp, total_cost, num_iters=100): """ The step-by-step algorithm illustration graph""" min_gs = [G - 1, G // 2, 0] fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(6, 6)) for idx, min_g in enumerate(min_gs): ind_in, ind_out = group_sort(C, G, num_iters=num_iters, min_g=min_g) C_ = C[ind_out, :][:, ind_in] cost = get_cost(C_, G) r, c = idx // 2, idx % 2 axes[r, c].set_title('$min_g = {} ({:.2f}\%)$'.format( min_g + 1, cost / total_cost * 100)) axes[r, c].matshow(C_, cmap=CMAP) axes[r, c].get_xaxis().set_ticks([]) axes[r, c].get_yaxis().set_ticks([]) # TODO: don't show these # result from run MBM gnd_in, gnd_out, cost = run_mbm(C, G, perm='GRPS', num_iters=num_iters) ind_in = [i for l in gnd_in for i in l] ind_out = [i for l in gnd_out for i in l] C_ = C[ind_out, :][:, ind_in] axes[1, 1].set_title('$MBM({:.2f}\%)$'.format( get_cost(C_, G) / total_cost * 100)) axes[1, 1].matshow(C_, cmap=CMAP) axes[1, 1].get_xaxis().set_ticks([]) axes[1, 1].get_yaxis().set_ticks([]) plt.tight_layout() fig.savefig(fp)
def test_run_mbm(self): """ GRPS N_S=10""" ten = torch.randn((32, 32, 3, 3)) ten[0:8, 0:8] += 100.0 ten[8:16, 8:16] += 100.0 ten[16:24, 16:24] += 100.0 ten[24:32, 24:32] += 100.0 G = 4 row_ind, col_ind, cost = mask_utils.run_mbm(ten, G) crit = ten[row_ind, col_ind, :, :].norm(dim=(2, 3)) # criterion should be around 300 in this case self.assertTrue(((crit - 300).abs() < 3).all())
def find_group_candidates(self, mod, relative=False, allow_depthwise=False, min_factor=None, max_groups=None, **kwargs): """ Find group number candidates in module. Note: use kwargs to pass additional requirements. """ assert isinstance(mod, MaskConv2d) C = mod.in_channels F = mod.out_channels W = mod.weight # common divisors Gs = list(sorted(factors(F).intersection(factors(C)))) if max_groups: while Gs[-1] > max_groups: # TODO refactorize Gs.pop() # assert Gs[0] == 1 # del Gs[0] # should be 1 if not allow_depthwise: if Gs[-1] == F and Gs[-1] == C: del Gs[-1] if min_factor is not None: # remove those factors that are below a threshold while len(Gs) >= 2: factor = min(F / Gs[-1], C / Gs[-1]) if factor < min_factor: Gs.pop() else: break # get the best cost for all group candidates costs = [] for G in Gs: _, _, cost = mask_utils.run_mbm(W, G, **kwargs) if relative: cost = cost / W.norm(dim=(2, 3)).sum().item() costs.append(cost) return Gs, costs
def find_group_candidates(self, mod, **kwargs): """ Find group number candidates in module. Note: use kwargs to pass additional requirements. """ W = model_utils.get_weight_parameter(mod) F, C = W.shape[:2] # common divisors Gs = list(sorted(factors(F).intersection(factors(C)))) del Gs[0] # should be 1 costs = [] for G in Gs: _, _, cost = mask_utils.run_mbm(W, G) costs.append(cost) return Gs, costs
def collect_random_stats(size, G, data_dir, resume=False, num_samples=100, num_iters=100, print_freq=100): """ Collect the statistics from randomly sampled test matrices. If resume is specified, we will load from the data file. """ # where the data file is stored fp = os.path.join( data_dir, 'random_stats_NI_{}_NS_{}.npy'.format(num_iters, num_samples)) # Decide how to deal with the stats data if resume: assert os.path.isfile(fp) ratios = np.load(fp) else: ratios = np.zeros(num_samples) # where to store result. for i in range(num_samples): if i % print_freq == 0: print('[{}/{}] Sampling ...'.format(i, num_samples)) C0, C = generate_test_matrix(size, G) gnd_in, gnd_out, cost = run_mbm(C, G, perm='GRPS', num_iters=num_iters) ind_in = [i for l in gnd_in for i in l] ind_out = [i for l in gnd_out for i in l] C_ = C[ind_out, :][:, ind_in] ratios[i] = get_cost(C_, G) / get_cost(C0, G) # save to file np.save(fp, ratios) return ratios
def get_next_cost(self, model, state_map, normalized=False, **kwargs): """ Get next cost map """ assert isinstance(model, nn.Module) next_costs = OrderedDict() for name, mod in model.named_modules(): if name in state_map: state = state_map[name] if len(state[1]) != state[0] + 1: # already at the end next_G = state[1][state[0] + 1] W = mod.weight # run the heuristic algorithm to calculate cost _, _, crit = mask_utils.run_mbm(W, next_G, normalized=normalized, **kwargs) if not normalized: crit /= W.norm(dim=(2, 3)).sum().item() next_costs[name] = 1 - crit else: next_costs[name] = -crit.item() return next_costs
def draw_model_stats(arch, grps, data_dir, num_iters=None): """ Draw the statistics of several models """ if not num_iters: num_iters = [1] fp = os.path.join( data_dir, 'model_stats_{}_NI_{}_G_{}.pdf'.format( arch, '-'.join([str(ni) for ni in num_iters]), '-'.join([str(g) for g in grps]))) print('Plot to file: {}'.format(fp)) fig, ax = plt.subplots(figsize=(5, 4)) print('Running on model {} ...'.format(arch)) model = utils.load_model(arch, 'imagenet', pretrained=True) results = {'num_iters': [], 'num_groups': [], 'ratio': []} for ni in num_iters: for G in grps: print('G = {} NI = {}'.format(G, ni)) mods = {} # Collect statistics for a single model for name, mod in model.named_modules(): if not isinstance(mod, nn.Conv2d): continue W = mod.weight F, C = W.shape[:2] if F % G != 0 or C % G != 0: continue C = W.norm(dim=(2, 3)).cpu().detach().numpy() gnd_in, gnd_out, cost = run_mbm(C, G, perm='GRPS', num_iters=ni) mods[name] = (cost, C.sum(), cost / C.sum() * 100) # print('{:30s}\t {:.2e}\t {:.2e}\t {:.2f}%'.format( # name, mods[name][0], mods[name][1], mods[name][2])) # Summarise results sum_cost = sum([val[0] for val in mods.values()]) total_cost = sum([val[1] for val in mods.values()]) results['num_iters'].append('$N_S={}$'.format(ni)) results['num_groups'].append('$G={}$'.format(G)) results['ratio'].append(sum_cost / total_cost * 100) df = pd.DataFrame(results) sns.barplot(x='num_groups', y='ratio', hue='num_iters', data=df) ax.legend() plt.tight_layout() fig.savefig(fp) df.to_csv(fp.replace('.pdf', '.csv'))