def test_gmp(self):
        original_model = copy.deepcopy(self.sparse_model).half()
        deepspeed_model = copy.deepcopy(original_model)
        deepspeed_model.to(self.device)
        original_model.to(self.device)

        replace_sparse_transformer_layer(
            model=deepspeed_model.base_model,
            config=self.config,
            training=True,
            fp16=True,
        )

        deepspeed_model.apply(rezero_weights)
        original_model.apply(rezero_weights)

        # Make sure the model maintains the same sparsity as the original
        expected_sparsity = _compute_sparsity(original_model)
        actual_sparsity = _compute_sparsity(deepspeed_model)
        self.assertAlmostEqual(actual_sparsity, expected_sparsity)

        # Prune weights
        original_modules = [
            m for m in original_model.modules()
            if isinstance(m, SparseWeightsBase)
        ]
        deepspeed_modules = [
            m for m in deepspeed_model.modules()
            if isinstance(m, SparseWeightsBase)
        ]
        actual_removed = global_prune_by_abs_weight(original_modules, 0.3)
        expected_removed = global_prune_by_abs_weight(deepspeed_modules, 0.3)

        # Make sure pruned the same number of weights
        self.assertAlmostEqual(actual_removed, expected_removed)

        # Make sure the model maintains the same sparsity as the original after
        # prunnning
        deepspeed_model.apply(rezero_weights)
        original_model.apply(rezero_weights)
        expected_sparsity = _compute_sparsity(original_model)
        actual_sparsity = _compute_sparsity(deepspeed_model)
        self.assertAlmostEqual(actual_sparsity, expected_sparsity)
예제 #2
0
    def training_step(self, model, inputs):
        """
        Prune every `prune_period` steps during the pruning phase.
        """

        model = unwrap_model(model)  # extract from DistributedDataParallel
        train_loss = super().training_step(model, inputs)

        if not self._setup_done:
            self.setup_pruning(self.args, model)
            self._setup_done = True

        # Track the steps by starting at `step=1`
        train_step = self.state.global_step + 1

        # Return if still in warm-up phase.
        if train_step < self.warmup_steps:
            return train_loss

        # Return if in cool-down phase.
        if train_step > self.max_steps - self.cooldown_steps:
            return train_loss

        # Return if step within pruning phase isn't divisible by pruning period.
        if not (train_step - self.warmup_steps) % self.prune_period == 0:
            return train_loss

        # Fraction of the way through the pruning phase.
        fraction_through_pruning = self.prune_iteration / self.total_prune_iterations

        # Target number of on-params at time t.
        lambda_t = np.power(1 - fraction_through_pruning, 3)  # varies from 0 to 1
        target_on_params = self.end_on_params + (
            self.start_on_params - self.end_on_params) * lambda_t

        # Actual number of on-params at time t.
        current_on_params = calc_on_params(self.sparse_modules)
        remove_params = current_on_params - int(target_on_params)

        # Prune the specified number of params.
        actual_removed = global_prune_by_abs_weight(
            self.sparse_modules,
            num_remove=remove_params
        )

        # Log how much was pruned and resulting sparsity.
        if self.verbose_gmp_logging:

            bert_sparsity = calc_model_sparsity(model.bert)
            logs = dict({
                "gmp/target_pruned_params": remove_params,
                "gmp/actual_pruned_params": actual_removed,
                "gmp/post_prune_bert_sparsity": bert_sparsity,
            })

        if wandb.run is not None:
            wandb.log(logs, commit=False)
            self.control.should_log = True

        self.prune_iteration += 1
        return train_loss
