def prune_fc(module: Linear,
             keep_idxes: List[int],
             bn_num_channels: int = None):
    """

    Args:
        module:
        keep_idxes:
        bn_num_channels: prev bn num_channels

    Returns:

    """
    if bn_num_channels is not None:
        assert module.in_features % bn_num_channels == 0

        channel_step = module.in_features // bn_num_channels

        _keep_idxes = []
        for idx in keep_idxes:
            _keep_idxes.extend(
                np.asarray(list(range(channel_step))) + idx * channel_step)

        keep_idxes = _keep_idxes

    module.in_features = len(keep_idxes)
    module.weight.data = module.weight.data[:, keep_idxes]
    module.weight.grad = None
    return keep_idxes
Exemplo n.º 2
0
    def update_attributes(self, link: nn.Linear):
        out_features, in_features = link.weight.shape

        link.in_features = in_features
        link.out_features = out_features