def prune_model(model, amount, prune_mask, method=prune.L1Unstructured):
    model.to('cpu')
    model.mask_to_device('cpu')
    for name, module in model.named_modules():  # re-apply current mask to the model
        if isinstance(module, torch.nn.Linear):
#            if name is not "fc4":
             prune.custom_from_mask(module, "weight", prune_mask[name])

    parameters_to_prune = (
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
        (model.fc4, 'weight'),
    )
    prune.global_unstructured(  # global prune the model
        parameters_to_prune,
        pruning_method=method,
        amount=amount,
    )

    for name, module in model.named_modules():  # make pruning "permanant" by removing the orig/mask values from the state dict
        if isinstance(module, torch.nn.Linear):
#            if name is not "fc4":
            torch.logical_and(module.weight_mask, prune_mask[name],
                              out=prune_mask[name])  # Update progress mask
            prune.remove(module, 'weight')  # remove all those values in the global pruned model

    return model
Example #2
0
def prune_fc(layer, paths) -> None:
    names = [name for (name, content) in layer.named_parameters() 
             if "weight" in name]
    contents = [content for (name, content) in layer.named_parameters() 
                if "weight" in name]
    for name, content in zip(names, contents):
        prune.custom_from_mask(layer, name=name, mask=paths.get_fc())
Example #3
0
def prune_conv_layer(model,
                     layer_index,
                     filter_index,
                     criterion='lrp',
                     cuda_flag=False):
    ''' input parameters
    1. model: 현재 모델
    2. layer_index: 자르고자 하는 layer index
    3. filter_index: 자르고자 하는 layer의 filter index
    '''

    conv = dict(model.named_modules())[layer_index]

    if not hasattr(conv, "output_mask"):
        # Instantiate output mask tensor of shape (num_output_channels, )
        conv.output_mask = torch.ones(conv.weight.shape[0])

    # Make sure the filter was not pruned before
    assert conv.output_mask[filter_index] != 0

    conv.output_mask[filter_index] = 0

    mask_weight = conv.output_mask.view(-1, 1, 1, 1).expand_as(conv.weight)
    torch_prune.custom_from_mask(conv, "weight", mask_weight)

    if conv.bias is not None:
        mask_bias = conv.output_mask
        torch_prune.custom_from_mask(conv, "bias", mask_bias)

    if cuda_flag:
        conv.weight = conv.weight.cuda()
        # conv.module.bias = conv.module.bias.cuda()

    return model
Example #4
0
    def global_unstructured(
        self, pruning_method: torch.nn.utils.prune.BasePruningMethod, **kwargs
    ):
        """Based on
        https://pytorch.org/docs/stable/_modules/torch/nn/utils/prune.html#global_unstructured.
        Modify scores depending on the algorithm.
        """
        assert isinstance(self.params_to_prune, Iterable)

        scores = self.get_prune_score()

        t = torch.nn.utils.parameters_to_vector(scores)
        # similarly, flatten the masks (if they exist), or use a flattened vector
        # of 1s of the same dimensions as t
        default_mask = torch.nn.utils.parameters_to_vector(
            [
                getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
                for (module, name) in self.params_to_prune  # type: ignore
            ]
        )

        # use the canonical pruning methods to compute the new mask, even if the
        # parameter is now a flattened out version of `parameters`
        container = prune.PruningContainer()
        container._tensor_name = "temp"  # type: ignore
        method = pruning_method(**kwargs)
        method._tensor_name = "temp"  # type: ignore
        if method.PRUNING_TYPE != "unstructured":
            raise TypeError(
                'Only "unstructured" PRUNING_TYPE supported for '
                "the `pruning_method`. Found method {} of type {}".format(
                    pruning_method, method.PRUNING_TYPE
                )
            )

        container.add_pruning_method(method)

        # use the `compute_mask` method from `PruningContainer` to combine the
        # mask computed by the new method with the pre-existing mask
        final_mask = container.compute_mask(t, default_mask)

        # Pointer for slicing the mask to match the shape of each parameter
        pointer = 0
        for module, name in self.params_to_prune:  # type: ignore
            param = getattr(module, name)
            # The length of the parameter
            num_param = param.numel()
            # Slice the mask, reshape it
            param_mask = final_mask[pointer : pointer + num_param].view_as(param)
            # Assign the correct pre-computed mask to each parameter and add it
            # to the forward_pre_hooks like any other pruning method
            prune.custom_from_mask(module, name, param_mask)

            # Increment the pointer to continue slicing the final_mask
            pointer += num_param
Example #5
0
def prune_layer(layer, prune_rate, reset=True):
    """ Prune given 'layer' with 'prune_rate' and reset surviving weights to their initial values, if 'reset' is True.
    Calls with 'prune_rate'=0.0 do not remove weights. """
    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
        pruned_mask = prune_mask(layer, prune_rate)
        if reset:
            prune.remove(layer, name='weight')  # temporarily remove pruning
            layer.weight = nn.Parameter(
                layer.weight_init.clone())  # set weights to initial weights
        prune.custom_from_mask(layer, name='weight',
                               mask=pruned_mask)  # apply pruned mask
