def prune_fc(module: Linear, keep_idxes: List[int], bn_num_channels: int = None): """ Args: module: keep_idxes: bn_num_channels: prev bn num_channels Returns: """ if bn_num_channels is not None: assert module.in_features % bn_num_channels == 0 channel_step = module.in_features // bn_num_channels _keep_idxes = [] for idx in keep_idxes: _keep_idxes.extend( np.asarray(list(range(channel_step))) + idx * channel_step) keep_idxes = _keep_idxes module.in_features = len(keep_idxes) module.weight.data = module.weight.data[:, keep_idxes] module.weight.grad = None return keep_idxes
def update_attributes(self, link: nn.Linear): out_features, in_features = link.weight.shape link.in_features = in_features link.out_features = out_features