예제 #1
0
    def test_run_mbm_with_perm(self):
        torch.manual_seed(0)
        W = torch.randn((16, 32, 3, 3))
        W[::2, ::2] += 100.0

        G = 4
        gnd_in1, gnd_out1, cost1 = mask_utils.run_mbm(W,
                                                      G,
                                                      perm=None,
                                                      num_iters=1)
        gnd_in2, gnd_out2, cost2 = mask_utils.run_mbm(W, G, perm="SORT")
        self.assertTrue(cost1 > cost2)

        C = mask_utils.get_criterion(W)
        C /= torch.norm(C)
        C = mask_utils._get_numpy(C)
        sum1 = 0
        for ind_in, ind_out in zip(gnd_in1, gnd_out1):
            sum1 += (C[ind_out, :])[:, ind_in].sum()
        self.assertTrue(np.allclose(C.sum() - sum1, cost1, rtol=1e-4))

        sum2 = 0
        for ind_in, ind_out in zip(gnd_in2, gnd_out2):
            sum2 += (C[ind_out, :])[:, ind_in].sum()
        self.assertTrue(np.allclose(C.sum() - sum2, cost2, rtol=1e-4))
예제 #2
0
def draw_step_by_step(C, G, fp, total_cost, num_iters=100):
    """ The step-by-step algorithm illustration graph"""
    min_gs = [G - 1, G // 2, 0]
    fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(6, 6))

    for idx, min_g in enumerate(min_gs):
        ind_in, ind_out = group_sort(C, G, num_iters=num_iters, min_g=min_g)
        C_ = C[ind_out, :][:, ind_in]
        cost = get_cost(C_, G)

        r, c = idx // 2, idx % 2
        axes[r, c].set_title('$min_g = {} ({:.2f}\%)$'.format(
            min_g + 1, cost / total_cost * 100))
        axes[r, c].matshow(C_, cmap=CMAP)
        axes[r, c].get_xaxis().set_ticks([])
        axes[r, c].get_yaxis().set_ticks([])

    # TODO: don't show these
    # result from run MBM
    gnd_in, gnd_out, cost = run_mbm(C, G, perm='GRPS', num_iters=num_iters)
    ind_in = [i for l in gnd_in for i in l]
    ind_out = [i for l in gnd_out for i in l]
    C_ = C[ind_out, :][:, ind_in]

    axes[1, 1].set_title('$MBM({:.2f}\%)$'.format(
        get_cost(C_, G) / total_cost * 100))
    axes[1, 1].matshow(C_, cmap=CMAP)
    axes[1, 1].get_xaxis().set_ticks([])
    axes[1, 1].get_yaxis().set_ticks([])

    plt.tight_layout()
    fig.savefig(fp)
예제 #3
0
    def test_run_mbm(self):
        """ GRPS N_S=10"""
        ten = torch.randn((32, 32, 3, 3))
        ten[0:8, 0:8] += 100.0
        ten[8:16, 8:16] += 100.0
        ten[16:24, 16:24] += 100.0
        ten[24:32, 24:32] += 100.0

        G = 4
        row_ind, col_ind, cost = mask_utils.run_mbm(ten, G)
        crit = ten[row_ind, col_ind, :, :].norm(dim=(2, 3))
        # criterion should be around 300 in this case
        self.assertTrue(((crit - 300).abs() < 3).all())
예제 #4
0
    def find_group_candidates(self,
                              mod,
                              relative=False,
                              allow_depthwise=False,
                              min_factor=None,
                              max_groups=None,
                              **kwargs):
        """ Find group number candidates in module.
    
    Note: use kwargs to pass additional requirements.
    """
        assert isinstance(mod, MaskConv2d)
        C = mod.in_channels
        F = mod.out_channels
        W = mod.weight

        # common divisors
        Gs = list(sorted(factors(F).intersection(factors(C))))
        if max_groups:
            while Gs[-1] > max_groups:  # TODO refactorize
                Gs.pop()

        # assert Gs[0] == 1
        # del Gs[0]  # should be 1
        if not allow_depthwise:
            if Gs[-1] == F and Gs[-1] == C:
                del Gs[-1]
        if min_factor is not None:
            # remove those factors that are below a threshold
            while len(Gs) >= 2:
                factor = min(F / Gs[-1], C / Gs[-1])
                if factor < min_factor:
                    Gs.pop()
                else:
                    break

        # get the best cost for all group candidates
        costs = []
        for G in Gs:
            _, _, cost = mask_utils.run_mbm(W, G, **kwargs)
            if relative:
                cost = cost / W.norm(dim=(2, 3)).sum().item()

            costs.append(cost)

        return Gs, costs
