def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None, training_hparams: hparams.TrainingHparams = None, dataset_hparams: hparams.DatasetHparams = None, data_order_seed: int = None): current_mask = Mask.ones_like( trained_model) if current_mask is None else current_mask current_mask_numpy = current_mask.numpy() # Determine the number of weights that need to be pruned. number_of_remaining_weights = np.sum( [np.sum(v) for v in current_mask_numpy.values()]) number_of_weights_to_prune = np.ceil( pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) # Determine which layers can be pruned. prunable_tensors = set(trained_model.prunable_layer_names) if pruning_hparams.pruning_layers_to_ignore: prunable_tensors -= set( pruning_hparams.pruning_layers_to_ignore.split(',')) # Get the model score. scores = Strategy.get_score(trained_model, current_mask, prunable_tensors, training_hparams, dataset_hparams, data_order_seed) # Get the model weights. # weights = {k: v.clone().cpu().detach().numpy() # for k, v in trained_model.state_dict().items() # if k in prunable_tensors} # Create a vector of all the unpruned weights in the model. # weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()]) score_vector = np.concatenate( [v[current_mask_numpy[k] == 1] for k, v in scores.items()]) threshold = np.sort(np.abs(score_vector))[number_of_weights_to_prune] new_mask = Mask({ k: np.where( np.abs(v) > threshold, current_mask_numpy[k], np.zeros_like(v)) for k, v in scores.items() }) for k in current_mask_numpy: if k not in new_mask: new_mask[k] = current_mask_numpy[k] return new_mask
def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None): current_mask = Mask.ones_like(trained_model).numpy( ) if current_mask is None else current_mask.numpy() print("Current values", current_mask.values()) print("Remaining weights", np.sum([np.sum(v) for v in current_mask.values()])) # Determine the number of weights that need to be pruned. number_of_remaining_weights = np.sum( [np.sum(v) for v in current_mask.values()]) number_of_weights_to_prune = np.ceil( pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) # Determine which layers can be pruned. prunable_tensors = set(trained_model.prunable_layer_names) if pruning_hparams.pruning_layers_to_ignore: prunable_tensors -= set( pruning_hparams.pruning_layers_to_ignore.split(',')) # Get the model weights. weights = { k: v.clone().cpu().detach().numpy() for k, v in trained_model.state_dict().items() if k in prunable_tensors } # Create a vector of all the unpruned weights in the model. weight_vector = np.concatenate( [v[current_mask[k] == 1] for k, v in weights.items()]) threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune] new_mask = Mask({ k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v)) for k, v in weights.items() }) for k in current_mask: if k not in new_mask: new_mask[k] = current_mask[k] print("New mask", new_mask) return new_mask
def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None): current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy() # number of initializations num_inits = next(iter(current_mask.values())).shape[0] assert np.array([num_inits == v.shape[0] for v in current_mask.values()]).all() # Determine the number of weights that need to be pruned. number_of_remaining_weights_per_init = np.sum([np.sum(v) for v in current_mask.values()]) // num_inits number_of_weights_to_prune_per_init = np.ceil( pruning_hparams.pruning_fraction * number_of_remaining_weights_per_init).astype(int) # Determine which layers can be pruned. prunable_tensors = set(trained_model.prunable_layer_names) if pruning_hparams.pruning_layers_to_ignore: prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(',')) # Get the model weights. weights = {k: v.clone().cpu().detach().numpy() for k, v in trained_model.state_dict().items() if k in prunable_tensors} # Create a vector of all the unpruned weights in the model. weight_vectors = [ np.concatenate( [ v[init_id, ...][current_mask[k][init_id,...] == 1] for k, v in weights.items() ] ) for init_id in range(num_inits)] thresholds = np.array([ np.sort(np.abs(wv))[number_of_weights_to_prune_per_init] for wv in weight_vectors ]) mask_dict = {} for k, v in weights.items(): threshold_tensor = thresholds.reshape(-1, *[1 for _ in range(v.ndim-1)]) threshold_tensor = np.tile(threshold_tensor, v.shape[1:]) mask_dict[k] = np.where(np.abs(v) > threshold_tensor, current_mask[k], np.zeros_like(v)) new_mask = Mask(mask_dict) for k in current_mask: if k not in new_mask: new_mask[k] = current_mask[k] return new_mask