Example #6
0
    def setup_graph(self):
        # initialize the masks
        if self._mask_init_method == 'random':
            for i, mask in enumerate(self.masks):
                size = mask.size()
                self.masks[i] = torch.tensor(get_mask_random(size, self._default_sparsity), dtype=mask.dtype).to(device)

        # initialize masked weight
        for i, layer in enumerate(self.layers):
            if prune.is_pruned(layer):
                prune.remove(layer, 'weight')
            prune.custom_from_mask(layer, 'weight', self.masks[i])
Example #7
0
def prune_connections(model, paths, output_type="m1"):
    """
    Prune connections in model in accordance with the paths found by the
    algorithm.

    Parameters
    ----------
    model : Model
        Model which is to be pruned
    paths : Paths 
        Paths objects, specifies which paths in the model must be pruned
    output_type : str, optional
        Whether the cell type will be M1 or M2. The default is "m1".

    Returns
    -------
    model: Model
        Pruned model.

    """
    if output_type == "m1":  # Output type M1
        path = torch.from_numpy(paths.m1)
        prune.custom_from_mask(model.x2h, "weight", path)
        prune.custom_from_mask(model.h2h, "weight", path)

    elif output_type == "m2":  # Output type M2
        path = torch.from_numpy(paths.m2)
        prune.custom_from_mask(model.x2h, "weight", path)
        prune.custom_from_mask(model.h2h, "weight", path)
        
    else:
        raise ValueError("output_type kwarg must be m1 or m2")
    
    return model
Example #8
0
def prune_cell_m2(layer, paths) -> None:
    """Prune M2 cell. Since M2 cells only have one output, the paths 
    are adapted accordingly."""
    # Copy names, contents
    names = [name for (name, content) in layer.named_parameters() 
             if "weight" in name]
    contents = [content for (name, content) in layer.named_parameters() 
                if "weight" in name]
    # Pruning step
    for name, content in zip(names, contents):
        if "hh" in name:  # hh layer has 4*n_out, n_out dims
            prune.custom_from_mask(layer, name=name, 
                                   mask=paths.get_m2(True))
        else:  # ih layer has 4*n_out, n_in dims
            prune.custom_from_mask(layer, name=name, mask=paths.get_m2())
Example #9
0
def prune_cell_m1(layer, paths) -> None:
    """Prune M1 cell.
    
    Shape of contents entry is 4*hiddens, n_input + n_output
    """
    # Copy names, contents
    names = [name for (name, content) in layer.named_parameters() 
             if "weight" in name]
    contents = [content for (name, content) in layer.named_parameters() 
                if "weight" in name]
    
    # Do we also want to prune the biases? And do we even need contents?
    
    # Pruning step
    for name, content in zip(names, contents):
        prune.custom_from_mask(layer, name=name, mask=paths.get_m1())
Example #10
0
    def test_prune_conv_layer_correctly(self):
        """ Prune the mask for a pruned convolutional layer in one step.
        Should zero out the three weights (as ceil(12*0.2)=3) with the lowest magnitude. """
        # initialize conv layer with 16 given weights and pruned mask
        initial_weights = torch.tensor([
            1.2, -0.1, 1.2, 4.3, -2.1, -1.1, -0.8, 1.2, 0.5, 0.2, 0.4, 1.4,
            2.2, -0.8, 0.4, 0.9
        ]).view(2, 2, 2, 2)
        initial_mask = torch.tensor(
            [1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0.,
             1.]).view(2, 2, 2, 2)
        test_layer = nn.Conv2d(2, 2, kernel_size=2, padding=1)
        test_layer.weight = nn.Parameter(initial_weights.clone())
        test_layer.register_buffer('weight_init', initial_weights.clone())
        test_layer = prune.custom_from_mask(test_layer,
                                            name='weight',
                                            mask=initial_mask)

        mp.prune_layer(layer=test_layer, prune_rate=0.2)

        expected_weights = torch.tensor([
            1.2, -0., 1.2, 4.3, -2.1, -1.1, -0., 1.2, 0., 0., 0., 1.4, 2.2,
            -0., 0., 0.9
        ]).view(2, 2, 2, 2)
        self.assertIs((test_layer.weight == expected_weights).all().item(),
                      True)
Example #11
0
def pruning_generate(model, state_dict):

    parameters_to_prune = []
    for (name, m) in model.named_modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            m = prune.custom_from_mask(m,
                                       name='weight',
                                       mask=state_dict[name + ".weight_mask"])