예제 #3
0
    def training_step(self, model, inputs):
        """Prune and regrow weights every 'prune_freq' iterations."""

        train_loss = super().training_step(model, inputs)

        if self.state.global_step % self.prune_freq != 0:
            self.prune_scheduler.step()
            return train_loss

        # Retrieve sparse modules (e.g. SparseWeights) after model has been setup for
        # distributed training, if it has.
        if self.sparse_modules is None:
            self.sparse_modules = filter_modules(
                model, include_modules=[SparseWeightsBase]
            ).values()

        # Pre-prune sparsities.
        param_sparsity0, mask_sparsity0 = calc_cumulative_sparsity(self.sparse_modules)

        # Prune weights.
        model.apply(rezero_weights)
        prune_fraction = self.prune_scheduler.get_prune_fraction()
        num_removed = global_prune_by_abs_weight(self.sparse_modules, prune_fraction)
        model.apply(rezero_weights)

        # Post-prune sparsities.
        param_sparsity1, mask_sparsity1 = calc_cumulative_sparsity(self.sparse_modules)

        # Accumulate gradients over one batch.
        self.optimizer.zero_grad()
        train_dataloader = self.callback_handler.train_dataloader
        train_batch = next(iter(train_dataloader))
        inputs_to_device(train_batch, device=self.args.device)
        batch_loss = self.compute_loss(model, train_batch)
        batch_loss.backward()

        # Regrow weights
        num_add = self.prune_scheduler.get_num_add(num_removed)
        global_add_by_abs_grad(self.sparse_modules, num_add)
        self.prune_scheduler.step()

        # Post-grow sparsities.
        param_sparsity2, mask_sparsity2 = calc_cumulative_sparsity(self.sparse_modules)

        # Log pruning stats.
        actual_pruned = param_sparsity1 - param_sparsity0
        actual_pruned_on_params = actual_pruned / (1 - mask_sparsity0)

        logging.info(f"RigLMixin:")
        logging.info(f"Target: remove {prune_fraction} frac of on params")
        logging.info(f"Actual: removed {actual_pruned_on_params} fraction of on params")

        # For now, the logs are very robust to ensure pruning occurs as expected.
        # TODO: Remove non-essential logging.
        logs = dict({
            "rigl/target_pruned_on_params": prune_fraction,
            "rigl/actual_pruned_on_params": actual_pruned_on_params,
            "rigl/target_pruned_all_params": prune_fraction * mask_sparsity0,
            "rigl/actual_pruned_all_params": actual_pruned,
            "rigl/pre_prune_param_sparsity": param_sparsity0,
            "rigl/pre_prune_mask_sparsity": mask_sparsity0,
            "rigl/post_prune_param_sparsity": param_sparsity1,
            "rigl/post_prune_mask_sparsity": mask_sparsity1,
            "rigl/pre_grow_param_sparsity": param_sparsity2,
            "rigl/post_grow_mask_sparsity": mask_sparsity2,
        })
        if wandb.run is not None:
            wandb.log(logs, step=self.state.global_step)

        return train_loss
    def test_global_rigl(self):
        """
        Test for globally pruning all sparse modules by their weights and adding back by
        gradients.
        """

        # -----------
        # Init model
        # -----------

        # Make sure there are no random zeros in model params.
        init_all_zero_params(self.model)
        sparsity = calc_sparsity(self.model)
        self.assertEqual(sparsity, 0)

        # Validate initial sparsity after rezeroing the weights.
        self.model.apply(rezero_weights)
        sparsity = calc_sparsity(self.model.bert)
        self.assertTrue(np.isclose(sparsity, 0.4701, atol=1e-4))

        # Get all the SparseWeightsBase modules. These will be pruned.
        sparse_modules = filter_modules(self.model,
                                        include_modules=[SparseWeightsBase])
        sparse_modules = sparse_modules.values()
        self.assertEqual(len(sparse_modules), 7)

        # Validate initial number of off params with sparse modules..
        total_sparse_params = np.sum(
            [m.weight.numel() for m in sparse_modules])
        total_off_mask = np.sum(
            [m.zero_mask.bool().sum() for m in sparse_modules])
        total_off_params = np.sum([(m.weight == 0).sum()
                                   for m in sparse_modules])
        self.assertEqual(total_sparse_params, 168)
        self.assertEqual(total_off_params, 126)
        self.assertEqual(total_off_mask, 126)

        # --------------
        # Prune weights
        # --------------

        num_removed = global_prune_by_abs_weight(sparse_modules,
                                                 prune_fraction=1 / 3)

        self.model.apply(rezero_weights)
        total_off_mask = np.sum(
            [m.zero_mask.bool().sum() for m in sparse_modules])
        total_off_params = np.sum([(m.weight == 0).sum()
                                   for m in sparse_modules])

        self.assertEqual(total_off_mask, 140)
        self.assertEqual(total_off_params, 140)

        # ---------------
        # Regrow weights
        # ---------------

        # Pseudo forward pass to accumulate gradients.
        batch_size = 2
        num_ebeddings = self.config.max_position_embeddings
        attention_mask = torch.ones(batch_size, num_ebeddings).float()
        input_ids = torch.ones(batch_size, num_ebeddings).long()
        token_type_ids = torch.ones(batch_size, num_ebeddings).long()
        labels = torch.ones(batch_size * num_ebeddings).long()

        outputs = self.model(
            attention_mask=attention_mask,
            input_ids=input_ids,
            labels=labels,
            token_type_ids=token_type_ids,
        )
        loss = outputs.loss
        loss.backward()

        # Add weights according to the largest gradients of the model.
        global_add_by_abs_grad(sparse_modules, num_add=num_removed)

        # The new weights are initialized to zero.
        self.model.apply(rezero_weights)
        total_off_mask = np.sum(
            [m.zero_mask.bool().sum() for m in sparse_modules])
        total_off_params = np.sum([(m.weight == 0).sum()
                                   for m in sparse_modules])

        # Validate number of off params after regrowing the weights.
        self.assertEqual(total_off_mask, 126)
        self.assertEqual(total_off_params, 140)

        # Psuedo training step where learning happens on the new zero weights.
        init_all_zero_params(self.model)
        self.model.apply(rezero_weights)

        # Validate number of off params after learning has occurred on new weights.
        total_off_mask = np.sum(
            [m.zero_mask.bool().sum() for m in sparse_modules])
        total_off_params = np.sum([(m.weight == 0).sum()
                                   for m in sparse_modules])

        self.assertEqual(total_off_mask, 126)
        self.assertEqual(total_off_params, 126)
