def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind', layers_to_ignore: str = ''): # Randomize the mask. mask = Mask.load(self.level_root) # Randomize while keeping the same layerwise proportions as the original mask. if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed)) # Randomize globally throughout all prunable layers. elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask)) # Randomize evenly across all layers. elif strategy == 'even': sparsity = mask.sparsity for i, k in sorted(mask.keys()): layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size), torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size)) mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size) # Identity. elif strategy == 'identity': pass # Error. else: raise ValueError(f'Invalid strategy: {strategy}') # Reset the masks of any layers that shouldn't be pruned. if layers_to_ignore: for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k]) # Save the new mask. mask.save(self.branch_root) # Determine the start step. if start_at == 'init': start_step = self.lottery_desc.str_to_step('0ep') state_step = start_step elif start_at == 'end': start_step = self.lottery_desc.str_to_step('0ep') state_step = self.lottery_desc.train_end_step elif start_at == 'rewind': start_step = self.lottery_desc.train_start_step state_step = start_step else: raise ValueError(f'Invalid starting point {start_at}') # Train the model with the new mask. model = PrunedModel(models.registry.load(self.level_root, state_step, self.lottery_desc.model_hparams), mask) train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def test_save_load_exists(self): self.assertFalse(Mask.exists(self.root)) self.assertFalse(os.path.exists(paths.mask(self.root))) m = Mask({'hello': np.ones([2, 3]), 'world': np.zeros([5, 6])}) m.save(self.root) self.assertTrue(os.path.exists(paths.mask(self.root))) self.assertTrue(Mask.exists(self.root)) m2 = Mask.load(self.root) self.assertEqual(len(m2), 2) self.assertEqual(len(m2.keys()), 2) self.assertEqual(len(m2.values()), 2) self.assertEqual(set(m2.keys()), set(['hello', 'world'])) self.assertTrue(np.array_equal(np.ones([2, 3]), m2['hello'])) self.assertTrue(np.array_equal(np.zeros([5, 6]), m2['world']))
def branch_function( self, strategy: str, prune_fraction: float, prune_experiment: str = 'main', prune_step: str = '0ep0it', # The step for states used to prune. prune_highest: bool = False, prune_iterations: int = 1, randomize_layerwise: bool = False, state_experiment: str = 'main', state_step: str = '0ep0it', # The step of the state to use alongside the pruning mask. start_step: str = '0ep0it', # The step at which to start the learning rate schedule. seed: int = None, reinitialize: bool = False): # Get the steps for each part of the process. iterations_per_epoch = datasets.registry.iterations_per_epoch( self.desc.dataset_hparams) prune_step = Step.from_str(prune_step, iterations_per_epoch) state_step = Step.from_str(state_step, iterations_per_epoch) start_step = Step.from_str(start_step, iterations_per_epoch) seed = self.replicate if seed is None else seed # Try to load the mask. try: mask = Mask.load(self.branch_root) except: mask = None result_folder = "Data_Distribution/" if reinitialize: result_folder = "Data_Distribution_Reinit/" elif randomize_layerwise: result_folder = "Data_Distribution__Randomize_Layerwise/" if not mask and get_platform().is_primary_process: # Gather the weights that will be used for pruning. prune_path = self.desc.run_path(self.replicate, prune_experiment) prune_model = models.registry.load(prune_path, prune_step, self.desc.model_hparams) # Ensure that a valid strategy is available. strategy_class = [s for s in strategies if s.valid_name(strategy)] if not strategy_class: raise ValueError(f'No such pruning strategy {strategy}') if len(strategy_class) > 1: raise ValueError('Multiple matching strategies') strategy_instance = strategy_class[0](strategy, self.desc, seed) # Run the strategy for each iteration. mask = Mask.ones_like(prune_model) iteration_fraction = 1 - (1 - prune_fraction)**( 1 / float(prune_iterations)) if iteration_fraction > 0: for it in range(0, prune_iterations): # Make a defensive copy of the model and mask out the pruned weights. prune_model2 = copy.deepcopy(prune_model) with torch.no_grad(): for k, v in prune_model2.named_parameters(): v.mul_(mask.get(k, 1)) # Compute the scores. scores = strategy_instance.score(prune_model2, mask) # Prune. mask = unvectorize( prune(vectorize(scores), iteration_fraction, not prune_highest, mask=vectorize(mask)), mask) # Shuffle randomly per layer. if randomize_layerwise: mask = shuffle_state_dict(mask, seed=seed) mask = Mask({k: v.clone().detach() for k, v in mask.items()}) mask.save(self.branch_root) # Plot graphs (Move below mask save?) # plot_distribution_scores(strategy_instance.score(prune_model, mask), strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder) # plot_distribution_scatter(strategy_instance.score(prune_model, mask), prune_model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder) # pdb.set_trace() # Load the mask. get_platform().barrier() mask = Mask.load(self.branch_root) # Determine the start step. state_path = self.desc.run_path(self.replicate, state_experiment) if reinitialize: model = models.registry.get(self.desc.model_hparams) else: model = models.registry.load(state_path, state_step, self.desc.model_hparams) # plot_distribution_weights(model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder) original_model = copy.deepcopy(model) model = PrunedModel(model, mask) # pdb.set_trace() train.standard_train(model, self.branch_root, self.desc.dataset_hparams, self.desc.training_hparams, start_step=start_step, verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch) weights_analysis(original_model, strategy, reinitialize, randomize_layerwise, "original") weights_analysis(model, strategy, reinitialize, randomize_layerwise, "pruned")