Example #12
0
    def update_layer_mask(self, layer, layer_mask, noise_std=1e-5):
        score = self.grad_dict[layer].abs() + noise_std

        # Add noise for slight bit of randomness.
        score_drop = score * layer_mask
        layer_mask_dropped, n_prune = self.drop_minimum(score_drop, layer_mask)

        # Randomly revive n_prune many connections from non-existing connections.
        score_grow = score * (~layer_mask_dropped)
        layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, n_prune)

        # update the weight
        if prune.is_pruned(layer):
            prune.remove(layer, 'weight')
        prune.custom_from_mask(layer, 'weight', layer_mask)

        return layer_mask, new_mask
Example #13
0
def prune_model(model, index_list):
    mag_w_mask = torch.zeros(100, 512)
    for index in index_list:
        mag_w_mask[:, index] = 1
    module = model.fc[0]
    pruned_w = prune.custom_from_mask(module, name='weight', mask=mag_w_mask)
    model.fc[0] = pruned_w
    return model
Example #14
0
    def update_layer_mask(self, layer, layer_mask, noise_std=1e-5):
        layer_weight = layer.weight
        layer_grad = self.grad_dict[layer]

        # Remove weight smaller than adaptive threshold
        layer_mask_dropped, drop_n = sparsify_weight(layer_weight.abs(), layer_mask, self._drop_fraction)

        # Grow weight whose gradient larger than adaptive threshold
        score_grow = layer_grad * (~layer_mask_dropped)
        layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, int(drop_n*self._grow_fraction))

        # update the weight
        if prune.is_pruned(layer):
            prune.remove(layer, 'weight')
        prune.custom_from_mask(layer, 'weight', layer_mask)

        return layer_mask, new_mask
Example #15
0
    def update_layer_mask(self, layer, layer_mask, noise_std=1e-5):
        layer_weight = layer.weight

        # Add noise for slight bit of randomness and drop
        masked_weights = layer_mask * layer_weight
        score_drop = masked_weights.abs() + self._random_normal(layer_weight.size(), std=noise_std)
        layer_mask_dropped, n_prune = self.drop_minimum(score_drop, layer_mask)

        # Randomly revive n_prune many connections from non-existing connections.
        score_grow = self.grad_dict[layer].abs() * (~layer_mask_dropped) + self._random_uniform(layer_weight.size()) * noise_std
        layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, n_prune)

        # update the weight
        if prune.is_pruned(layer):
            prune.remove(layer, 'weight')
        prune.custom_from_mask(layer, 'weight', layer_mask)

        return layer_mask, new_mask
Example #16
0
 def load_state(self,
                modelpath='/raid/zhassylbekov/sungbae/model/initstate.pth'):
     #    umask = dict(self.u_embeddings.named_buffers())['weight_mask'].cpu()
     vmask = dict(self.v_embeddings.named_buffers())['weight_mask'].cpu()
     cuda_using = next(self.parameters()).is_cuda
     self = SkipGramModel(self.vocab_size, self.emb_dimension)
     if cuda_using:
         self.load_state_dict(torch.load(modelpath))
         #        prune.custom_from_mask(self.u_embeddings,name='weight',mask=umask)
         prune.custom_from_mask(self.v_embeddings,
                                name='weight',
                                mask=vmask)
         self.cuda()
     else:
         self.load_state_dict(torch.load(modelpath))
         #        prune.custom_from_mask(self.u_embeddings,name='weight',mask=umask)
         prune.custom_from_mask(self.v_embeddings,
                                name='weight',
                                mask=vmask)
Example #17
0
    def prune_step_change(self, pstep, prune_mode):
        cuda_using = next(self.parameters()).is_cuda
        #reimport fixed u and v weights
        ui, vi = self.load_weights()
        #fix current weights
        uc, vc = (self.u_embeddings.weight.data.clone().cpu(),
                  self.v_embeddings.weight.data.clone().cpu())
        #fix current masks
        if not list(self.u_embeddings.named_buffers()):
            #prune.identity(self.u_embeddings, name='weight')
            prune.identity(self.v_embeddings, name='weight')
        #umask = dict(self.u_embeddings.named_buffers())['weight_mask'].cpu()
        vmask = dict(self.v_embeddings.named_buffers())['weight_mask'].cpu()

        #u_temp = torch.nn.Embedding(self.vocab_size, self.emb_dimension)
        v_temp = torch.nn.Embedding(self.vocab_size, self.emb_dimension)

        if prune_mode == 'change':
            f = lambda x, y: x - y
        elif prune_mode == 'absolute change':
            f = lambda x, y: torch.abs(x) - torch.abs(y)
        else:
            f = lambda x, y: x
        # weights to be left must have higher function outputs
        #u_temp.weight.data.copy_(f(uc,ui))
        v_temp.weight.data.copy_(f(vc, vi))

        #prune.custom_from_mask(u_temp,name='weight',mask=umask)
        prune.custom_from_mask(v_temp, name='weight', mask=vmask)
        if cuda_using:
            #    u_temp.cuda()
            v_temp.cuda()
        #prune.l1_unstructured(u_temp, name='weight', amount=pstep)
        prune.l1_unstructured(v_temp, name='weight', amount=pstep)
        #checked, cuda <-> cpu crash DNE
        #u_temp.weight.data.copy_(uc)
        v_temp.weight.data.copy_(vc)

        #self.u_embeddings = u_temp
        self.v_embeddings = v_temp
