def post_epoch(self, epoch): count_nnz = self.logger.isEnabledFor(logging.DEBUG) and self.rank == 0 if count_nnz: params_sparse, nonzero_params_sparse1 = count_nonzero_params( self.model) self.model.apply(rezero_weights) if self.post_epoch_hooks: for hook in self.post_epoch_hooks: self.model.apply(hook) if count_nnz: params_sparse, nonzero_params_sparse2 = count_nonzero_params( self.model) self.logger.debug( "Params total/nnz before/nnz after %s %s / %s = %s", params_sparse, nonzero_params_sparse1, nonzero_params_sparse2, float(nonzero_params_sparse2) / params_sparse) self.logger.debug("End of epoch %s LR/weight decay before step: %s/%s", epoch, self.get_lr(), self.get_weight_decay()) # Update learning rate if not isinstance(self.lr_scheduler, (OneCycleLR, ComposedLRScheduler)): self.lr_scheduler.step() self.logger.debug("End of epoch %s LR/weight decay after step: %s/%s", epoch, self.get_lr(), self.get_weight_decay())
def on_init_end(self, args, state, control, model, **kwargs): """Log sparsity of the model and the sparsity of just the encoder.""" model.apply(rezero_weights) num_total, num_nonzero = count_nonzero_params(model) model_sparsity = 1 - (num_nonzero / num_total) logging.info( f"Non-zero Params / Total Params, {num_nonzero:,} / {num_total:,}") logging.info(f" Model Sparsity={model_sparsity:.4f}") num_total, num_nonzero = count_nonzero_params(model.bert.encoder) encoder_sparsity = 1 - (num_nonzero / num_total) logging.info(f" Encoder Sparsity={encoder_sparsity:0.4f}") num_total, num_nonzero = count_nonzero_params(model.bert) bert_sparsity = 1 - (num_nonzero / num_total) logging.info(f" Bert Sparsity={bert_sparsity:0.4f}") if wandb.run is not None: wandb.run.summary.update( dict( bert_on_params_at_init=num_nonzero, bert_sparsity_at_init=bert_sparsity, ))
def test_metacl_experiment_with_rezero_weights(self): # Get experiment class. exp = metacl_experiment["experiment_class"]() # The classes are sampled randomly, so we need a way to make sure # the experiment will only sample from what's been randomly generated. metacl_experiment.pop("num_classes") dataset = exp.load_dataset(metacl_experiment, train=True) class_indices = exp.compute_class_indices(metacl_experiment, dataset) fast_and_slow_sampler = exp.create_train_sampler(metacl_experiment, dataset, class_indices=class_indices) replay_sampler = exp.create_replay_sampler(metacl_experiment, dataset, class_indices=class_indices) metacl_experiment.update( fast_and_slow_classes=list(fast_and_slow_sampler.task_indices.keys()), replay_classes=list(replay_sampler.task_indices.keys()), ) # Setup experiment and initialize model. exp.setup_experiment(metacl_experiment) total_params, on_params = count_nonzero_params(exp.model) assert total_params == 160 assert on_params <= 80 # Less than as some may be randomly zero. # Loop through some pseudo epochs. for _ in range(10): total_params, on_params = count_nonzero_params(exp.model) assert on_params <= 80 exp.run_epoch() total_params, on_params = count_nonzero_params(exp.model) assert on_params <= 80
def update_auxillary_metrics(self, args, **kwargs): """ Track sparsity, learning rate, and run checks to ensure sparsity is not changing too much. Run only after updating metrics for all eval sets. """ # track sparsity information num_total, num_nonzero = count_nonzero_params(kwargs["model"]) model_sparsity = 1 - (num_nonzero / num_total) self.eval_metrics["num_total_params"].append(num_total) self.eval_metrics["num_nonzero_params"].append(num_nonzero) self.eval_metrics["sparsity"].append(model_sparsity) # guarantee that everything stayed sparse, up to specified tolerance if (self.sparsity_tolerance < 1) and len( self.eval_metrics["sparsity"]) > 1: sparse_diff = self.eval_metrics["sparsity"][0] - self.eval_metrics[ "sparsity"][-1] # noqa if abs(sparse_diff) > self.sparsity_tolerance: logging.warn( "Model sparsity fluctuated beyond acceptable range." f"Current sparsity level: {self.eval_metrics['sparsity'][-1]}" ) # track learning rate # get_last_lr() returns lr for each parameter group. For now, # assume lrs are the same for all and just track one. if kwargs["lr_scheduler"] is not None: last_lr = kwargs["lr_scheduler"].get_last_lr() self.eval_metrics["lr"].append(last_lr[0]) self.step_counter += args.eval_steps self.steps.append(self.step_counter)
def test_simple_linear(self): """Count non-zero params in a simple linear net""" model = simple_linear_net() expected_params = 32 * 16 + 2 * 16 + 16 + 2 total_params, total_nonzero_params = count_nonzero_params(model) self.assertEqual(total_nonzero_params, expected_params) self.assertEqual(total_params, expected_params) model[0].weight[0, 0] = 0.0 model[0].weight[0, 1] = 0.0 model[2].weight[0, 0] = 0.0 model[2].weight[1, 0] = 0.0 total_params, total_nonzero_params = count_nonzero_params(model) self.assertEqual(total_nonzero_params, expected_params - 4) self.assertEqual(total_params, expected_params)
def post_epoch(self): super().post_epoch() count_nnz = self.logger.isEnabledFor(logging.DEBUG) and self.rank == 0 if count_nnz: params_sparse, nonzero_params_sparse1 = count_nonzero_params( self.model) self.model.apply(rezero_weights) if count_nnz: params_sparse, nonzero_params_sparse2 = count_nonzero_params( self.model) self.logger.debug( "Params total/nnz before/nnz after %s %s / %s = %s", params_sparse, nonzero_params_sparse1, nonzero_params_sparse2, float(nonzero_params_sparse2) / params_sparse)
def test_simple_conv_net(self): """Count non-zero params in a simple linear net""" model = simple_conv_net() expected_params = 75 + 3 + 3 * 111 + 3 + 6 + 2 total_params, total_nonzero_params = count_nonzero_params(model) self.assertEqual(total_nonzero_params, expected_params) self.assertEqual(total_params, expected_params) model[0].weight[0, 0, 3, 3] = 0.0 model[0].weight[1, 0, 1, 1] = 0.0 model[4].weight[0, 0] = 0.0 model[4].weight[1, 0] = 0.0 total_params, total_nonzero_params = count_nonzero_params(model) self.assertEqual(total_nonzero_params, expected_params - 4) self.assertEqual(total_params, expected_params)
def test_params_count(self): """ Test the number of non-zero parameters for default dense and sparse networks """ dense_net = resnet50(config=dict(num_classes=10)) dense_net(Variable(torch.randn(2, 3, 32, 32))) sparse_net = resnet50(config=dict(num_classes=10, defaults_sparse=True)) sparse_net(Variable(torch.randn(2, 3, 32, 32))) total_params_dense, total_nonzero_params_dense = count_nonzero_params(dense_net) self.assertGreater(total_params_dense, 23500000) self.assertGreaterEqual(total_params_dense, total_nonzero_params_dense) params_sparse, nonzero_params_sparse = count_nonzero_params(sparse_net) self.assertEqual(params_sparse, total_params_dense) self.assertLess(nonzero_params_sparse, 10000000)
def setup_experiment(self, config): super().setup_experiment(config) if self.rank == 0: params_sparse, nonzero_params_sparse2 = count_nonzero_params( self.model) self.logger.debug("Params total/nnz %s / %s = %s ", params_sparse, nonzero_params_sparse2, float(nonzero_params_sparse2) / params_sparse)
def setup_experiment(self, config): super().setup_experiment(config) if not self.logger.disabled: params_sparse, nonzero_params_sparse2 = count_nonzero_params( self.model) self.logger.debug("Params nnz/total %s / %s = %s ", nonzero_params_sparse2, params_sparse, float(nonzero_params_sparse2) / params_sparse)
def test_supervised_experiment_with_rezero_weights(self): # Setup experiment and initialize model. exp = supervised_experiment["experiment_class"]() exp.setup_experiment(supervised_experiment) total_params, on_params = count_nonzero_params(exp.model) assert total_params == 160 assert on_params <= 80 # Less than as some may be randomly zero. # Loop through some pseudo epochs. for _ in range(10): total_params, on_params = count_nonzero_params(exp.model) assert on_params <= 80 exp.run_epoch() total_params, on_params = count_nonzero_params(exp.model) assert on_params <= 80
def test_simple_model_is_half_sparse(self): # Supervised Experiment: # Validate that the fully rezeroed model has exactly 80 on-params exp = supervised_experiment["experiment_class"] model = exp.create_model(supervised_experiment, "cpu") model.classifier.module.weight.data[:] = 1 model.classifier.rezero_weights() total_params, on_params = count_nonzero_params(model) assert on_params == 80 # MetaCL Experiment: # Validate that the fully rezeroed model has exactly 80 on-params exp = metacl_experiment["experiment_class"] model = exp.create_model(metacl_experiment, "cpu") model.classifier.module.weight.data[:] = 1 model.classifier.rezero_weights() total_params, on_params = count_nonzero_params(model) assert on_params == 80
def on_step_end(self, args, state, control, model, **kwargs): """Rezero weights and log sparsity.""" model.apply(rezero_weights) # Log sparsity to wandb if wandb.run is not None: num_total, num_nonzero = count_nonzero_params(model) model_sparsity = 1 - (num_nonzero / num_total) num_total, num_nonzero = count_nonzero_params(model.bert) bert_sparsity = 1 - (num_nonzero / num_total) num_total, num_nonzero = count_nonzero_params(model.bert.encoder) encoder_sparsity = 1 - (num_nonzero / num_total) logs = dict(model_sparsity=model_sparsity, bert_sparsity=bert_sparsity, encoder_sparsity=encoder_sparsity) wandb.log(logs, commit=False)
def on_train_end(self, args, state, control, model, **kwargs): num_total, num_nonzero = count_nonzero_params(model.bert) bert_sparsity = 1 - (num_nonzero / num_total) logging.info(f" Bert Sparsity={bert_sparsity:0.4f}") if wandb.run is not None: wandb.run.summary.update( dict( bert_on_params_at_end=num_nonzero, bert_sparsity_at_end=bert_sparsity, ))
def test_simple_sparse_conv_net(self): """Count non-zero params in a simple linear net""" model = simple_sparse_conv_net() expected_params = (4 * 5 * 5 + 4) + (4 * 5 * 5 * 6 + 6) + \ (6 * 100 + 100) + (100 * 2 + 2) expected_nonzero_params = (4 * round(0.2 * 5 * 5) + 4) + \ (6 * round(0.5 * 5 * 5 * 4) + 6) + \ (100 * round(0.2 * 6) + 100) + \ (100 * 2 + 2) total_params, total_nonzero_params = count_nonzero_params(model) self.assertEqual(total_nonzero_params, expected_nonzero_params) self.assertEqual(total_params, expected_params)
def post_epoch(self, epoch): count_nnz = self.logger.isEnabledFor(logging.DEBUG) and self.rank == 0 if count_nnz: params_sparse, nonzero_params_sparse1 = count_nonzero_params( self.model) self.model.apply(rezero_weights) if count_nnz: params_sparse, nonzero_params_sparse2 = count_nonzero_params( self.model) self.logger.debug("Params total/nnz before/nnz after %s %s / %s", params_sparse, nonzero_params_sparse1, nonzero_params_sparse2) # Update learning rate if not isinstance(self.lr_scheduler, (OneCycleLR, ComposedLRScheduler)): self.lr_scheduler.step() if self.rank == 0: self.logger.info("LR Scheduler: %s", self.get_lr()) if self.scaled_lr_scheduler is not None: self.scaled_lr_scheduler.step()
def test_custom_auto_params(self): """Create sparse ResNets with custom auto params.""" net = ResNet( config=dict(num_classes=10, defaults_sparse=True, activation_params_func=my_auto_sparse_activation_params, conv_params_func=my_auto_sparse_conv_params) ) net(Variable(torch.randn(2, 3, 32, 32))) params_sparse, nonzero_params_sparse = count_nonzero_params(net) self.assertAlmostEqual(float(nonzero_params_sparse) / params_sparse, 0.42, delta=0.01) self.assertIsInstance(net, ResNet, "Loads ResNet50 with custom auto params")
def test(self, test_loader=None): if test_loader is None: test_loader = self.gen_test_loader if not self.validation: test_loader = test_loader self.validation = False else: test_loader = self.validation_loader ret = evaluate_model(self.model, test_loader, self.device) ret["mean_accuracy"] = 100.0 * ret["mean_accuracy"] entropy = self.entropy() ret.update({ "entropy": float(entropy), "total_samples": len(test_loader.sampler), "non_zero_parameters": count_nonzero_params(self.model)[1], }) return ret