Exemplo n.º 1
0
def ranked_filter_pruning(config, ratio_to_prune, is_parallel):
    """Test L1 ranking and pruning of filters.
    First we rank and prune the filters of a Convolutional layer using
    a L1RankedStructureParameterPruner.  Then we physically remove the
    filters from the model (via "thining" process).
    """
    model, zeros_mask_dict = common.setup_test(config.arch, config.dataset,
                                               is_parallel)

    for pair in config.module_pairs:
        # Test that we can access the weights tensor of the first convolution in layer 1
        conv1_p = distiller.model_find_param(model, pair[0] + ".weight")
        assert conv1_p is not None
        num_filters = conv1_p.size(0)

        # Test that there are no zero-filters
        assert distiller.sparsity_3D(conv1_p) == 0.0

        # Create a filter-ranking pruner
        pruner = distiller.pruning.L1RankedStructureParameterPruner(
            "filter_pruner",
            group_type="Filters",
            desired_sparsity=ratio_to_prune,
            weights=pair[0] + ".weight")
        pruner.set_param_mask(conv1_p,
                              pair[0] + ".weight",
                              zeros_mask_dict,
                              meta=None)

        conv1 = common.find_module_by_name(model, pair[0])
        assert conv1 is not None
        # Test that the mask has the correct fraction of filters pruned.
        # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
        expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
        expected_pruning = expected_cnt_removed_filters / conv1.out_channels
        masker = zeros_mask_dict[pair[0] + ".weight"]
        assert masker is not None
        assert distiller.sparsity_3D(masker.mask) == expected_pruning

        # Use the mask to prune
        assert distiller.sparsity_3D(conv1_p) == 0
        masker.apply_mask(conv1_p)
        assert distiller.sparsity_3D(conv1_p) == expected_pruning

        # Remove filters
        conv2 = common.find_module_by_name(model, pair[1])
        assert conv2 is not None
        assert conv1.out_channels == num_filters
        assert conv2.in_channels == num_filters

    # Test thinning
    distiller.remove_filters(model,
                             zeros_mask_dict,
                             config.arch,
                             config.dataset,
                             optimizer=None)
    assert conv1.out_channels == num_filters - expected_cnt_removed_filters
    assert conv2.in_channels == num_filters - expected_cnt_removed_filters
    return model, zeros_mask_dict
Exemplo n.º 2
0
    def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
        assert param.dim() == 4, "This pruning is only supported for 4D weights"
        if param.grad is None:
            msglogger.info("Skipping gradient pruning of %s because it does not have a gradient yet", param_name)
            return
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
            return

        # Compute the multiplication of the filters times the filter_gradienrs
        view_filters = param.view(param.size(0), -1)
        view_filter_grads = param.grad.view(param.size(0), -1)
        weighted_gradients = view_filter_grads * view_filters
        weighted_gradients = weighted_gradients.sum(dim=1)

        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
        mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map)
        zeros_mask_dict[param_name].mask = mask

        msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                       param_name,
                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, num_filters_to_prune, num_filters)
        return binary_map