Example #18
0
def prune_equal_fanin(
    model: torch.nn.Module,
    epoch: int,
    prune_epoch: int,
    k: int = 2,
    device: torch.device = torch.device('cpu')) -> torch.nn.Module:
    """
    Prune the linear layers of the network such that each neuron has the same fan-in.

    :param model: pytorch model.
    :param epoch: current training epoch.
    :param prune_epoch: training epoch when pruning needs to be applied.
    :param k: fan-in.
    :param device: cpu or cuda device.
    :return: Pruned model
    """
    if epoch != prune_epoch:
        return model

    model.eval()
    for i, module in enumerate(model.children()):
        # prune only Linear layers
        if isinstance(module, torch.nn.Linear):
            # create mask
            mask = torch.ones(module.weight.shape)
            # identify weights with the lowest absolute values
            param_absneg = -torch.abs(module.weight)
            idx = torch.topk(param_absneg, k=param_absneg.shape[1] - k,
                             dim=1)[1]
            for j in range(len(idx)):
                mask[j, idx[j]] = 0
            # prune
            mask = mask.to(device)
            prune.custom_from_mask(module, name="weight", mask=mask)
            # print(f"Pruned {k}/{module.weight.shape[1]} weights")

    return model
Example #19
0
    def test_prune_mask_for_linear_layer_correctly(self):
        """ Prune the mask for an unpruned linear layer in one step.
        Should zero out the two weights with the lowest magnitude. """
        # initialize linear layer with 10 given weights and unpruned mask
        initial_weights = torch.tensor([[1., -2., 3., -1.5, -3.],
                                        [-1., 2., -4., 0.5, 1.5]])
        test_layer = nn.Linear(2, 5)
        test_layer.weight = nn.Parameter(initial_weights.clone())
        test_layer = prune.custom_from_mask(test_layer,
                                            name='weight',
                                            mask=torch.ones_like(
                                                test_layer.weight))

        test_mask_pruned = mp.prune_mask(layer=test_layer, prune_rate=0.2)
        self.assertIs(
            test_mask_pruned.equal(
                torch.tensor([[0., 1., 1., 1., 1.], [1., 1., 1., 0., 1.]])),
            True)
Example #20
0
    def test_do_not_apply_init_weight_after_pruning_linear_layer(self):
        """ Generate, modify and prune an unpruned linear layer.
        Its weights should not be reset. """
        # initialize linear layer with 6 given weights and unpruned mask
        initial_weights = torch.tensor([[1., -2., 3.], [-4., 5., -6.]])
        test_layer = nn.Linear(2, 3)
        test_layer.weight = nn.Parameter(
            2 *
            initial_weights.clone())  # fake training, i.e. save modify weights
        test_layer.register_buffer('weight_init', initial_weights.clone())
        test_layer = prune.custom_from_mask(test_layer,
                                            name='weight',
                                            mask=torch.ones_like(
                                                test_layer.weight))

        mp.prune_layer(layer=test_layer, prune_rate=0.2, reset=False)

        expected_weights = torch.tensor([[0., -0., 6.], [-8., 10., -12.]])
        self.assertIs(test_layer.weight.equal(expected_weights), True)
Example #21
0
    def test_prune_linear_layer_correctly(self):
        """ Prune the mask for a pruned linear layer in one step.
        Should zero out the two weights (as ceil(8*0.2)=2) with the lowest magnitude. """
        # initialize linear layer with 10 given weights and pruned mask
        initial_weights = torch.tensor([[1., -2., 3., -1.5, -3.],
                                        [-1., 2., -4., 0.5, 1.5]])
        initial_mask = torch.tensor([[0., 1., 1., 1., 1.],
                                     [1., 1., 1., 0., 1.]])
        test_layer = nn.Linear(2, 5)
        test_layer.weight = nn.Parameter(initial_weights.clone())
        test_layer.register_buffer('weight_init', initial_weights.clone())
        test_layer = prune.custom_from_mask(test_layer,
                                            name='weight',
                                            mask=initial_mask)

        mp.prune_layer(layer=test_layer, prune_rate=0.2)

        expected_weights = torch.tensor([[0., -2., 3., -0., -3.],
                                         [-0., 2., -4., 0., 1.5]])
        self.assertIs(test_layer.weight.equal(expected_weights), True)
