Пример #1
0
def _test_model(tmpdir, model_class, x, target_layers, percent, options=None, train=False, save=False):
    # save initialized model
    if not options:
        options = dict()
    model = model_class(**options)
    if train:
        model.train()
    else:
        model.eval()
    if save:
        PATH = str(tmpdir.join('model.onnx'))
        torch.onnx.export(model, x, PATH, verbose=False,
                          input_names=['input'],
                          output_names=['output'])

    # print(model)

    # pruning with Pruner
    graph = Graph(model, x)
    PATH = str(tmpdir.join('model.png'))
    graph.plot()
    plt.savefig(PATH)

    mask = NormMask(model, graph, target_layers, percent=percent)
    pruner = Pruner(model, x, target_layers, mask)
    pruner.apply_mask()
    info = pruner.apply_rebuild()

    model(x)

    PATH = str(tmpdir.join('model.pth'))
    torch.save(model.state_dict(), PATH)

    # load pruned weight and run
    model = model_class(**options)
    model.load_state_dict(torch.load(PATH))
    if train:
        model.train()
    else:
        model.eval()
    model(x)

    # save pruned onnx
    if save:
        PATH = str(tmpdir.join('model_pruned.onnx'))
        torch.onnx.export(model, x, PATH, verbose=False,
                          input_names=['input'],
                          output_names=['output'])
Пример #2
0
def test_pruner():
    model = SimpleNet()
    x = np.ones((1, 1, 9, 9), dtype=np.float32)

    target_layers = ['conv1', 'conv2']
    percent = {'conv1': 0.6}
    default = 0.5

    graph = Graph(model, x)

    mask = NormMask(model, graph, target_layers, percent=percent, default=default)

    pruner = Pruner(model, x, target_layers, mask)

    print(model.count_params())

    pruner.mask()

    info = pruner.apply_rebuild()
    pruner.reinitialize()

    print(info)

    print(model.count_params())

    model(x)
Пример #3
0
    def __init__(self, model, args, target_layers,
                 pruning_rate, stop_trigger, pruning_rate_decay=1 / 8):
        """ Progressive Deep Neural Networks Acceleration via Soft Filter Pruning

        https://arxiv.org/abs/1808.07471

        Args:
            model (chainer.Chain):
            target_layers (list):
            pruning_rate (float): sparsity. target_layerで指定した全レイヤ一律 [0, 1) 大きいほど高圧縮
            pruning_rate_decay (float): pruning_rateのprogressiveな変化率を調整するパラメータ。論文では1/8がデフォルト
                pruning_rateの3/4のsparsityを学習のmax_iteration/epochの何%の位置に指定するか
            trigger (tuple): weightをzeroにする頻度 (500, 'iteration') のように指定する。論文では(1, 'epoch')がデフォルト
            stop_trigger (int): 学習の総iteration/epochを指定
        """

        if not enable_scipy:
            raise ImportError("please install scipy")

        self.model = model
        self.target_layers = target_layers
        self.pruning_rate = pruning_rate
        self.pruning_rate_decay = pruning_rate_decay
        self.stop_trigger = stop_trigger

        self.graph = Graph(model, args)

        initial_pruning_rate = 0.
        self.mask = NormMask(model, self.graph, target_layers, percent=initial_pruning_rate, norm='l2')

        self._pruning_rate_fn = self._init_pruning_rate_fn(pruning_rate,
                                                           pruning_rate_decay,
                                                           stop_trigger)
Пример #4
0
def pruning(model, args, target_layers, threshold, default=None):

    graph = Graph(model, args)
    mask = NormMask(model,
                    graph,
                    target_layers,
                    threshold=threshold,
                    default=default,
                    mask_layer='batchnorm')
    info = {}
    info['mask'] = mask()
    info['rebuild'] = rebuild(model, graph, target_layers)
    return info
Пример #5
0
def test_allnet():
    x = np.zeros((1, 3, 32, 32), dtype=np.float32)
    model = AllSupportedLayersNet()

    model(x)

    percent = 0.8
    target_layers = AllSupportedLayersNet.target_layers

    graph = Graph(model, x)
    mask = NormMask(model, graph, target_layers, percent=percent)
    pruner = Pruner(model, x, target_layers, mask)
    pruner.apply_mask()
    info = pruner.apply_rebuild()
    print(info)

    model(x)