예제 #5
0
    def training_step(self, model, inputs):
        """Prune and regrow weights every 'prune_freq' iterations."""

        train_loss = super().training_step(model, inputs)

        if self.state.global_step % self.prune_freq != 0:
            self.prune_scheduler.step()
            return train_loss

        # Retrieve sparse modules (e.g. SparseWeights) after model has been setup for
        # distributed training, if it has.
        if self.sparse_modules is None:
            self.sparse_modules = filter_modules(
                model, include_modules=[SparseWeightsBase]).values()
        sparse_modules = self.sparse_modules

        # Pre-prune sparsities (for verbose logging).
        model.apply(rezero_weights)
        if self.verbose_rigl_logging:
            param_sparsity0, mask_sparsity0 = calc_cumulative_sparsity(
                sparse_modules)

        # If prune fraction is 0, say for a warmup step, return and don't prune.
        prune_fraction = self.prune_scheduler.get_prune_fraction()
        if prune_fraction == 0:
            self.prune_scheduler.step()
            return train_loss

        # Prune weights.
        num_removed = global_prune_by_abs_weight(self.sparse_modules,
                                                 prune_fraction)
        model.apply(rezero_weights)

        # Post-prune sparsities (for verbose logging).
        if self.verbose_rigl_logging:
            param_sparsity1, mask_sparsity1 = calc_cumulative_sparsity(
                sparse_modules)

        # Accumulate gradients over one batch.
        self.optimizer.zero_grad()
        train_dataloader = self.callback_handler.train_dataloader
        train_batch = next(iter(train_dataloader))
        inputs_to_device(train_batch, device=self.args.device)
        batch_loss = self.compute_loss(model, train_batch)
        batch_loss.backward()

        # Regrow weights
        num_add = self.prune_scheduler.get_num_add(num_removed)
        global_add_by_abs_grad(self.sparse_modules, num_add)
        self.prune_scheduler.step()

        logs = dict({
            "rigl/target_pruned_on_params": prune_fraction,
        })

        # Post-grow sparsities (for verbose logging).
        if self.verbose_rigl_logging:
            param_sparsity2, mask_sparsity2 = calc_cumulative_sparsity(
                sparse_modules)

            # Log pruning stats.
            actual_pruned = param_sparsity1 - param_sparsity0
            actual_pruned_on_params = actual_pruned / (1 - mask_sparsity0)

            logging.debug(f"Target: remove {prune_fraction} frac of on params")
            logging.debug(f"Actual: removed {actual_pruned_on_params} "
                          "fraction of on params")

            # These are logs are very robust to ensure the actual percentage and count
            # of pruned-params match the target amounts.
            logs = dict({
                "rigl/actual_pruned_on_params": actual_pruned_on_params,
                "rigl/target_pruned_all_params":
                prune_fraction * mask_sparsity0,
                "rigl/actual_pruned_all_params": actual_pruned,
                "rigl/pre_prune_param_sparsity": param_sparsity0,
                "rigl/pre_prune_mask_sparsity": mask_sparsity0,
                "rigl/post_prune_param_sparsity": param_sparsity1,
                "rigl/post_prune_mask_sparsity": mask_sparsity1,
                "rigl/pre_grow_param_sparsity": param_sparsity2,
                "rigl/post_grow_mask_sparsity": mask_sparsity2,
            })

        if wandb.run is not None:
            wandb.log(logs, commit=False)

        return train_loss