Example #22
0
    def test_prune_mask_for_conv_layer_correctly(self):
        """ Prune the mask for an unpruned convolutional layer in one step.
        Should zero out the two weights with the lowest magnitude. """
        # Initialize conv layer with 16 given weights and unpruned mask
        initial_weights = torch.tensor([
            1.2, -0.1, 1.2, 4.3, -2.1, -1.1, -0.8, 1.2, 0.5, 0.2, 0.4, 1.4,
            2.2, -0.8, 0.4, 0.9
        ]).view(2, 2, 2, 2)
        test_layer = nn.Conv2d(2, 2, kernel_size=2, padding=1)
        test_layer.weight = nn.Parameter(initial_weights.clone())
        test_layer = prune.custom_from_mask(test_layer,
                                            name='weight',
                                            mask=torch.ones_like(
                                                test_layer.weight))

        test_mask_pruned = mp.prune_mask(layer=test_layer, prune_rate=0.2)

        expected_mask = torch.tensor(
            [1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0.,
             1.]).view(2, 2, 2, 2)
        self.assertIs(test_mask_pruned.equal(expected_mask), True)
Example #23
0
 def __init__(self,
              in_planes,
              planes,
              num_bits,
              num_bits_weight,
              stride,
              type_prune,
              sparsity,
              layer_num,
              option='A'):
     super(ResNetBlock, self).__init__()
     if in_planes == 3:
         op = QConv2d(in_planes,
                      planes,
                      num_bits,
                      num_bits_weight,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False)
         if type_prune == 'channel':
             op = prune.ln_structured(op,
                                      name='weight',
                                      amount=sparsity,
                                      n=2,
                                      dim=0)
         elif type_prune == 'group':
             width = 4
             tmp_pruned = op.weight.data.clone()
             original_size = tmp_pruned.size()
             tmp_pruned = tmp_pruned.view(original_size[0], -1)
             append_size = width - tmp_pruned.shape[1] % width
             tmp_pruned = torch.cat(
                 (tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
             tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width)
             tmp_pruned = tmp_pruned.pow(2.0).mean(
                 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape)
             tmp = tmp_pruned.flatten()
             num = tmp.shape[0] * (1 - sparsity)
             top_k = torch.topk(tmp, int(num), sorted=True)
             threshold = top_k.values[-1]
             tmp_pruned = tmp_pruned.ge(threshold)
             tmp_pruned = tmp_pruned.view(original_size[0], -1)
             tmp_pruned = tmp_pruned[:, 0:op.weight.data[0].nelement()]
             tmp_pruned = tmp_pruned.contiguous().view(original_size)
             op = prune.custom_from_mask(op, name='weight', mask=tmp_pruned)
         self.add_module("conv", op)
         bn_op = nn.BatchNorm2d(planes)
         self.add_module("bn", bn_op)
         self.add_module("relu", nn.ReLU(inplace=True))
     elif in_planes == 1:
         self.add_module("avg_pool", nn.AvgPool2d(kernel_size=8, stride=1))
         self.add_module("flatten", Flatten())
         op = nn.Linear(in_features=64, out_features=10)
         if type_prune == 'channel':
             op = prune.ln_structured(op,
                                      name='weight',
                                      amount=sparsity,
                                      n=2,
                                      dim=0)
         elif type_prune == 'group':
             width = 4
             tmp_pruned = op.weight.data.clone()
             original_size = tmp_pruned.size()
             tmp_pruned = tmp_pruned.view(original_size[0], -1)
             append_size = width - tmp_pruned.shape[1] % width
             tmp_pruned = torch.cat(
                 (tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
             tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width)
             tmp_pruned = tmp_pruned.pow(2.0).mean(
                 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape)
             tmp = tmp_pruned.flatten()
             num = tmp.shape[0] * (1 - sparsity)
             top_k = torch.topk(tmp, int(num), sorted=True)
             threshold = top_k.values[-1]
             tmp_pruned = tmp_pruned.ge(threshold)
             tmp_pruned = tmp_pruned.view(original_size[0], -1)
             tmp_pruned = tmp_pruned[:, 0:op.weight.data[0].nelement()]
             tmp_pruned = tmp_pruned.contiguous().view(original_size)
             op = prune.custom_from_mask(op, name='weight', mask=tmp_pruned)
         self.add_module("fc", op)
     else:
         op = BasicBlock(in_planes, planes, num_bits, num_bits_weight,
                         stride, type_prune, sparsity, layer_num, option)
         self.add_module("conv", op)
Example #24
0
def setup_masks(layer):
    """ Setup a mask of ones for all linear and convolutional layers. """
    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
        prune.custom_from_mask(layer,
                               name='weight',
                               mask=torch.ones_like(layer.weight))
Example #25
0
def main(args):
    save_folder = args.save_folder

    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(model_folder)

    setattr(args, 'model_folder', model_folder)

    logger = create_logger(model_folder, 'train', 'info')
    print_args(args, logger)

    
    # seed
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic=True
        torch.backends.cudnn.benchmark=False

    if "ResNet" in args.model :
        depth_ = args.model.split('-')[1]
        
        # when it is dst, resnet block is special
        if args.prune_method!='dst' : p_type=None

        res_dict = { '20' : resnet20(num_classes=int(args.dataset.split('-')[1]), prune_type = p_type),
                '32' : resnet32(num_classes=int(args.dataset.split('-')[1]), prune_type = p_type),
                '44': resnet44(num_classes=int(args.dataset.split('-')[1]), prune_type = p_type),
                '56': resnet56(num_classes=int(args.dataset.split('-')[1]), prune_type = p_type)
                }
            
        net = res_dict[depth_]


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net.to(device)

    # set trainer
    trainer = Trainer(args, logger)

    # loss
    loss = nn.CrossEntropyLoss()

    # dataloader
    kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    if args.dataset == 'cifar-10':
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.Pad(4),
                            transforms.RandomCrop(32),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                        ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                        ])),
            batch_size=args.batch_size, shuffle=False, **kwargs)
    elif args.dataset == 'cifar-100':
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.Pad(4),
                            transforms.RandomCrop(32),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                        ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                        ])),
            batch_size=args.batch_size, shuffle=False, **kwargs)

    # optimizer & scheduler
    optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1)

    # pruning
    if args.prune_method=='global':
        if args.prune_type=='group':
            tmps = []
            for n,conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1]==3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    append_size = 4 - tmp_pruned.shape[1] % 4
                    tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, 4)
                    tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand(tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0]*(1 - args.sparsity)#sparsity 0.2
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n,conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1]==3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    append_size = 4 - tmp_pruned.shape[1] % 4
                    tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, 4)
                    tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand(tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmp_pruned = tmp_pruned.ge(threshold)
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(original_size)

                    prune.custom_from_mask(conv, name='weight', mask=tmp_pruned)
        elif args.prune_type =='filter':
            tmps = []
            for n,conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1]==3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0]*(1 - args.sparsity)#sparsity 0.5
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n,conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1]==3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmp_pruned = tmp_pruned.ge(threshold)
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(original_size)

                    prune.custom_from_mask(conv, name='weight', mask=tmp_pruned)        

        print(f'model pruned!!(sparsity : {args.sparsity : .2f}, prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning')
    
    elif args.prune_method=='uniform':
        assert False, 'uniform code is not ready'

    elif args.prune_method =='dst':
        print(f'model pruned!!(prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning')

    
    # Training
    trainer.train(net, loss, device, train_loader, test_loader, optimizer=optimizer, scheduler=scheduler)
