def post_epoch(self, epoch):
        count_nnz = self.logger.isEnabledFor(logging.DEBUG) and self.rank == 0
        if count_nnz:
            params_sparse, nonzero_params_sparse1 = count_nonzero_params(
                self.model)

        self.model.apply(rezero_weights)
        if self.post_epoch_hooks:
            for hook in self.post_epoch_hooks:
                self.model.apply(hook)

        if count_nnz:
            params_sparse, nonzero_params_sparse2 = count_nonzero_params(
                self.model)
            self.logger.debug(
                "Params total/nnz before/nnz after %s %s / %s = %s",
                params_sparse, nonzero_params_sparse1, nonzero_params_sparse2,
                float(nonzero_params_sparse2) / params_sparse)

        self.logger.debug("End of epoch %s LR/weight decay before step: %s/%s",
                          epoch, self.get_lr(), self.get_weight_decay())

        # Update learning rate
        if not isinstance(self.lr_scheduler,
                          (OneCycleLR, ComposedLRScheduler)):
            self.lr_scheduler.step()

        self.logger.debug("End of epoch %s LR/weight decay after step: %s/%s",
                          epoch, self.get_lr(), self.get_weight_decay())
Example #2
0
    def on_init_end(self, args, state, control, model, **kwargs):
        """Log sparsity of the model and the sparsity of just the encoder."""

        model.apply(rezero_weights)

        num_total, num_nonzero = count_nonzero_params(model)
        model_sparsity = 1 - (num_nonzero / num_total)
        logging.info(
            f"Non-zero Params / Total Params, {num_nonzero:,} / {num_total:,}")
        logging.info(f"   Model Sparsity={model_sparsity:.4f}")

        num_total, num_nonzero = count_nonzero_params(model.bert.encoder)
        encoder_sparsity = 1 - (num_nonzero / num_total)
        logging.info(f"   Encoder Sparsity={encoder_sparsity:0.4f}")

        num_total, num_nonzero = count_nonzero_params(model.bert)
        bert_sparsity = 1 - (num_nonzero / num_total)
        logging.info(f"   Bert Sparsity={bert_sparsity:0.4f}")

        if wandb.run is not None:
            wandb.run.summary.update(
                dict(
                    bert_on_params_at_init=num_nonzero,
                    bert_sparsity_at_init=bert_sparsity,
                ))
Example #3
0
    def test_metacl_experiment_with_rezero_weights(self):

        # Get experiment class.
        exp = metacl_experiment["experiment_class"]()

        # The classes are sampled randomly, so we need a way to make sure
        # the experiment will only sample from what's been randomly generated.
        metacl_experiment.pop("num_classes")
        dataset = exp.load_dataset(metacl_experiment, train=True)
        class_indices = exp.compute_class_indices(metacl_experiment, dataset)
        fast_and_slow_sampler = exp.create_train_sampler(metacl_experiment, dataset,
                                                         class_indices=class_indices)
        replay_sampler = exp.create_replay_sampler(metacl_experiment, dataset,
                                                   class_indices=class_indices)
        metacl_experiment.update(
            fast_and_slow_classes=list(fast_and_slow_sampler.task_indices.keys()),
            replay_classes=list(replay_sampler.task_indices.keys()),
        )

        # Setup experiment and initialize model.
        exp.setup_experiment(metacl_experiment)
        total_params, on_params = count_nonzero_params(exp.model)
        assert total_params == 160
        assert on_params <= 80  # Less than as some may be randomly zero.

        # Loop through some pseudo epochs.
        for _ in range(10):

            total_params, on_params = count_nonzero_params(exp.model)
            assert on_params <= 80

            exp.run_epoch()

            total_params, on_params = count_nonzero_params(exp.model)
            assert on_params <= 80
Example #4
0
    def update_auxillary_metrics(self, args, **kwargs):
        """
        Track sparsity, learning rate, and run checks to ensure sparsity
        is not changing too much. Run only after updating metrics for all
        eval sets.
        """
        # track sparsity information
        num_total, num_nonzero = count_nonzero_params(kwargs["model"])
        model_sparsity = 1 - (num_nonzero / num_total)
        self.eval_metrics["num_total_params"].append(num_total)
        self.eval_metrics["num_nonzero_params"].append(num_nonzero)
        self.eval_metrics["sparsity"].append(model_sparsity)

        # guarantee that everything stayed sparse, up to specified tolerance
        if (self.sparsity_tolerance < 1) and len(
                self.eval_metrics["sparsity"]) > 1:
            sparse_diff = self.eval_metrics["sparsity"][0] - self.eval_metrics[
                "sparsity"][-1]  # noqa
            if abs(sparse_diff) > self.sparsity_tolerance:
                logging.warn(
                    "Model sparsity fluctuated beyond acceptable range."
                    f"Current sparsity level: {self.eval_metrics['sparsity'][-1]}"
                )

        # track learning rate
        # get_last_lr() returns lr for each parameter group. For now,
        # assume lrs are the same for all and just track one.
        if kwargs["lr_scheduler"] is not None:
            last_lr = kwargs["lr_scheduler"].get_last_lr()
            self.eval_metrics["lr"].append(last_lr[0])

        self.step_counter += args.eval_steps
        self.steps.append(self.step_counter)