예제 #6
0
    def on_step_end(
        self, args, state, control, model=None,
        train_dataloader=None, optimizer=None, **kwargs
    ):
        """Prune and regrow weights every 'prune_freq' iterations."""

        if state.global_step % self.prune_freq != 0:
            self.prune_scheduler.step()
            return

        # Pre-prune sparsities.
        param_sparsity0, mask_sparsity0 = calc_cumulative_sparsity(self.sparse_modules)

        # Prune weights.
        model.apply(rezero_weights)
        prune_fraction = self.prune_scheduler.get_prune_fraction()
        num_removed = global_prune_by_abs_weight(self.sparse_modules, prune_fraction)
        model.apply(rezero_weights)

        # Post-prune sparsities.
        param_sparsity1, mask_sparsity1 = calc_cumulative_sparsity(self.sparse_modules)

        # Accumulate gradients over one batch.
        optimizer.zero_grad()
        train_batch = next(iter(train_dataloader))
        inputs_to_device(train_batch, device=args.device)
        output = model(**train_batch)
        output.loss.backward()

        # Regrow weights
        num_add = self.prune_scheduler.get_num_add(num_removed)
        global_add_by_abs_grad(self.sparse_modules, num_add)
        self.prune_scheduler.step()

        # Post-grow sparsities.
        param_sparsity2, mask_sparsity2 = calc_cumulative_sparsity(self.sparse_modules)

        # Log pruning stats.
        actual_pruned = param_sparsity1 - param_sparsity0
        actual_pruned_on_params = actual_pruned / (1 - mask_sparsity0)

        logging.info(f"RigLCallback:")
        logging.info(f"Target: remove {prune_fraction} frac of on params")
        logging.info(f"Actual: removed {actual_pruned_on_params} fraction of on params")

        # For now, the logs are very robust to ensure pruning occurs as expected.
        # TODO: Remove non-essential logging.
        logs = dict({
            "rigl/target_pruned_on_params": prune_fraction,
            "rigl/actual_pruned_on_params": actual_pruned_on_params,
            "rigl/target_pruned_all_params": prune_fraction * mask_sparsity0,
            "rigl/actual_pruned_all_params": actual_pruned,
            "rigl/pre_prune_param_sparsity": param_sparsity0,
            "rigl/pre_prune_mask_sparsity": mask_sparsity0,
            "rigl/post_prune_param_sparsity": param_sparsity1,
            "rigl/post_prune_mask_sparsity": mask_sparsity1,
            "rigl/pre_grow_param_sparsity": param_sparsity2,
            "rigl/post_grow_mask_sparsity": mask_sparsity2,
        })
        if wandb.run is not None:
            wandb.log(logs, step=state.global_step)
    def test_global_pruning_and_regrowing(self):

        # ----------
        # Prune
        # ----------

        sparse_modules = [self.model.lin1, self.model.lin2, self.model.lin3]
        global_prune_by_abs_weight(sparse_modules, prune_fraction=0.5)
        self.model.apply(rezero_weights)

        # Validate pruned weights
        expected_w1 = torch.tensor(
            [
                [0.0000, 0.3400, 0.0000, -0.4500],
                [-0.3700, 0.0000, 0.0000, -0.4500],
                [-0.3100, 0.0000, 0.0000, 0.0000],
                [0.4200, 0.0000, 0.0000, 0.0000],
            ]
        )
        self.assertTrue(all_equal(self.model.lin1.weight, expected_w1))

        expected_w2 = torch.tensor(
            [
                [0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.2800, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000],
            ]
        )
        self.assertTrue(all_equal(self.model.lin2.weight, expected_w2))

        expected_w3 = torch.tensor(
            [
                [0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, -0.3600, 0.0000, -0.4300],
                [0.0000, 0.3300, 0.0000, -0.5000],
                [0.0000, 0.0000, 0.4100, 0.0000],
            ]
        )
        self.assertTrue(all_equal(self.model.lin3.weight, expected_w3))

        # Validate pruned mask is a subset of the original mask.
        pruned_lin1_mask = self.model.lin1.zero_mask
        pruned_lin2_mask = self.model.lin2.zero_mask
        pruned_lin3_mask = self.model.lin3.zero_mask
        self.assertTrue((pruned_lin1_mask >= self.initial_lin1_mask).all())
        self.assertTrue((pruned_lin2_mask >= self.initial_lin2_mask).all())
        self.assertTrue((pruned_lin3_mask >= self.initial_lin3_mask).all())

        # Validate number of off weights.
        zero_masks = parameters_to_vector(self.model.buffers())
        weights = parameters_to_vector(self.model.parameters())
        self.assertEqual((weights == 0).sum(), 36)
        self.assertEqual(zero_masks.sum(), 36)

        # ----------
        # Regrow
        # ----------

        # Pseudo forward pass to accumulate gradients.
        x = torch.tensor(
            [[0.35, 0.94, 0.10, 0.31], [0.05, 0.16, 0.46, 0.11]], requires_grad=True
        )
        self.model(x).sum().backward()

        # Regrow weights per the largest abs gradients.
        global_add_by_abs_grad(sparse_modules, num_add=12)
        self.model.apply(rezero_weights)

        # Validate regrown weights
        expected_w1 = torch.tensor(
            [
                [0.0000, 0.3400, 0.0000, -0.4500],
                [-0.3700, 0.0000, 0.0000, -0.4500],
                [-0.3100, 0.0000, 0.0000, 0.0000],
                [0.4200, 0.0000, 0.0000, 0.0000],
            ]
        )
        self.assertTrue(all_close(self.model.lin1.weight, expected_w1))

        expected_w2 = torch.tensor(
            [
                [0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.2800, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000],
            ]
        )
        self.assertTrue(all_close(self.model.lin2.weight, expected_w2))

        expected_w3 = torch.tensor(
            [
                [0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, -0.3600, 0.0000, -0.4300],
                [0.0000, 0.3300, 0.0000, -0.5000],
                [0.0000, 0.0000, 0.4100, 0.0000],
            ]
        )
        self.assertTrue(all_close(self.model.lin3.weight, expected_w3))

        # Validate regrown mask is a subset of the pruned mask.
        self.assertTrue((self.model.lin1.zero_mask <= pruned_lin1_mask).all())
        self.assertTrue((self.model.lin2.zero_mask <= pruned_lin2_mask).all())
        self.assertTrue((self.model.lin3.zero_mask <= pruned_lin3_mask).all())

        # Validate number of off weights.
        zero_masks = parameters_to_vector(self.model.buffers())
        weights = parameters_to_vector(self.model.parameters())
        self.assertEqual((weights == 0).sum(), 36)
        self.assertEqual(zero_masks.sum(), 24)