Example #26
0
def main(args):
    save_folder = args.save_folder

    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(model_folder)

    setattr(args, 'model_folder', model_folder)

    logger = create_logger(model_folder, 'train', 'info')
    print_args(args, logger)

    # seed
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # ResNet
    if "ResNet" in args.model:
        depth_ = args.model.split('-')[1]

        # when it is dst, resnet block is special
        if args.prune_method != 'dst': p_type = None

        res_dict = {
            '18':
            resnet18(pretrained=args.pretrained,
                     progress=True,
                     prune_type=p_type),
            '34':
            resnet34(pretrained=args.pretrained,
                     progress=True,
                     prune_type=p_type),
            '50':
            resnet50(pretrained=args.pretrained,
                     progress=True,
                     prune_type=p_type),
            '101':
            resnet101(pretrained=args.pretrained,
                      progress=True,
                      prune_type=p_type)
        }

        net = res_dict[depth_]

    #elif 'efficientnet' in args.model:
    #    net = EfficientNet.from_pretrained(args.model)
    elif args.model == 'efficientnet_b0':
        print('efficientnet-b0 load...')
        net = efficientnet_b0(pretrained=args.pretrained)

    # MobileNet
    elif args.model == "mobilenetv3-large-1.0":
        print('mobilenetv3-large-1.0')
        net = mobilenetv3_large_100(pretrained=args.pretrained)

    elif args.model == 'once-mobilenetv3-large-1.0':
        print('once-mobilenetv3-large-1.0')
        net, image_size = ofa_specialized(
            'note8_lat@[email protected]_finetune@25', pretrained=args.pretrained)

    elif args.model == 'mobilenetv2-120d':
        print('mobilenetv2-120d load...')
        net = mobilenetv2_120d(pretrained=args.pretrained)

    # conv1 trainable
    if args.conv1_not_train:
        print('conv1 weight not train')
        if args.model == "mobilenetv3-large-1.0":
            for param in net.conv_stem.parameters():
                param.requires_grad = False
        elif "ResNet" in args.model:
            for param in net.conv1.parameters():
                param.requires_grad = False

        else:
            assert (False, 'not ready')

    # custom pretrain path
    if args.pretrain_path:
        print('load custom pretrain weight...')
        net.load_state_dict(torch.load(args.pretrain_path))

    net2 = copy.deepcopy(net)  # for save removed_models
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = nn.DataParallel(net)
    net.to(device)

    # KD
    if args.KD:
        print('knowledge distillation model load!')
        teacher_net, image_size = ofa_specialized(
            'flops@[email protected]_finetune@75', pretrained=True)  # 79.6%
        teacher_net = nn.DataParallel(teacher_net)
        teacher_net.to(device)

    # set trainer
    if args.KD:
        trainer = Trainer_KD(args, logger)
    else:
        trainer = Trainer(args, logger)

    # loss
    loss = nn.CrossEntropyLoss()

    # dataloader
    if args.model != 'once-mobilenetv3-large-1': image_size = 224
    if args.dataset == 'imagenet':
        train_loader = torch.utils.data.DataLoader(
            datasets.ImageNet(
                '/data/imagenet/',
                split='train',
                download=False,
                transform=transforms.Compose([
                    transforms.RandomSizedCrop(image_size),
                    transforms.RandomHorizontalFlip(),  #ImageNetPolicy(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406),
                                         (0.229, 0.224, 0.225))
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_worker,
            pin_memory=True)
        #
        test_loader = torch.utils.data.DataLoader(datasets.ImageNet(
            '/data/imagenet/',
            split='val',
            download=False,
            transform=transforms.Compose([
                transforms.Resize(int(image_size / 0.875)),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.num_worker,
                                                  pin_memory=True)

    # optimizer & scheduler
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.scheduler == 'multistep':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=eval(args.multi_step_epoch),
            gamma=args.multi_step_gamma)
    elif args.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer,
            mode='max',
            patience=3,
            verbose=True,
            factor=0.3,
            threshold=1e-4,
            min_lr=1e-6)

    # pruning
    if args.prune_method == 'global':
        if args.prune_type == 'group_filter':
            tmps = []
            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()  # (out, ch, h, w)
                    tmp_pruned = tmp_pruned.view(original_size[0],
                                                 -1)  # (out, inp)
                    #append_size = 4 - tmp_pruned.shape[1] % 4
                    #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1,
                                                 args.block_size)  # out, -1, 4
                    tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand(
                        tmp_pruned.shape)  # out, -1, 4
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0] * (1 - args.sparsity)  #sparsity 0.2
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    #append_size = 4 - tmp_pruned.shape[1] % 4
                    #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1,
                                                 args.block_size)  # out, -1, 4
                    tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand(
                        tmp_pruned.shape)  # out,-1, 4
                    tmp_pruned = tmp_pruned.ge(threshold)
                    tmp_pruned = tmp_pruned.view(original_size[0],
                                                 -1)  # out, inp
                    #tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(
                        original_size)  # out, ch, h, w

                    prune.custom_from_mask(conv,
                                           name='weight',
                                           mask=tmp_pruned)

        elif args.prune_type == 'group_channel':
            tmps = []
            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()  # (out, ch, h, w)
                    tmp_pruned = tmp_pruned.view(original_size[0],
                                                 -1)  # (out, inp)
                    #append_size = 4 - tmp_pruned.shape[1] % 4
                    #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(
                        -1, args.block_size, tmp_pruned.shape[1])  # out, -1, 4
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(
                        tmp_pruned.shape)  # out, -1, 4
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0] * (1 - args.sparsity)  #sparsity 0.2
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    #append_size = 4 - tmp_pruned.shape[1] % 4
                    #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
                    tmp_pruned = tmp_pruned.view(
                        -1, args.block_size, tmp_pruned.shape[1])  # out, -1, 4
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(
                        tmp_pruned.shape)  # out,-1, 4
                    tmp_pruned = tmp_pruned.ge(threshold)
                    #tmp_pruned = tmp_pruned.view(original_size[0], -1) # out, inp
                    #tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(
                        original_size)  # out, ch, h, w

                    prune.custom_from_mask(conv,
                                           name='weight',
                                           mask=tmp_pruned)

        elif args.prune_type == 'filter':
            tmps = []
            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(
                        tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0] * (1 - args.sparsity)  #sparsity 0.5
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(
                        tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmp_pruned = tmp_pruned.ge(threshold)
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned[:,
                                            0:conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(original_size)

                    prune.custom_from_mask(conv,
                                           name='weight',
                                           mask=tmp_pruned)

        elif args.prune_type == 'channel':
            tmps = []
            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(0, keepdim=True).expand(
                        tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmps.append(tmp)

            tmps = torch.cat(tmps)
            num = tmps.shape[0] * (1 - args.sparsity)  #sparsity 0.5
            top_k = torch.topk(tmps, int(num), sorted=True)
            threshold = top_k.values[-1]

            for n, conv in enumerate(net.modules()):
                if isinstance(conv, nn.Conv2d):
                    if conv.weight.shape[1] <= 3:
                        continue
                    tmp_pruned = conv.weight.data.clone()
                    original_size = tmp_pruned.size()
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned.abs().mean(0, keepdim=True).expand(
                        tmp_pruned.shape)
                    tmp = tmp_pruned.flatten()
                    tmp_pruned = tmp_pruned.ge(threshold)
                    tmp_pruned = tmp_pruned.view(original_size[0], -1)
                    tmp_pruned = tmp_pruned[:,
                                            0:conv.weight.data[0].nelement()]
                    tmp_pruned = tmp_pruned.contiguous().view(original_size)

                    prune.custom_from_mask(conv,
                                           name='weight',
                                           mask=tmp_pruned)
        print(
            f'model pruned!!(sparsity : {args.sparsity : .2f}, prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning'
        )

    elif args.prune_method == 'uniform':
        assert False, 'uniform code is not ready'

    elif args.prune_method == 'dst':
        print(
            f'model pruned!!(prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning'
        )

    elif args.prune_method == None:
        print('Not pruned model training started!')

    # Training
    if args.KD:
        trainer.train(net,
                      teacher_net,
                      loss,
                      device,
                      train_loader,
                      test_loader,
                      optimizer=optimizer,
                      scheduler=scheduler)
    else:
        trainer.train(net,
                      loss,
                      device,
                      train_loader,
                      test_loader,
                      optimizer=optimizer,
                      scheduler=scheduler)

    # save removed models
    filename = os.path.join(args.model_folder, 'pruned_models.pth')
    temp = torch.load(filename)
    temp_dict = OrderedDict()
    for i in temp:
        if ('orig' in i):
            value = temp[i] * temp[i.split('_orig')[0] + '_mask']
            temp_dict[i.split('module.')[1].split('_orig')[0]] = value
        elif 'mask' not in i:
            temp_dict[i.split('module.')[1]] = temp[i]
    net2.load_state_dict(temp_dict)
    save_model(net2, os.path.join(args.model_folder, 'removed_models.pth'))
    print('saved removed models')
Example #27
0
    def __init__(self,
                 in_planes,
                 planes,
                 num_bits,
                 num_bits_weight,
                 stride,
                 type_prune,
                 sparsity,
                 layer_num,
                 option='A'):
        super(BasicBlock, self).__init__()

        self.conv1 = QConv2d(in_planes,
                             planes,
                             num_bits,
                             num_bits_weight,
                             kernel_size=3,
                             stride=stride,
                             padding=1,
                             bias=False)

        if type_prune == 'channel':
            self.conv1 = prune.ln_structured(self.conv1,
                                             name='weight',
                                             amount=sparsity,
                                             n=2,
                                             dim=0)
        elif type_prune == 'group':
            width = 4
            tmp_pruned = self.conv1.weight.data.clone()
            original_size = tmp_pruned.size()
            tmp_pruned = tmp_pruned.view(original_size[0], -1)
            append_size = width - tmp_pruned.shape[1] % width
            tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]),
                                   1)
            tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width)
            tmp_pruned = tmp_pruned.pow(2.0).mean(
                2, keepdim=True).pow(0.5).expand(tmp_pruned.shape)
            tmp = tmp_pruned.flatten()
            num = tmp.shape[0] * (1 - sparsity)
            top_k = torch.topk(tmp, int(num), sorted=True)
            threshold = top_k.values[-1]
            tmp_pruned = tmp_pruned.ge(threshold)
            tmp_pruned = tmp_pruned.view(original_size[0], -1)
            tmp_pruned = tmp_pruned[:, 0:self.conv1.weight.data[0].nelement()]
            tmp_pruned = tmp_pruned.contiguous().view(original_size)
            self.conv1 = prune.custom_from_mask(self.conv1,
                                                name='weight',
                                                mask=tmp_pruned)

        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = QConv2d(planes,
                             planes,
                             num_bits,
                             num_bits_weight,
                             kernel_size=3,
                             stride=1,
                             padding=1,
                             bias=False)
        if type_prune == 'channel':
            self.conv2 = prune.ln_structured(self.conv2,
                                             name='weight',
                                             amount=sparsity,
                                             n=2,
                                             dim=0)
        elif type_prune == 'group':
            width = 4
            tmp_pruned = self.conv2.weight.data.clone()
            original_size = tmp_pruned.size()
            tmp_pruned = tmp_pruned.view(original_size[0], -1)
            append_size = width - tmp_pruned.shape[1] % width
            tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]),
                                   1)
            tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width)
            tmp_pruned = tmp_pruned.pow(2.0).mean(
                2, keepdim=True).pow(0.5).expand(tmp_pruned.shape)
            tmp = tmp_pruned.flatten()
            num = tmp.shape[0] * (1 - sparsity)
            top_k = torch.topk(tmp, int(num), sorted=True)
            threshold = top_k.values[-1]
            tmp_pruned = tmp_pruned.ge(threshold)
            tmp_pruned = tmp_pruned.view(original_size[0], -1)
            tmp_pruned = tmp_pruned[:, 0:self.conv2.weight.data[0].nelement()]
            tmp_pruned = tmp_pruned.contiguous().view(original_size)
            self.conv2 = prune.custom_from_mask(self.conv2,
                                                name='weight',
                                                mask=tmp_pruned)

        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x: F.pad(
                    x[:, :, ::2, ::2],
                    (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                    QConv2d(in_planes,
                            self.expansion * planes,
                            num_bits,
                            num_bits_weight,
                            kernel_size=1,
                            stride=stride,
                            bias=False),
                    nn.BatchNorm2d(self.expansion * planes))