def prune_fc_layer_with_craig(layer: nn.Linear, prune_percent_per_layer: float, similarity_metric: Union[Text, Dict] = "", prune_type: Text = "craig", **kwargs) -> Tuple[List[int], List[float]]: # Get CRAIG subset. subset_nodes: List subset_weights: List subset_nodes, subset_weights = get_layer_craig_subset( layer=layer, original_num_nodes=layer.out_features, prune_percent_per_layer=prune_percent_per_layer, similarity_metric=similarity_metric, prune_type=prune_type, **kwargs) # Remove nodes+weights+biases, and adjust weights. num_nodes: int = len(subset_nodes) # Prune current layer. # Multiply weights (and biases?) by subset_weights. subset_weights_tensor = torch.tensor(subset_weights) layer.weight = nn.Parameter(layer.weight[subset_nodes] * subset_weights_tensor.reshape((num_nodes, 1))) if layer.bias is not None: layer.bias = nn.Parameter(layer.bias[subset_nodes] * subset_weights_tensor) layer.out_features = num_nodes return subset_nodes, subset_weights
def update_attributes(self, link: nn.Linear): out_features, in_features = link.weight.shape link.in_features = in_features link.out_features = out_features