Exemplo n.º 3
0
    def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
        assert param.dim() == 4, "This pruning is only supported for 4D weights"

        # Use the parameter name to locate the module that has the activation sparsity statistics
        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
        module = distiller.find_module_by_fq_name(model, fq_name)
        if module is None:
            raise ValueError("Could not find a layer named %s in the model."
                             "\nMake sure to use assign_layer_fq_names()" % fq_name)
        if not hasattr(module, self.activation_rank_criterion):
            raise ValueError("Could not find attribute \"{}\" in module %s"
                             "\nMake sure to use SummaryActivationStatsCollector(\"{}\")".
                             format(self.activation_rank_criterion, fq_name, self.activation_rank_criterion))

        quality_criterion, std = getattr(module, self.activation_rank_criterion).value()
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
            return

        # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_criterion = np.argsort(quality_criterion)[:-num_filters_to_prune]
        mask, binary_map = _mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map)
        zeros_mask_dict[param_name].mask = mask

        msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                       param_name,
                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, num_filters_to_prune, num_filters)
        return binary_map
    def rank_and_prune_filters(self,
                               fraction_to_prune,
                               param,
                               param_name,
                               zeros_mask_dict,
                               model,
                               binary_map=None):
        assert param.dim() == 4, "This pruner is only supported for 4D weights"
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)

        keep_prob = 1 - fraction_to_prune
        if binary_map is None:
            binary_map = torch.bernoulli(
                torch.as_tensor([keep_prob] * num_filters))
        mask, _ = _mask_from_filter_order(None, param, num_filters, binary_map)
        mask = mask.to(param.device)
        # Compensate for dropping filters
        pruning_factor = binary_map.sum() / num_filters
        mask.div_(pruning_factor)

        zeros_mask_dict[param_name].mask = mask
        msglogger.debug(
            "BernoulliFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
            param_name,
            distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
            fraction_to_prune, num_filters_to_prune, num_filters)
        return binary_map
    def rank_and_prune_filters(self,
                               fraction_to_prune,
                               param,
                               param_name,
                               zeros_mask_dict,
                               model,
                               binary_map=None):
        assert param.dim(
        ) == 4, "This pruning is only supported for 4D weights"
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)

        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters",
                           100 * fraction_to_prune)
            return

        filters_ordered_randomly = np.random.permutation(
            num_filters)[:-num_filters_to_prune]
        mask, binary_map = _mask_from_filter_order(filters_ordered_randomly,
                                                   param, num_filters,
                                                   binary_map)
        zeros_mask_dict[param_name].mask = mask

        msglogger.info(
            "RandomRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
            param_name,
            distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
            fraction_to_prune, num_filters_to_prune, num_filters)
        return binary_map
    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
        assert param.dim() == 4, "This thresholding is only supported for 4D weights"

        # Use the parameter name to locate the module that has the activation sparsity statistics
        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
        module = distiller.find_module_by_fq_name(model, fq_name)
        if module is None:
            raise ValueError("Could not find a layer named %s in the model."
                             "\nMake sure to use assign_layer_fq_names()" % fq_name)
        if not hasattr(module, 'apoz_channels'):
            raise ValueError("Could not find attribute \'apoz_channels\' in module %s."
                             "\nMake sure to use SummaryActivationStatsCollector(\"apoz_channels\")" % fq_name)

        apoz, std = module.apoz_channels.value()
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
            return

        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_apoz = np.argsort(-apoz)[:-num_filters_to_prune]
        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_apoz,
                                                                                               param, num_filters)

        msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                       param_name,
                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, num_filters_to_prune, num_filters)
Exemplo n.º 7
0
def test_sparsity():
    zeros = torch.zeros(2, 3, 5, 6)
    print(distiller.sparsity(zeros))
    assert distiller.sparsity(zeros) == 1.0
    assert distiller.sparsity_3D(zeros) == 1.0
    assert distiller.density_3D(zeros) == 0.0

    ones = torch.zeros(12, 43, 4, 6)
    ones.fill_(1)
    assert distiller.sparsity(ones) == 0.0