Example #5
0
    def test_simple_linear(self):
        """Count non-zero params in a simple linear net"""

        model = simple_linear_net()
        expected_params = 32 * 16 + 2 * 16 + 16 + 2
        total_params, total_nonzero_params = count_nonzero_params(model)
        self.assertEqual(total_nonzero_params, expected_params)
        self.assertEqual(total_params, expected_params)

        model[0].weight[0, 0] = 0.0
        model[0].weight[0, 1] = 0.0
        model[2].weight[0, 0] = 0.0
        model[2].weight[1, 0] = 0.0

        total_params, total_nonzero_params = count_nonzero_params(model)
        self.assertEqual(total_nonzero_params, expected_params - 4)
        self.assertEqual(total_params, expected_params)
    def post_epoch(self):
        super().post_epoch()

        count_nnz = self.logger.isEnabledFor(logging.DEBUG) and self.rank == 0
        if count_nnz:
            params_sparse, nonzero_params_sparse1 = count_nonzero_params(
                self.model)

        self.model.apply(rezero_weights)

        if count_nnz:
            params_sparse, nonzero_params_sparse2 = count_nonzero_params(
                self.model)
            self.logger.debug(
                "Params total/nnz before/nnz after %s %s / %s = %s",
                params_sparse, nonzero_params_sparse1, nonzero_params_sparse2,
                float(nonzero_params_sparse2) / params_sparse)
Example #7
0
    def test_simple_conv_net(self):
        """Count non-zero params in a simple linear net"""

        model = simple_conv_net()
        expected_params = 75 + 3 + 3 * 111 + 3 + 6 + 2
        total_params, total_nonzero_params = count_nonzero_params(model)
        self.assertEqual(total_nonzero_params, expected_params)
        self.assertEqual(total_params, expected_params)

        model[0].weight[0, 0, 3, 3] = 0.0
        model[0].weight[1, 0, 1, 1] = 0.0
        model[4].weight[0, 0] = 0.0
        model[4].weight[1, 0] = 0.0

        total_params, total_nonzero_params = count_nonzero_params(model)
        self.assertEqual(total_nonzero_params, expected_params - 4)
        self.assertEqual(total_params, expected_params)
    def test_params_count(self):
        """
        Test the number of non-zero parameters for default dense and sparse networks
        """
        dense_net = resnet50(config=dict(num_classes=10))
        dense_net(Variable(torch.randn(2, 3, 32, 32)))

        sparse_net = resnet50(config=dict(num_classes=10, defaults_sparse=True))
        sparse_net(Variable(torch.randn(2, 3, 32, 32)))

        total_params_dense, total_nonzero_params_dense = count_nonzero_params(dense_net)
        self.assertGreater(total_params_dense, 23500000)
        self.assertGreaterEqual(total_params_dense, total_nonzero_params_dense)

        params_sparse, nonzero_params_sparse = count_nonzero_params(sparse_net)

        self.assertEqual(params_sparse, total_params_dense)
        self.assertLess(nonzero_params_sparse, 10000000)
    def setup_experiment(self, config):
        super().setup_experiment(config)

        if self.rank == 0:
            params_sparse, nonzero_params_sparse2 = count_nonzero_params(
                self.model)
            self.logger.debug("Params total/nnz %s / %s = %s ", params_sparse,
                              nonzero_params_sparse2,
                              float(nonzero_params_sparse2) / params_sparse)
Example #10
0
    def setup_experiment(self, config):
        super().setup_experiment(config)

        if not self.logger.disabled:
            params_sparse, nonzero_params_sparse2 = count_nonzero_params(
                self.model)
            self.logger.debug("Params nnz/total %s / %s = %s ",
                              nonzero_params_sparse2, params_sparse,
                              float(nonzero_params_sparse2) / params_sparse)
Example #11
0
    def test_supervised_experiment_with_rezero_weights(self):

        # Setup experiment and initialize model.
        exp = supervised_experiment["experiment_class"]()
        exp.setup_experiment(supervised_experiment)
        total_params, on_params = count_nonzero_params(exp.model)
        assert total_params == 160
        assert on_params <= 80  # Less than as some may be randomly zero.

        # Loop through some pseudo epochs.
        for _ in range(10):

            total_params, on_params = count_nonzero_params(exp.model)
            assert on_params <= 80

            exp.run_epoch()

            total_params, on_params = count_nonzero_params(exp.model)
            assert on_params <= 80
