def prune_fc(module: Linear, keep_idxes: List[int], bn_num_channels: int = None): """ Args: module: keep_idxes: bn_num_channels: prev bn num_channels Returns: """ if bn_num_channels is not None: assert module.in_features % bn_num_channels == 0 channel_step = module.in_features // bn_num_channels _keep_idxes = [] for idx in keep_idxes: _keep_idxes.extend( np.asarray(list(range(channel_step))) + idx * channel_step ) keep_idxes = _keep_idxes module.in_features = len(keep_idxes) module.weight = torch.nn.Parameter(module.weight.data[:, keep_idxes]) module.weight.grad = None return keep_idxes
def test_linear_layer_feed_forward(self): num_hidden_layer = 4 num_input_data_feature = 2 data = np.array([[1, 2], [2, 3], [3, 4]]) self.assertEqual(data.shape[-1], num_input_data_feature) torch_data = torch.Tensor(data) initial_weights = xavier_uniform((2, num_hidden_layer)) initial_bias = np.ones(num_hidden_layer) linear = LinearLayer(num_input_data_feature, num_hidden_layer, initial_weights=initial_weights, initial_bias=initial_bias, bias_exist=False) output = linear.forward(data) with torch.no_grad(): torch_linear = Linear(num_input_data_feature, num_hidden_layer, False) torch_linear.weight = torch.nn.Parameter( torch.Tensor(initial_weights.transpose())) output_torch = torch_linear(torch_data) epsilon = 1e-5 self.assertTrue( np.alltrue(output - output_torch.numpy() < epsilon))
def prune_fc_layer_with_craig(layer: nn.Linear, prune_percent_per_layer: float, similarity_metric: Union[Text, Dict] = "", prune_type: Text = "craig", **kwargs) -> Tuple[List[int], List[float]]: # Get CRAIG subset. subset_nodes: List subset_weights: List subset_nodes, subset_weights = get_layer_craig_subset( layer=layer, original_num_nodes=layer.out_features, prune_percent_per_layer=prune_percent_per_layer, similarity_metric=similarity_metric, prune_type=prune_type, **kwargs) # Remove nodes+weights+biases, and adjust weights. num_nodes: int = len(subset_nodes) # Prune current layer. # Multiply weights (and biases?) by subset_weights. subset_weights_tensor = torch.tensor(subset_weights) layer.weight = nn.Parameter(layer.weight[subset_nodes] * subset_weights_tensor.reshape((num_nodes, 1))) if layer.bias is not None: layer.bias = nn.Parameter(layer.bias[subset_nodes] * subset_weights_tensor) layer.out_features = num_nodes return subset_nodes, subset_weights
def _create_mean_predictor(self): """Creates a new predictor using the mean parameters across the predictor heads.""" weights = [] biases = [] # Collect weights/biases from each predictor head and create tensors for i in range(self.n_participants): weights.append(self.predictor_heads[i].weight) biases.append(self.predictor_heads[i].bias) weights = torch.stack(weights) biases = torch.stack(biases) # Create new linear predictor and set weights/biases to means predictor_heads_mean = Linear(self.hidden_dim, self.output_dim) predictor_heads_mean.weight = Parameter(weights.mean(0)) predictor_heads_mean.bias = Parameter(biases.mean(0)) return predictor_heads_mean
def _create_sampled_predictor(self): """Creates a new predictor using parameters sampled from the prior for random effects.""" # Sample parameters and extract weight and bias parameters from flattened list sampled_params = MultivariateNormal(self.mean, self.cov).sample() flattened_mlp_params = sampled_params[ : ((self.hidden_dim + 1) * self.output_dim) ] mlp_params = flattened_mlp_params.reshape( (self.output_dim, self.hidden_dim + 1) ) weight, bias = mlp_params[:, :-1], mlp_params[:, -1] # Create new linear predictor and set weights/biases to sampled values predictor_head_sampled = Linear(self.hidden_dim, self.output_dim) predictor_head_sampled.weight = Parameter(weight) predictor_head_sampled.bias = Parameter(bias) return predictor_head_sampled
def weight_sharing(rank, checkpoint): # get the model, wrap with DDP and fwd, bwd. set_random_seed(31415) l1 = Linear(2000, 2000) l2 = Linear(2000, 2000) l1.weight = l2.weight model = Sequential(l1, l2) model.to("cuda") model = DDP(model, device_ids=[rank]) input_tensor = torch.rand((64, 2000)).cuda() input_tensor.requires_grad = True output_tensor = checkpoint(model, input_tensor) output_tensor.sum().backward() norm = 0.0 for p in model.parameters(): assert p.grad is not None norm += p.grad.norm().item() assert numpy.allclose(norm, 57004.34228515625), norm