Exemplo n.º 8
0
def test_ranked_filter_pruning():
    model, zeros_mask_dict = setup_test("resnet20_cifar", "cifar10")

    # Test that we can access the weights tensor of the first convolution in layer 1
    conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight")
    assert conv1_p is not None

    # Test that there are no zero-channels
    assert distiller.sparsity_3D(conv1_p) == 0.0

    # Create a filter-ranking pruner
    reg_regims = {"layer1.0.conv1.weight": [0.1, "3D"]}
    pruner = distiller.pruning.L1RankedStructureParameterPruner(
        "filter_pruner", reg_regims)
    pruner.set_param_mask(conv1_p,
                          "layer1.0.conv1.weight",
                          zeros_mask_dict,
                          meta=None)

    conv1 = find_module_by_name(model, "layer1.0.conv1")
    assert conv1 is not None
    # Test that the mask has the correct fraction of filters pruned.
    # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
    expected_pruning = int(0.1 * conv1.out_channels) / conv1.out_channels
    assert distiller.sparsity_3D(
        zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning

    # Use the mask to prune
    assert distiller.sparsity_3D(conv1_p) == 0
    zeros_mask_dict["layer1.0.conv1.weight"].apply_mask(conv1_p)
    assert distiller.sparsity_3D(conv1_p) == expected_pruning

    # Remove filters
    conv2 = find_module_by_name(model, "layer1.0.conv2")
    assert conv2 is not None
    assert conv1.out_channels == 16
    assert conv2.in_channels == 16

    # Test thinning
    distiller.remove_filters(model, zeros_mask_dict, "resnet20_cifar",
                             "cifar10")
    assert conv1.out_channels == 15
    assert conv2.in_channels == 15
Exemplo n.º 9
0
    def rank_and_prune_filters(self,
                               fraction_to_prune,
                               param,
                               param_name,
                               zeros_mask_dict,
                               model,
                               binary_map=None):
        assert param.dim(
        ) == 4, "This pruning is only supported for 4D weights"

        # Use the parameter name to locate the module that has the activation sparsity statistics
        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
        distiller.assign_layer_fq_names(model)
        module = distiller.find_module_by_fq_name(model, fq_name)
        assert module is not None

        if not hasattr(module, self.activation_rank_criterion):
            raise ValueError(
                "Could not find attribute \"%s\" in module %s\n"
                "\tThis is pruner uses activation statistics collected during forward-"
                "passes of the network.\n"
                "\tThis error is an indication that these statistics "
                "have not been collected yet.\n"
                "\tMake sure to use SummaryActivationStatsCollector(\"%s\")\n"
                "\tFor more info see issue #444 (https://github.com/NervanaSystems/distiller/issues/444)"
                % (self.activation_rank_criterion, fq_name,
                   self.activation_rank_criterion))

        quality_criterion, std = getattr(
            module, self.activation_rank_criterion).value()
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters",
                           100 * fraction_to_prune)
            return

        # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_criterion = np.argsort(
            quality_criterion)[:-num_filters_to_prune]
        mask, binary_map = _mask_from_filter_order(
            filters_ordered_by_criterion, param, num_filters, binary_map)
        zeros_mask_dict[param_name].mask = mask

        msglogger.info(
            "ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
            param_name,
            distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
            fraction_to_prune, num_filters_to_prune, num_filters)
        return binary_map
Exemplo n.º 10
0
def prune_tensor(param, param_name, fraction_to_prune, zeros_mask_dict):
    """Prune filters from a parameter tensor.

    Returns the filter-sparsity of the tensor.
    """
    # Create a filter-ranking pruner
    pruner = distiller.pruning.L1RankedStructureParameterPruner(
        name=None,
        group_type="Filters",
        desired_sparsity=fraction_to_prune,
        weights=param_name)
    pruner.set_param_mask(param, param_name, zeros_mask_dict, meta=None)
    # Use the mask to prune
    zeros_mask_dict[param_name].apply_mask(param)
    return distiller.sparsity_3D(param)
Exemplo n.º 11
0
def weights_sparsity_summary(model,
                             return_total_sparsity=False,
                             param_dims=[2, 4]):

    df = pd.DataFrame(columns=[
        'Name', 'Shape', 'NNZ (dense)', 'NNZ (sparse)', 'Cols (%)', 'Rows (%)',
        'Ch (%)', '2D (%)', '3D (%)', 'Fine (%)', 'Std', 'Mean', 'Abs-Mean'
    ])
    pd.set_option('precision', 2)
    params_size = 0
    sparse_params_size = 0
    summary_param_types = ['weight', 'bias']
    for name, param in model.state_dict().items():
        # Extract just the actual parameter's name, which in this context we treat as its "type"
        curr_param_type = name.split('.')[-1]
        if param.dim(
        ) in param_dims and curr_param_type in summary_param_types:
            _density = distiller.density(param)
            params_size += torch.numel(param)
            sparse_params_size += param.numel() * _density
            df.loc[len(df.index)] = ([
                name,
                distiller.size_to_str(param.size()),
                torch.numel(param),
                int(_density * param.numel()),
                distiller.sparsity_cols(param) * 100,
                distiller.sparsity_rows(param) * 100,
                distiller.sparsity_ch(param) * 100,
                distiller.sparsity_2D(param) * 100,
                distiller.sparsity_3D(param) * 100, (1 - _density) * 100,
                param.std().item(),
                param.mean().item(),
                param.abs().mean().item()
            ])

    total_sparsity = (1 - sparse_params_size / params_size) * 100

    df.loc[len(df.index)] = ([
        'Total sparsity:', '-', params_size,
        int(sparse_params_size), 0, 0, 0, 0, 0, total_sparsity, 0, 0, 0
    ])

    if return_total_sparsity:
        return df, total_sparsity
    return df