Example #12
0
    def test_simple_model_is_half_sparse(self):

        # Supervised Experiment:
        #   Validate that the fully rezeroed model has exactly 80 on-params
        exp = supervised_experiment["experiment_class"]
        model = exp.create_model(supervised_experiment, "cpu")
        model.classifier.module.weight.data[:] = 1
        model.classifier.rezero_weights()
        total_params, on_params = count_nonzero_params(model)
        assert on_params == 80

        # MetaCL Experiment:
        #   Validate that the fully rezeroed model has exactly 80 on-params
        exp = metacl_experiment["experiment_class"]
        model = exp.create_model(metacl_experiment, "cpu")
        model.classifier.module.weight.data[:] = 1
        model.classifier.rezero_weights()
        total_params, on_params = count_nonzero_params(model)
        assert on_params == 80
Example #13
0
    def on_step_end(self, args, state, control, model, **kwargs):
        """Rezero weights and log sparsity."""

        model.apply(rezero_weights)

        # Log sparsity to wandb
        if wandb.run is not None:
            num_total, num_nonzero = count_nonzero_params(model)
            model_sparsity = 1 - (num_nonzero / num_total)

            num_total, num_nonzero = count_nonzero_params(model.bert)
            bert_sparsity = 1 - (num_nonzero / num_total)

            num_total, num_nonzero = count_nonzero_params(model.bert.encoder)
            encoder_sparsity = 1 - (num_nonzero / num_total)

            logs = dict(model_sparsity=model_sparsity,
                        bert_sparsity=bert_sparsity,
                        encoder_sparsity=encoder_sparsity)

            wandb.log(logs, commit=False)
Example #14
0
    def on_train_end(self, args, state, control, model, **kwargs):

        num_total, num_nonzero = count_nonzero_params(model.bert)
        bert_sparsity = 1 - (num_nonzero / num_total)
        logging.info(f"   Bert Sparsity={bert_sparsity:0.4f}")

        if wandb.run is not None:
            wandb.run.summary.update(
                dict(
                    bert_on_params_at_end=num_nonzero,
                    bert_sparsity_at_end=bert_sparsity,
                ))
Example #15
0
    def test_simple_sparse_conv_net(self):
        """Count non-zero params in a simple linear net"""

        model = simple_sparse_conv_net()
        expected_params = (4 * 5 * 5 + 4) + (4 * 5 * 5 * 6 + 6) + \
                          (6 * 100 + 100) + (100 * 2 + 2)
        expected_nonzero_params = (4 * round(0.2 * 5 * 5) + 4) + \
                                  (6 * round(0.5 * 5 * 5 * 4) + 6) + \
                                  (100 * round(0.2 * 6) + 100) + \
                                  (100 * 2 + 2)
        total_params, total_nonzero_params = count_nonzero_params(model)
        self.assertEqual(total_nonzero_params, expected_nonzero_params)
        self.assertEqual(total_params, expected_params)
Example #16
0
    def post_epoch(self, epoch):
        count_nnz = self.logger.isEnabledFor(logging.DEBUG) and self.rank == 0
        if count_nnz:
            params_sparse, nonzero_params_sparse1 = count_nonzero_params(
                self.model)

        self.model.apply(rezero_weights)

        if count_nnz:
            params_sparse, nonzero_params_sparse2 = count_nonzero_params(
                self.model)
            self.logger.debug("Params total/nnz before/nnz after %s %s / %s",
                              params_sparse, nonzero_params_sparse1,
                              nonzero_params_sparse2)

        # Update learning rate
        if not isinstance(self.lr_scheduler,
                          (OneCycleLR, ComposedLRScheduler)):
            self.lr_scheduler.step()

        if self.rank == 0:
            self.logger.info("LR Scheduler: %s", self.get_lr())
        if self.scaled_lr_scheduler is not None:
            self.scaled_lr_scheduler.step()
    def test_custom_auto_params(self):
        """Create sparse ResNets with custom auto params."""

        net = ResNet(
            config=dict(num_classes=10,
                        defaults_sparse=True,
                        activation_params_func=my_auto_sparse_activation_params,
                        conv_params_func=my_auto_sparse_conv_params)
        )
        net(Variable(torch.randn(2, 3, 32, 32)))

        params_sparse, nonzero_params_sparse = count_nonzero_params(net)
        self.assertAlmostEqual(float(nonzero_params_sparse) / params_sparse,
                               0.42, delta=0.01)

        self.assertIsInstance(net, ResNet, "Loads ResNet50 with custom auto params")
Example #18
0
    def test(self, test_loader=None):
        if test_loader is None:
            test_loader = self.gen_test_loader

        if not self.validation:
            test_loader = test_loader
            self.validation = False
        else:
            test_loader = self.validation_loader

        ret = evaluate_model(self.model, test_loader, self.device)
        ret["mean_accuracy"] = 100.0 * ret["mean_accuracy"]
        entropy = self.entropy()
        ret.update({
            "entropy": float(entropy),
            "total_samples": len(test_loader.sampler),
            "non_zero_parameters": count_nonzero_params(self.model)[1],
        })

        return ret