Пример #6
0
def test_seq_out():

    class A(chainer.Chain):
        """N-output"""

        target_layers = [
            '/conv1',
            '/conv2',
        ]

        def __init__(self, n_class=10):
            super(A, self).__init__()
            with self.init_scope():
                self.conv1 = chainer.links.Convolution2D(None, 10, 3)
                self.bn1 = chainer.links.BatchNormalization(10)
                self.conv2 = chainer.links.Convolution2D(10, 3)
                self.bn2 = chainer.links.BatchNormalization(10)
                self.fc1 = chainer.links.Linear(None, n_class)
                self.fc2 = chainer.links.Linear(None, n_class)

        def __call__(self, x):
            h = self.conv1(x)
            h = self.bn1(h)
            h = self.conv2(h)
            h = self.bn2(h)
            h1 = self.fc1(h)
            h2 = self.fc2(h)
            return h1, h2

    x = np.zeros((1, 1, 32, 32), dtype=np.float32)
    model = A()

    model(x)

    percent = 0.8
    target_layers = A.target_layers

    graph = Graph(model, x)
    mask = NormMask(model, graph, target_layers, percent=percent)
    pruner = Pruner(model, x, target_layers, mask)
    pruner.apply_mask()
    info = pruner.apply_rebuild()
    pruner.reinitialize()

    model(x)
Пример #7
0
def test_no_target_layers():
    x = np.zeros((1, 3, 32, 32), dtype=np.float32)
    model = AllSupportedLayersNet()

    model(x)

    percent = 0.8
    target_layers = []  # empty!

    graph = Graph(model, x)
    mask = NormMask(model, graph, target_layers, percent=percent)
    pruner = Pruner(model, x, target_layers, mask)
    pruner.apply_mask()
    try:
        info = pruner.apply_rebuild()
        raise ValueError
    except ValueError:
        pass
Пример #8
0
def test_no_target_layers():
    x = torch.randn((1, 3, 32, 32), requires_grad=False)
    model = models.resnet18()
    model.eval()

    model(x)

    percent = 0.8
    target_layers = []  # empty!

    graph = Graph(model, x)
    mask = NormMask(model, graph, target_layers, percent=percent)
    pruner = Pruner(model, x, target_layers, mask)
    pruner.apply_mask()
    try:
        info = pruner.apply_rebuild()
        raise ValueError
    except ValueError:
        pass
Пример #9
0
def test_normmask():

    model = SimpleNet()
    x = np.ones((1, 1, 9, 9), dtype=np.float32)

    target_layers = ['/conv1', '/conv2']
    percent = {'/conv1': 0.8}
    default = 0.4

    graph = Graph(model, x)

    mask = NormMask(model, graph, target_layers, percent=percent, default=default)

    info = mask()

    assert info[0]['name'] == '/conv1'
    assert info[0]['before'] == 10
    assert info[0]['after'] == 2
    assert info[1]['name'] == '/conv2'
    assert info[1]['before'] == 10
    assert info[1]['after'] == 6
Пример #10
0
def pruning(model, args, target_conv_layers, threshold, default=None):
    """Apply mask and rebuild for Network Slimming

    Args:
        model (torch.nn.Module, chainer.Chain): target model.
        args: dummy inputs of target model.
        target_conv_layers (list[str]):
        threshold (float, dict): mask threshold for BatchNorm2d.weight.
        default (float, Optional): default threshold (available only if threshold is dict).

    Returns:
        dict: pruning runtime information

    """

    graph = Graph(model, args)
    mask = NormMask(model, graph, target_conv_layers, threshold=threshold, default=default,
                    mask_layer='batchnorm')
    info = {}
    info['mask'] = mask()
    info['rebuild'] = rebuild(model, graph, target_conv_layers)
    return info
Пример #11
0
def test_simplenet():
    net = SimpleNet()
    x = np.ones((1, 1, 9, 9), dtype=np.float32)

    net(x)

    assert net.conv1.W.shape == (3, 1, 3, 3)  # (oc, ic, kh, kw)
    assert net.bn1_1.gamma.shape == (3, )
    assert net.bn1_1.beta.shape == (3, )
    assert net.bn1_2.gamma.shape == (3, )
    assert net.bn1_2.beta.shape == (3, )
    assert net.bn1_3.gamma.shape == (3, )
    assert net.bn1_3.beta.shape == (3, )
    assert net.conv2.W.shape == (4, 3, 3, 3)  # (oc, ic, kh, kw)
    assert net.bn2.gamma.shape == (4, )
    assert net.bn2.beta.shape == (4, )
    assert net.fc.W.shape == (10, 4)  # (o, i)

    # mask
    target_channel = 1
    net.conv1.W.array[target_channel, ...] = 0

    graph = Graph(net, x)

    target_layers = ['/conv1']

    rebuild(net, graph, target_layers)

    assert net.conv1.W.shape == (2, 1, 3, 3)  # (oc, ic, kh, kw)
    assert net.bn1_1.gamma.shape == (2, )
    assert net.bn1_1.beta.shape == (2, )
    assert net.bn1_2.gamma.shape == (2, )
    assert net.bn1_2.beta.shape == (2, )
    assert net.bn1_3.gamma.shape == (2, )
    assert net.bn1_3.beta.shape == (2, )
    assert net.conv2.W.shape == (4, 2, 3, 3)  # (oc, ic, kh, kw)
    assert net.bn2.gamma.shape == (4, )
    assert net.bn2.beta.shape == (4, )
    assert net.fc.W.shape == (10, 4)