Exemplo n.º 12
0
def test_sparsity():
    zeros = torch.zeros(2, 3, 5, 6)
    print(distiller.sparsity(zeros))
    assert distiller.sparsity(zeros) == 1.0
    assert distiller.sparsity_3D(zeros) == 1.0
    assert distiller.density_3D(zeros) == 0.0
    ones = torch.ones(12, 43, 4, 6)
    assert distiller.sparsity(ones) == 0.0
    x = torch.tensor([[1., 2., 0, 4., 0], [1., 2., 0, 4., 0]])
    assert distiller.density(x) == 0.6
    assert distiller.density_cols(x, transposed=False) == 0.6
    assert distiller.sparsity_rows(x, transposed=False) == 0
    x = torch.tensor([[0., 0., 0], [1., 4., 0], [1., 2., 0], [0., 0., 0]])
    assert distiller.density(x) == 4 / 12
    assert distiller.sparsity_rows(x, transposed=False) == 0.5
    assert common.almost_equal(distiller.sparsity_cols(x, transposed=False),
                               1 / 3)
    assert common.almost_equal(distiller.sparsity_rows(x), 1 / 3)
Exemplo n.º 13
0
def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4]):

    df = pd.DataFrame(columns=['Name', 'Shape', 'NNZ (dense)', 'NNZ (sparse)',
                               'Cols (%)','Rows (%)', 'Ch (%)', '2D (%)', '3D (%)',
                               'Fine (%)', 'Std', 'Mean', 'Abs-Mean'])
    pd.set_option('precision', 2)
    params_size = 0
    sparse_params_size = 0
    for name, param in model.state_dict().items():
        if (param.dim() in param_dims) and any(type in name for type in ['weight', 'bias']):
            _density = distiller.density(param)
            params_size += torch.numel(param)
            sparse_params_size += param.numel() * _density
            df.loc[len(df.index)] = ([
                name,
                distiller.size_to_str(param.size()),
                torch.numel(param),
                int(_density * param.numel()),
                distiller.sparsity_cols(param)*100,
                distiller.sparsity_rows(param)*100,
                distiller.sparsity_ch(param)*100,
                distiller.sparsity_2D(param)*100,
                distiller.sparsity_3D(param)*100,
                (1-_density)*100,
                param.std().item(),
                param.mean().item(),
                param.abs().mean().item()
            ])

    total_sparsity = (1 - sparse_params_size/params_size)*100

    df.loc[len(df.index)] = ([
        'Total sparsity:',
        '-',
        params_size,
        int(sparse_params_size),
        0, 0, 0, 0, 0,
        total_sparsity,
        0, 0, 0])

    if return_total_sparsity:
        return df, total_sparsity
    return df
Exemplo n.º 14
0
    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
        num_filters = param.size(0)
        num_filters_to_prune = int(fraction_to_prune * num_filters)
        if num_filters_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
            return

        # Compute the multiplicatipn of the filters times the filter_gradienrs
        view_filters = param.view(param.size(0), -1)
        view_filter_grads = param.grad.view(param.size(0), -1)
        weighted_gradients = view_filter_grads * view_filters
        weighted_gradients = weighted_gradients.sum(dim=1)

        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
        filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_gradient,
                                                                                               param, num_filters)
        msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                       param_name,
                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, num_filters_to_prune, num_filters)
