def test_gmp(self): original_model = copy.deepcopy(self.sparse_model).half() deepspeed_model = copy.deepcopy(original_model) deepspeed_model.to(self.device) original_model.to(self.device) replace_sparse_transformer_layer( model=deepspeed_model.base_model, config=self.config, training=True, fp16=True, ) deepspeed_model.apply(rezero_weights) original_model.apply(rezero_weights) # Make sure the model maintains the same sparsity as the original expected_sparsity = _compute_sparsity(original_model) actual_sparsity = _compute_sparsity(deepspeed_model) self.assertAlmostEqual(actual_sparsity, expected_sparsity) # Prune weights original_modules = [ m for m in original_model.modules() if isinstance(m, SparseWeightsBase) ] deepspeed_modules = [ m for m in deepspeed_model.modules() if isinstance(m, SparseWeightsBase) ] actual_removed = global_prune_by_abs_weight(original_modules, 0.3) expected_removed = global_prune_by_abs_weight(deepspeed_modules, 0.3) # Make sure pruned the same number of weights self.assertAlmostEqual(actual_removed, expected_removed) # Make sure the model maintains the same sparsity as the original after # prunnning deepspeed_model.apply(rezero_weights) original_model.apply(rezero_weights) expected_sparsity = _compute_sparsity(original_model) actual_sparsity = _compute_sparsity(deepspeed_model) self.assertAlmostEqual(actual_sparsity, expected_sparsity)
def training_step(self, model, inputs): """ Prune every `prune_period` steps during the pruning phase. """ model = unwrap_model(model) # extract from DistributedDataParallel train_loss = super().training_step(model, inputs) if not self._setup_done: self.setup_pruning(self.args, model) self._setup_done = True # Track the steps by starting at `step=1` train_step = self.state.global_step + 1 # Return if still in warm-up phase. if train_step < self.warmup_steps: return train_loss # Return if in cool-down phase. if train_step > self.max_steps - self.cooldown_steps: return train_loss # Return if step within pruning phase isn't divisible by pruning period. if not (train_step - self.warmup_steps) % self.prune_period == 0: return train_loss # Fraction of the way through the pruning phase. fraction_through_pruning = self.prune_iteration / self.total_prune_iterations # Target number of on-params at time t. lambda_t = np.power(1 - fraction_through_pruning, 3) # varies from 0 to 1 target_on_params = self.end_on_params + ( self.start_on_params - self.end_on_params) * lambda_t # Actual number of on-params at time t. current_on_params = calc_on_params(self.sparse_modules) remove_params = current_on_params - int(target_on_params) # Prune the specified number of params. actual_removed = global_prune_by_abs_weight( self.sparse_modules, num_remove=remove_params ) # Log how much was pruned and resulting sparsity. if self.verbose_gmp_logging: bert_sparsity = calc_model_sparsity(model.bert) logs = dict({ "gmp/target_pruned_params": remove_params, "gmp/actual_pruned_params": actual_removed, "gmp/post_prune_bert_sparsity": bert_sparsity, }) if wandb.run is not None: wandb.log(logs, commit=False) self.control.should_log = True self.prune_iteration += 1 return train_loss
def training_step(self, model, inputs): """Prune and regrow weights every 'prune_freq' iterations.""" train_loss = super().training_step(model, inputs) if self.state.global_step % self.prune_freq != 0: self.prune_scheduler.step() return train_loss # Retrieve sparse modules (e.g. SparseWeights) after model has been setup for # distributed training, if it has. if self.sparse_modules is None: self.sparse_modules = filter_modules( model, include_modules=[SparseWeightsBase] ).values() # Pre-prune sparsities. param_sparsity0, mask_sparsity0 = calc_cumulative_sparsity(self.sparse_modules) # Prune weights. model.apply(rezero_weights) prune_fraction = self.prune_scheduler.get_prune_fraction() num_removed = global_prune_by_abs_weight(self.sparse_modules, prune_fraction) model.apply(rezero_weights) # Post-prune sparsities. param_sparsity1, mask_sparsity1 = calc_cumulative_sparsity(self.sparse_modules) # Accumulate gradients over one batch. self.optimizer.zero_grad() train_dataloader = self.callback_handler.train_dataloader train_batch = next(iter(train_dataloader)) inputs_to_device(train_batch, device=self.args.device) batch_loss = self.compute_loss(model, train_batch) batch_loss.backward() # Regrow weights num_add = self.prune_scheduler.get_num_add(num_removed) global_add_by_abs_grad(self.sparse_modules, num_add) self.prune_scheduler.step() # Post-grow sparsities. param_sparsity2, mask_sparsity2 = calc_cumulative_sparsity(self.sparse_modules) # Log pruning stats. actual_pruned = param_sparsity1 - param_sparsity0 actual_pruned_on_params = actual_pruned / (1 - mask_sparsity0) logging.info(f"RigLMixin:") logging.info(f"Target: remove {prune_fraction} frac of on params") logging.info(f"Actual: removed {actual_pruned_on_params} fraction of on params") # For now, the logs are very robust to ensure pruning occurs as expected. # TODO: Remove non-essential logging. logs = dict({ "rigl/target_pruned_on_params": prune_fraction, "rigl/actual_pruned_on_params": actual_pruned_on_params, "rigl/target_pruned_all_params": prune_fraction * mask_sparsity0, "rigl/actual_pruned_all_params": actual_pruned, "rigl/pre_prune_param_sparsity": param_sparsity0, "rigl/pre_prune_mask_sparsity": mask_sparsity0, "rigl/post_prune_param_sparsity": param_sparsity1, "rigl/post_prune_mask_sparsity": mask_sparsity1, "rigl/pre_grow_param_sparsity": param_sparsity2, "rigl/post_grow_mask_sparsity": mask_sparsity2, }) if wandb.run is not None: wandb.log(logs, step=self.state.global_step) return train_loss
def test_global_rigl(self): """ Test for globally pruning all sparse modules by their weights and adding back by gradients. """ # ----------- # Init model # ----------- # Make sure there are no random zeros in model params. init_all_zero_params(self.model) sparsity = calc_sparsity(self.model) self.assertEqual(sparsity, 0) # Validate initial sparsity after rezeroing the weights. self.model.apply(rezero_weights) sparsity = calc_sparsity(self.model.bert) self.assertTrue(np.isclose(sparsity, 0.4701, atol=1e-4)) # Get all the SparseWeightsBase modules. These will be pruned. sparse_modules = filter_modules(self.model, include_modules=[SparseWeightsBase]) sparse_modules = sparse_modules.values() self.assertEqual(len(sparse_modules), 7) # Validate initial number of off params with sparse modules.. total_sparse_params = np.sum( [m.weight.numel() for m in sparse_modules]) total_off_mask = np.sum( [m.zero_mask.bool().sum() for m in sparse_modules]) total_off_params = np.sum([(m.weight == 0).sum() for m in sparse_modules]) self.assertEqual(total_sparse_params, 168) self.assertEqual(total_off_params, 126) self.assertEqual(total_off_mask, 126) # -------------- # Prune weights # -------------- num_removed = global_prune_by_abs_weight(sparse_modules, prune_fraction=1 / 3) self.model.apply(rezero_weights) total_off_mask = np.sum( [m.zero_mask.bool().sum() for m in sparse_modules]) total_off_params = np.sum([(m.weight == 0).sum() for m in sparse_modules]) self.assertEqual(total_off_mask, 140) self.assertEqual(total_off_params, 140) # --------------- # Regrow weights # --------------- # Pseudo forward pass to accumulate gradients. batch_size = 2 num_ebeddings = self.config.max_position_embeddings attention_mask = torch.ones(batch_size, num_ebeddings).float() input_ids = torch.ones(batch_size, num_ebeddings).long() token_type_ids = torch.ones(batch_size, num_ebeddings).long() labels = torch.ones(batch_size * num_ebeddings).long() outputs = self.model( attention_mask=attention_mask, input_ids=input_ids, labels=labels, token_type_ids=token_type_ids, ) loss = outputs.loss loss.backward() # Add weights according to the largest gradients of the model. global_add_by_abs_grad(sparse_modules, num_add=num_removed) # The new weights are initialized to zero. self.model.apply(rezero_weights) total_off_mask = np.sum( [m.zero_mask.bool().sum() for m in sparse_modules]) total_off_params = np.sum([(m.weight == 0).sum() for m in sparse_modules]) # Validate number of off params after regrowing the weights. self.assertEqual(total_off_mask, 126) self.assertEqual(total_off_params, 140) # Psuedo training step where learning happens on the new zero weights. init_all_zero_params(self.model) self.model.apply(rezero_weights) # Validate number of off params after learning has occurred on new weights. total_off_mask = np.sum( [m.zero_mask.bool().sum() for m in sparse_modules]) total_off_params = np.sum([(m.weight == 0).sum() for m in sparse_modules]) self.assertEqual(total_off_mask, 126) self.assertEqual(total_off_params, 126)
def training_step(self, model, inputs): """Prune and regrow weights every 'prune_freq' iterations.""" train_loss = super().training_step(model, inputs) if self.state.global_step % self.prune_freq != 0: self.prune_scheduler.step() return train_loss # Retrieve sparse modules (e.g. SparseWeights) after model has been setup for # distributed training, if it has. if self.sparse_modules is None: self.sparse_modules = filter_modules( model, include_modules=[SparseWeightsBase]).values() sparse_modules = self.sparse_modules # Pre-prune sparsities (for verbose logging). model.apply(rezero_weights) if self.verbose_rigl_logging: param_sparsity0, mask_sparsity0 = calc_cumulative_sparsity( sparse_modules) # If prune fraction is 0, say for a warmup step, return and don't prune. prune_fraction = self.prune_scheduler.get_prune_fraction() if prune_fraction == 0: self.prune_scheduler.step() return train_loss # Prune weights. num_removed = global_prune_by_abs_weight(self.sparse_modules, prune_fraction) model.apply(rezero_weights) # Post-prune sparsities (for verbose logging). if self.verbose_rigl_logging: param_sparsity1, mask_sparsity1 = calc_cumulative_sparsity( sparse_modules) # Accumulate gradients over one batch. self.optimizer.zero_grad() train_dataloader = self.callback_handler.train_dataloader train_batch = next(iter(train_dataloader)) inputs_to_device(train_batch, device=self.args.device) batch_loss = self.compute_loss(model, train_batch) batch_loss.backward() # Regrow weights num_add = self.prune_scheduler.get_num_add(num_removed) global_add_by_abs_grad(self.sparse_modules, num_add) self.prune_scheduler.step() logs = dict({ "rigl/target_pruned_on_params": prune_fraction, }) # Post-grow sparsities (for verbose logging). if self.verbose_rigl_logging: param_sparsity2, mask_sparsity2 = calc_cumulative_sparsity( sparse_modules) # Log pruning stats. actual_pruned = param_sparsity1 - param_sparsity0 actual_pruned_on_params = actual_pruned / (1 - mask_sparsity0) logging.debug(f"Target: remove {prune_fraction} frac of on params") logging.debug(f"Actual: removed {actual_pruned_on_params} " "fraction of on params") # These are logs are very robust to ensure the actual percentage and count # of pruned-params match the target amounts. logs = dict({ "rigl/actual_pruned_on_params": actual_pruned_on_params, "rigl/target_pruned_all_params": prune_fraction * mask_sparsity0, "rigl/actual_pruned_all_params": actual_pruned, "rigl/pre_prune_param_sparsity": param_sparsity0, "rigl/pre_prune_mask_sparsity": mask_sparsity0, "rigl/post_prune_param_sparsity": param_sparsity1, "rigl/post_prune_mask_sparsity": mask_sparsity1, "rigl/pre_grow_param_sparsity": param_sparsity2, "rigl/post_grow_mask_sparsity": mask_sparsity2, }) if wandb.run is not None: wandb.log(logs, commit=False) return train_loss
def on_step_end( self, args, state, control, model=None, train_dataloader=None, optimizer=None, **kwargs ): """Prune and regrow weights every 'prune_freq' iterations.""" if state.global_step % self.prune_freq != 0: self.prune_scheduler.step() return # Pre-prune sparsities. param_sparsity0, mask_sparsity0 = calc_cumulative_sparsity(self.sparse_modules) # Prune weights. model.apply(rezero_weights) prune_fraction = self.prune_scheduler.get_prune_fraction() num_removed = global_prune_by_abs_weight(self.sparse_modules, prune_fraction) model.apply(rezero_weights) # Post-prune sparsities. param_sparsity1, mask_sparsity1 = calc_cumulative_sparsity(self.sparse_modules) # Accumulate gradients over one batch. optimizer.zero_grad() train_batch = next(iter(train_dataloader)) inputs_to_device(train_batch, device=args.device) output = model(**train_batch) output.loss.backward() # Regrow weights num_add = self.prune_scheduler.get_num_add(num_removed) global_add_by_abs_grad(self.sparse_modules, num_add) self.prune_scheduler.step() # Post-grow sparsities. param_sparsity2, mask_sparsity2 = calc_cumulative_sparsity(self.sparse_modules) # Log pruning stats. actual_pruned = param_sparsity1 - param_sparsity0 actual_pruned_on_params = actual_pruned / (1 - mask_sparsity0) logging.info(f"RigLCallback:") logging.info(f"Target: remove {prune_fraction} frac of on params") logging.info(f"Actual: removed {actual_pruned_on_params} fraction of on params") # For now, the logs are very robust to ensure pruning occurs as expected. # TODO: Remove non-essential logging. logs = dict({ "rigl/target_pruned_on_params": prune_fraction, "rigl/actual_pruned_on_params": actual_pruned_on_params, "rigl/target_pruned_all_params": prune_fraction * mask_sparsity0, "rigl/actual_pruned_all_params": actual_pruned, "rigl/pre_prune_param_sparsity": param_sparsity0, "rigl/pre_prune_mask_sparsity": mask_sparsity0, "rigl/post_prune_param_sparsity": param_sparsity1, "rigl/post_prune_mask_sparsity": mask_sparsity1, "rigl/pre_grow_param_sparsity": param_sparsity2, "rigl/post_grow_mask_sparsity": mask_sparsity2, }) if wandb.run is not None: wandb.log(logs, step=state.global_step)
def test_global_pruning_and_regrowing(self): # ---------- # Prune # ---------- sparse_modules = [self.model.lin1, self.model.lin2, self.model.lin3] global_prune_by_abs_weight(sparse_modules, prune_fraction=0.5) self.model.apply(rezero_weights) # Validate pruned weights expected_w1 = torch.tensor( [ [0.0000, 0.3400, 0.0000, -0.4500], [-0.3700, 0.0000, 0.0000, -0.4500], [-0.3100, 0.0000, 0.0000, 0.0000], [0.4200, 0.0000, 0.0000, 0.0000], ] ) self.assertTrue(all_equal(self.model.lin1.weight, expected_w1)) expected_w2 = torch.tensor( [ [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.2800, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], ] ) self.assertTrue(all_equal(self.model.lin2.weight, expected_w2)) expected_w3 = torch.tensor( [ [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, -0.3600, 0.0000, -0.4300], [0.0000, 0.3300, 0.0000, -0.5000], [0.0000, 0.0000, 0.4100, 0.0000], ] ) self.assertTrue(all_equal(self.model.lin3.weight, expected_w3)) # Validate pruned mask is a subset of the original mask. pruned_lin1_mask = self.model.lin1.zero_mask pruned_lin2_mask = self.model.lin2.zero_mask pruned_lin3_mask = self.model.lin3.zero_mask self.assertTrue((pruned_lin1_mask >= self.initial_lin1_mask).all()) self.assertTrue((pruned_lin2_mask >= self.initial_lin2_mask).all()) self.assertTrue((pruned_lin3_mask >= self.initial_lin3_mask).all()) # Validate number of off weights. zero_masks = parameters_to_vector(self.model.buffers()) weights = parameters_to_vector(self.model.parameters()) self.assertEqual((weights == 0).sum(), 36) self.assertEqual(zero_masks.sum(), 36) # ---------- # Regrow # ---------- # Pseudo forward pass to accumulate gradients. x = torch.tensor( [[0.35, 0.94, 0.10, 0.31], [0.05, 0.16, 0.46, 0.11]], requires_grad=True ) self.model(x).sum().backward() # Regrow weights per the largest abs gradients. global_add_by_abs_grad(sparse_modules, num_add=12) self.model.apply(rezero_weights) # Validate regrown weights expected_w1 = torch.tensor( [ [0.0000, 0.3400, 0.0000, -0.4500], [-0.3700, 0.0000, 0.0000, -0.4500], [-0.3100, 0.0000, 0.0000, 0.0000], [0.4200, 0.0000, 0.0000, 0.0000], ] ) self.assertTrue(all_close(self.model.lin1.weight, expected_w1)) expected_w2 = torch.tensor( [ [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.2800, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], ] ) self.assertTrue(all_close(self.model.lin2.weight, expected_w2)) expected_w3 = torch.tensor( [ [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, -0.3600, 0.0000, -0.4300], [0.0000, 0.3300, 0.0000, -0.5000], [0.0000, 0.0000, 0.4100, 0.0000], ] ) self.assertTrue(all_close(self.model.lin3.weight, expected_w3)) # Validate regrown mask is a subset of the pruned mask. self.assertTrue((self.model.lin1.zero_mask <= pruned_lin1_mask).all()) self.assertTrue((self.model.lin2.zero_mask <= pruned_lin2_mask).all()) self.assertTrue((self.model.lin3.zero_mask <= pruned_lin3_mask).all()) # Validate number of off weights. zero_masks = parameters_to_vector(self.model.buffers()) weights = parameters_to_vector(self.model.parameters()) self.assertEqual((weights == 0).sum(), 36) self.assertEqual(zero_masks.sum(), 24)