예제 #5
0
  def find_group_candidates(self, mod, **kwargs):
    """ Find group number candidates in module.
    
    Note: use kwargs to pass additional requirements.
    """
    W = model_utils.get_weight_parameter(mod)
    F, C = W.shape[:2]

    # common divisors
    Gs = list(sorted(factors(F).intersection(factors(C))))
    del Gs[0]  # should be 1

    costs = []
    for G in Gs:
      _, _, cost = mask_utils.run_mbm(W, G)
      costs.append(cost)

    return Gs, costs
예제 #6
0
def collect_random_stats(size,
                         G,
                         data_dir,
                         resume=False,
                         num_samples=100,
                         num_iters=100,
                         print_freq=100):
    """ Collect the statistics from randomly sampled test matrices.
  
    If resume is specified, we will load from the data file. 
  """
    # where the data file is stored
    fp = os.path.join(
        data_dir,
        'random_stats_NI_{}_NS_{}.npy'.format(num_iters, num_samples))

    # Decide how to deal with the stats data
    if resume:
        assert os.path.isfile(fp)
        ratios = np.load(fp)
    else:
        ratios = np.zeros(num_samples)  # where to store result.

        for i in range(num_samples):
            if i % print_freq == 0:
                print('[{}/{}] Sampling ...'.format(i, num_samples))

            C0, C = generate_test_matrix(size, G)

            gnd_in, gnd_out, cost = run_mbm(C,
                                            G,
                                            perm='GRPS',
                                            num_iters=num_iters)
            ind_in = [i for l in gnd_in for i in l]
            ind_out = [i for l in gnd_out for i in l]
            C_ = C[ind_out, :][:, ind_in]

            ratios[i] = get_cost(C_, G) / get_cost(C0, G)

    # save to file
    np.save(fp, ratios)

    return ratios
예제 #7
0
    def get_next_cost(self, model, state_map, normalized=False, **kwargs):
        """ Get next cost map """
        assert isinstance(model, nn.Module)

        next_costs = OrderedDict()
        for name, mod in model.named_modules():
            if name in state_map:
                state = state_map[name]
                if len(state[1]) != state[0] + 1:  # already at the end
                    next_G = state[1][state[0] + 1]
                    W = mod.weight
                    # run the heuristic algorithm to calculate cost
                    _, _, crit = mask_utils.run_mbm(W,
                                                    next_G,
                                                    normalized=normalized,
                                                    **kwargs)
                    if not normalized:
                        crit /= W.norm(dim=(2, 3)).sum().item()
                        next_costs[name] = 1 - crit
                    else:
                        next_costs[name] = -crit.item()

        return next_costs
예제 #8
0
def draw_model_stats(arch, grps, data_dir, num_iters=None):
    """ Draw the statistics of several models """
    if not num_iters:
        num_iters = [1]
    fp = os.path.join(
        data_dir, 'model_stats_{}_NI_{}_G_{}.pdf'.format(
            arch, '-'.join([str(ni) for ni in num_iters]),
            '-'.join([str(g) for g in grps])))

    print('Plot to file: {}'.format(fp))

    fig, ax = plt.subplots(figsize=(5, 4))

    print('Running on model {} ...'.format(arch))

    model = utils.load_model(arch, 'imagenet', pretrained=True)
    results = {'num_iters': [], 'num_groups': [], 'ratio': []}

    for ni in num_iters:

        for G in grps:
            print('G = {} NI = {}'.format(G, ni))

            mods = {}

            # Collect statistics for a single model
            for name, mod in model.named_modules():
                if not isinstance(mod, nn.Conv2d):
                    continue

                W = mod.weight
                F, C = W.shape[:2]

                if F % G != 0 or C % G != 0:
                    continue

                C = W.norm(dim=(2, 3)).cpu().detach().numpy()
                gnd_in, gnd_out, cost = run_mbm(C,
                                                G,
                                                perm='GRPS',
                                                num_iters=ni)
                mods[name] = (cost, C.sum(), cost / C.sum() * 100)

                # print('{:30s}\t {:.2e}\t {:.2e}\t {:.2f}%'.format(
                #     name, mods[name][0], mods[name][1], mods[name][2]))

            # Summarise results
            sum_cost = sum([val[0] for val in mods.values()])
            total_cost = sum([val[1] for val in mods.values()])

            results['num_iters'].append('$N_S={}$'.format(ni))
            results['num_groups'].append('$G={}$'.format(G))
            results['ratio'].append(sum_cost / total_cost * 100)

    df = pd.DataFrame(results)
    sns.barplot(x='num_groups', y='ratio', hue='num_iters', data=df)

    ax.legend()
    plt.tight_layout()
    fig.savefig(fp)

    df.to_csv(fp.replace('.pdf', '.csv'))