def weights_sparsity_summary(model,
                             return_total_sparsity=False,
                             param_dims=[2, 4]):
    df = pd.DataFrame(columns=[
        "Name",
        "Shape",
        "NNZ (dense)",
        "NNZ (sparse)",
        "Cols (%)",
        "Rows (%)",
        "Ch (%)",
        "2D (%)",
        "3D (%)",
        "Fine (%)",
        "Std",
        "Mean",
        "Abs-Mean",
    ])
    pd.set_option("precision", 2)
    params_size = 0
    sparse_params_size = 0
    for name, param in model.state_dict().items():
        # Extract just the actual parameter's name, which in this context we treat as its "type"
        if param.dim() in param_dims and any(type in name
                                             for type in ["weight", "bias"]):
            _density = distiller.density(param)
            params_size += torch.numel(param)
            sparse_params_size += param.numel() * _density
            df.loc[len(df.index)] = [
                name,
                distiller.size_to_str(param.size()),
                torch.numel(param),
                int(_density * param.numel()),
                distiller.sparsity_cols(param) * 100,
                distiller.sparsity_rows(param) * 100,
                distiller.sparsity_ch(param) * 100,
                distiller.sparsity_2D(param) * 100,
                distiller.sparsity_3D(param) * 100,
                (1 - _density) * 100,
                param.std().item(),
                param.mean().item(),
                param.abs().mean().item(),
            ]

    total_sparsity = (1 - sparse_params_size / params_size) * 100

    df.loc[len(df.index)] = [
        "Total sparsity:",
        "-",
        params_size,
        int(sparse_params_size),
        0,
        0,
        0,
        0,
        0,
        total_sparsity,
        0,
        0,
        0,
    ]

    if return_total_sparsity:
        return df, total_sparsity
    return df
Exemplo n.º 16
0
def ranked_filter_pruning(config,
                          ratio_to_prune,
                          is_parallel,
                          rounding_fn=math.floor):
    """Test L1 ranking and pruning of filters.
    First we rank and prune the filters of a Convolutional layer using
    a L1RankedStructureParameterPruner.  Then we physically remove the
    filters from the model (via "thining" process).
    """
    logger.info("executing: %s (invoked by %s)" %
                (inspect.currentframe().f_code.co_name,
                 inspect.currentframe().f_back.f_code.co_name))

    model, zeros_mask_dict = common.setup_test(config.arch, config.dataset,
                                               is_parallel)

    for pair in config.module_pairs:
        # Test that we can access the weights tensor of the first convolution in layer 1
        conv1_p = distiller.model_find_param(model, pair[0] + ".weight")
        assert conv1_p is not None
        num_filters = conv1_p.size(0)

        # Test that there are no zero-filters
        assert distiller.sparsity_3D(conv1_p) == 0.0

        # Create a filter-ranking pruner
        pruner = distiller.pruning.L1RankedStructureParameterPruner(
            "filter_pruner",
            group_type="Filters",
            desired_sparsity=ratio_to_prune,
            weights=pair[0] + ".weight",
            rounding_fn=rounding_fn)
        pruner.set_param_mask(conv1_p,
                              pair[0] + ".weight",
                              zeros_mask_dict,
                              meta=None)

        conv1 = common.find_module_by_name(model, pair[0])
        assert conv1 is not None
        # Test that the mask has the correct fraction of filters pruned.
        # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
        expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
        expected_pruning = expected_cnt_removed_filters / conv1.out_channels
        masker = zeros_mask_dict[pair[0] + ".weight"]
        assert masker is not None
        assert distiller.sparsity_3D(masker.mask) == expected_pruning

        # Use the mask to prune
        assert distiller.sparsity_3D(conv1_p) == 0
        masker.apply_mask(conv1_p)
        assert distiller.sparsity_3D(conv1_p) == expected_pruning

        # Remove filters
        conv2 = common.find_module_by_name(model, pair[1])
        assert conv2 is not None
        assert conv1.out_channels == num_filters
        assert conv2.in_channels == num_filters

    # Test thinning
    distiller.remove_filters(model,
                             zeros_mask_dict,
                             config.arch,
                             config.dataset,
                             optimizer=None)
    assert conv1.out_channels == num_filters - expected_cnt_removed_filters
    assert conv2.in_channels == num_filters - expected_cnt_removed_filters

    # Test the thinned model
    dummy_input = distiller.get_dummy_input(config.dataset,
                                            distiller.model_device(model))
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=0.1)
    run_forward_backward(model, optimizer, dummy_input)

    return model, zeros_mask_dict