def _test_model(tmpdir, model_class, x, target_layers, percent, options=None, train=False, save=False): # save initialized model if not options: options = dict() model = model_class(**options) if train: model.train() else: model.eval() if save: PATH = str(tmpdir.join('model.onnx')) torch.onnx.export(model, x, PATH, verbose=False, input_names=['input'], output_names=['output']) # print(model) # pruning with Pruner graph = Graph(model, x) PATH = str(tmpdir.join('model.png')) graph.plot() plt.savefig(PATH) mask = NormMask(model, graph, target_layers, percent=percent) pruner = Pruner(model, x, target_layers, mask) pruner.apply_mask() info = pruner.apply_rebuild() model(x) PATH = str(tmpdir.join('model.pth')) torch.save(model.state_dict(), PATH) # load pruned weight and run model = model_class(**options) model.load_state_dict(torch.load(PATH)) if train: model.train() else: model.eval() model(x) # save pruned onnx if save: PATH = str(tmpdir.join('model_pruned.onnx')) torch.onnx.export(model, x, PATH, verbose=False, input_names=['input'], output_names=['output'])
def test_pruner(): model = SimpleNet() x = np.ones((1, 1, 9, 9), dtype=np.float32) target_layers = ['conv1', 'conv2'] percent = {'conv1': 0.6} default = 0.5 graph = Graph(model, x) mask = NormMask(model, graph, target_layers, percent=percent, default=default) pruner = Pruner(model, x, target_layers, mask) print(model.count_params()) pruner.mask() info = pruner.apply_rebuild() pruner.reinitialize() print(info) print(model.count_params()) model(x)
def __init__(self, model, args, target_layers, pruning_rate, stop_trigger, pruning_rate_decay=1 / 8): """ Progressive Deep Neural Networks Acceleration via Soft Filter Pruning https://arxiv.org/abs/1808.07471 Args: model (chainer.Chain): target_layers (list): pruning_rate (float): sparsity. target_layerで指定した全レイヤ一律 [0, 1) 大きいほど高圧縮 pruning_rate_decay (float): pruning_rateのprogressiveな変化率を調整するパラメータ。論文では1/8がデフォルト pruning_rateの3/4のsparsityを学習のmax_iteration/epochの何%の位置に指定するか trigger (tuple): weightをzeroにする頻度 (500, 'iteration') のように指定する。論文では(1, 'epoch')がデフォルト stop_trigger (int): 学習の総iteration/epochを指定 """ if not enable_scipy: raise ImportError("please install scipy") self.model = model self.target_layers = target_layers self.pruning_rate = pruning_rate self.pruning_rate_decay = pruning_rate_decay self.stop_trigger = stop_trigger self.graph = Graph(model, args) initial_pruning_rate = 0. self.mask = NormMask(model, self.graph, target_layers, percent=initial_pruning_rate, norm='l2') self._pruning_rate_fn = self._init_pruning_rate_fn(pruning_rate, pruning_rate_decay, stop_trigger)
def pruning(model, args, target_layers, threshold, default=None): graph = Graph(model, args) mask = NormMask(model, graph, target_layers, threshold=threshold, default=default, mask_layer='batchnorm') info = {} info['mask'] = mask() info['rebuild'] = rebuild(model, graph, target_layers) return info
def test_allnet(): x = np.zeros((1, 3, 32, 32), dtype=np.float32) model = AllSupportedLayersNet() model(x) percent = 0.8 target_layers = AllSupportedLayersNet.target_layers graph = Graph(model, x) mask = NormMask(model, graph, target_layers, percent=percent) pruner = Pruner(model, x, target_layers, mask) pruner.apply_mask() info = pruner.apply_rebuild() print(info) model(x)
def test_seq_out(): class A(chainer.Chain): """N-output""" target_layers = [ '/conv1', '/conv2', ] def __init__(self, n_class=10): super(A, self).__init__() with self.init_scope(): self.conv1 = chainer.links.Convolution2D(None, 10, 3) self.bn1 = chainer.links.BatchNormalization(10) self.conv2 = chainer.links.Convolution2D(10, 3) self.bn2 = chainer.links.BatchNormalization(10) self.fc1 = chainer.links.Linear(None, n_class) self.fc2 = chainer.links.Linear(None, n_class) def __call__(self, x): h = self.conv1(x) h = self.bn1(h) h = self.conv2(h) h = self.bn2(h) h1 = self.fc1(h) h2 = self.fc2(h) return h1, h2 x = np.zeros((1, 1, 32, 32), dtype=np.float32) model = A() model(x) percent = 0.8 target_layers = A.target_layers graph = Graph(model, x) mask = NormMask(model, graph, target_layers, percent=percent) pruner = Pruner(model, x, target_layers, mask) pruner.apply_mask() info = pruner.apply_rebuild() pruner.reinitialize() model(x)
def test_no_target_layers(): x = np.zeros((1, 3, 32, 32), dtype=np.float32) model = AllSupportedLayersNet() model(x) percent = 0.8 target_layers = [] # empty! graph = Graph(model, x) mask = NormMask(model, graph, target_layers, percent=percent) pruner = Pruner(model, x, target_layers, mask) pruner.apply_mask() try: info = pruner.apply_rebuild() raise ValueError except ValueError: pass
def test_no_target_layers(): x = torch.randn((1, 3, 32, 32), requires_grad=False) model = models.resnet18() model.eval() model(x) percent = 0.8 target_layers = [] # empty! graph = Graph(model, x) mask = NormMask(model, graph, target_layers, percent=percent) pruner = Pruner(model, x, target_layers, mask) pruner.apply_mask() try: info = pruner.apply_rebuild() raise ValueError except ValueError: pass
def test_normmask(): model = SimpleNet() x = np.ones((1, 1, 9, 9), dtype=np.float32) target_layers = ['/conv1', '/conv2'] percent = {'/conv1': 0.8} default = 0.4 graph = Graph(model, x) mask = NormMask(model, graph, target_layers, percent=percent, default=default) info = mask() assert info[0]['name'] == '/conv1' assert info[0]['before'] == 10 assert info[0]['after'] == 2 assert info[1]['name'] == '/conv2' assert info[1]['before'] == 10 assert info[1]['after'] == 6
def pruning(model, args, target_conv_layers, threshold, default=None): """Apply mask and rebuild for Network Slimming Args: model (torch.nn.Module, chainer.Chain): target model. args: dummy inputs of target model. target_conv_layers (list[str]): threshold (float, dict): mask threshold for BatchNorm2d.weight. default (float, Optional): default threshold (available only if threshold is dict). Returns: dict: pruning runtime information """ graph = Graph(model, args) mask = NormMask(model, graph, target_conv_layers, threshold=threshold, default=default, mask_layer='batchnorm') info = {} info['mask'] = mask() info['rebuild'] = rebuild(model, graph, target_conv_layers) return info
def test_simplenet(): net = SimpleNet() x = np.ones((1, 1, 9, 9), dtype=np.float32) net(x) assert net.conv1.W.shape == (3, 1, 3, 3) # (oc, ic, kh, kw) assert net.bn1_1.gamma.shape == (3, ) assert net.bn1_1.beta.shape == (3, ) assert net.bn1_2.gamma.shape == (3, ) assert net.bn1_2.beta.shape == (3, ) assert net.bn1_3.gamma.shape == (3, ) assert net.bn1_3.beta.shape == (3, ) assert net.conv2.W.shape == (4, 3, 3, 3) # (oc, ic, kh, kw) assert net.bn2.gamma.shape == (4, ) assert net.bn2.beta.shape == (4, ) assert net.fc.W.shape == (10, 4) # (o, i) # mask target_channel = 1 net.conv1.W.array[target_channel, ...] = 0 graph = Graph(net, x) target_layers = ['/conv1'] rebuild(net, graph, target_layers) assert net.conv1.W.shape == (2, 1, 3, 3) # (oc, ic, kh, kw) assert net.bn1_1.gamma.shape == (2, ) assert net.bn1_1.beta.shape == (2, ) assert net.bn1_2.gamma.shape == (2, ) assert net.bn1_2.beta.shape == (2, ) assert net.bn1_3.gamma.shape == (2, ) assert net.bn1_3.beta.shape == (2, ) assert net.conv2.W.shape == (4, 2, 3, 3) # (oc, ic, kh, kw) assert net.bn2.gamma.shape == (4, ) assert net.bn2.beta.shape == (4, ) assert net.fc.W.shape == (10, 4)