Exemplo n.º 1
0
    def setup_experiment(self, config):
        """
        Save the names of which modules to track. The forward hooks are registered in
        `clone_model` at the start of every epoch of training.

        The subsets of modules to track are defined via `include_*` params. See
        `filter_modules`_ for further details.

        :param config:
            - track_input_sparsity_args:
                - include_modules: a list of module types to track
                - include_names: a list of module names to track e.g. "features.stem"
                - include_patterns: a list of regex patterns to compare to the names;
                                    for instance, all feature parameters in ResNet can
                                    be included through "features.*"
            - track_output_sparsity_args: same as track_input_sparsity_args
        """
        super().setup_experiment(config)

        # The hooks managers will be initialized for each cloned model, prior to the
        # fast steps. For now, just the names of the modules to track will be saved.
        self.output_hook_manager = None
        self.input_hook_manager = None

        # Save the names of the modules to tracked inputs.
        input_tracking_args = config.get("track_input_sparsity_args", {})
        named_modules = filter_modules(self.model, **input_tracking_args)
        self.track_input_of_names = list(named_modules.keys())

        # Log the names of the modules with tracked inputs.
        tracked_names = pformat(self.track_input_of_names)
        self.logger.info(
            f"Tracking input sparsity of modules: {tracked_names}")

        # Save the names of the modules to tracked outputs.
        output_tracking_args = config.get("track_output_sparsity_args", {})
        named_modules = filter_modules(self.model, **output_tracking_args)
        self.track_output_of_names = list(named_modules.keys())

        # Log the names of the modules with tracked outputs.
        tracked_names = pformat(self.track_output_of_names)
        self.logger.info(
            f"Tracking output sparsity of modules: {tracked_names}")

        # Throw a warning when no modules are being tracked.
        if not input_tracking_args and not output_tracking_args:
            self.logger.warning(
                "No modules specified to track input/output sparsity.")
Exemplo n.º 2
0
    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Process config args
        ha_args = config.get("plot_hidden_activations_args", {})
        ha_plot_freq, filter_args, ha_max_samples = self.process_ha_args(
            ha_args)

        self.ha_plot_freq = ha_plot_freq
        self.ha_max_samples = ha_max_samples

        # Register hook for tracking hidden activations
        named_modules = filter_modules(self.model, **filter_args)
        hook_args = dict(max_samples_to_track=self.ha_max_samples)
        self.ha_hook = ModelHookManager(named_modules,
                                        TrackHiddenActivationsHook,
                                        hook_args=hook_args)

        # Log the names of the modules being tracked
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(
            f"Tracking hidden activations of modules: {tracked_names}")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.ha_targets = torch.tensor([]).long()
Exemplo n.º 3
0
    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Process config args
        gradient_metrics_args = config.get("gradient_metrics_args", {})
        self.gradient_metrics_plot_freq, self.gradient_metrics_filter_args, \
            self.gradient_metrics_max_samples, self.gradient_metrics = \
            process_gradient_metrics_args(
                gradient_metrics_args)

        # Register hook for tracking hidden activations
        named_modules = filter_modules(self.model,
                                       **self.gradient_metrics_filter_args)
        hook_args = dict(
            max_samples_to_track=self.gradient_metrics_max_samples)
        self.gradient_metric_hooks = ModelHookManager(named_modules,
                                                      TrackGradientsHook,
                                                      hook_type="backward",
                                                      hook_args=hook_args)

        # Log the names of the modules being tracked
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(f"Tracking gradients for modules: {tracked_names}")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.gradient_metric_targets = torch.tensor([]).long()
Exemplo n.º 4
0
    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Unpack, validate, and process the default arguments.
        metric_args = config.get("plot_dendrite_metrics_args", {})
        self.metric_args, filter_args, max_samples = self.process_args(metric_args)

        # The maximum 'max_samples_to_track' will be tracked by the all the hooks.
        self.max_samples_to_track = max_samples
        hook_args = dict(max_samples_to_track=self.max_samples_to_track)

        # The 'filter_args' specify which modules to track.
        named_modules = filter_modules(self.model, **filter_args)
        self.dendrite_hooks = ModelHookManager(named_modules,
                                               ApplyDendritesHook,
                                               hook_args=hook_args)

        # The hook is specifically made for `ApplyDendritesBase` modules.
        for module in named_modules.values():
            assert isinstance(module, ApplyDendritesBase)

        # Log the names of the modules being tracked, and warn when there's none.
        names = list(named_modules.keys())
        self.logger.info(f"Dendrite Metric Setup: Tracking modules: {names}")
        if len(names) == 0:
            self.logger.warning("Dendrite Metric Setup: "
                                "No modules found for tracking.")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.targets = torch.tensor([]).long()
Exemplo n.º 5
0
    def setup_experiment(self, config):
        """
        Register forward hooks to track the input and output sparsities.

        The subsets of modules to track are defined via `include_*` params. See
        `filter_modules`_ for further details.

        :param config:
            - track_input_sparsity_args:
                - include_modules: a list of module types to track
                - include_names: a list of module names to track e.g. "features.stem"
                - include_patterns: a list of regex patterns to compare to the names;
                                    for instance, all feature parameters in ResNet can
                                    be included through "features.*"
            - track_output_sparsity_args: same as track_input_sparsity_args

        .. _filter_modules: nupic.research.frameworks.pytorch.model_utils.filter_modules
        """
        super().setup_experiment(config)

        # Register hooks to track input sparsities.
        input_tracking_args = config.get("track_input_sparsity_args", {})
        named_modules = filter_modules(self.model, **input_tracking_args)
        self.input_hook_manager = ModelHookManager(named_modules,
                                                   TrackSparsityHook)

        # Log the names of the modules with tracked inputs.
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(
            f"Tracking input sparsity of modules: {tracked_names}")

        # Register hooks to track output sparsities.
        output_tracking_args = config.get("track_output_sparsity_args", {})
        named_modules = filter_modules(self.model, **output_tracking_args)
        self.output_hook_manager = ModelHookManager(named_modules,
                                                    TrackSparsityHook)

        # Log the names of the modules with tracked outputs.
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(
            f"Tracking output sparsity of modules: {tracked_names}")

        # Throw a warning when no modules are being tracked.
        if not input_tracking_args and not output_tracking_args:
            self.logger.warning(
                "No modules specified to track input/output sparsity.")
    def test_name_not_included(self):
        """
        Test the case when a param name does not exist in the network.
        """

        resnet = resnets.resnet50(num_classes=10)
        named_params = filter_modules(resnet, include_names=["adaptation.1"])
        self.assertEqual(len(named_params), 0)
 def test_filter_out_resnet_linear_params(self):
     """
     Filter out only the linear params of resnet.
     """
     resnet = resnets.resnet50(num_classes=10)
     named_modules = filter_modules(resnet,
                                    include_modules=[torch.nn.Linear])
     self.assertEqual(len(named_modules), 1)
     self.assertIn("classifier", named_modules)
    def test_name_not_included(self):
        """
        Test the case when a param name does not exist in the network.
        """

        oml = OMLNetwork(num_classes=10)

        named_params = filter_modules(oml, include_names=["adaptation.1"])
        self.assertTrue(len(named_params) == 0)
Exemplo n.º 9
0
 def on_init_end(self, args, state, control, model=None, **kwargs):
     """
     Save a list of all the sparse modules and a dict of their initial densities.
     """
     self.sparse_modules = filter_modules(model, include_modules=[SparseWeightsBase])
     for name, module in self.sparse_modules.items():
         zero_mask = module.zero_mask.bool()
         self.initial_on_params[name] = zero_mask.numel() - count_nonzero(zero_mask)
     initial_sparsity = getattr(args, "config_kwargs", {}).get("sparsity", 0)
     self.initial_density = 1 - initial_sparsity
Exemplo n.º 10
0
    def test_include_name(self):
        """
        Test use of `include_names`.
        """

        resnet = resnets.resnet50(num_classes=10)
        named_modules = filter_modules(resnet, include_names=["classifier"])
        self.assertEqual(len(named_modules), 1)
        self.assertIn("classifier", named_modules)
        self.assertIsInstance(named_modules["classifier"], torch.nn.Linear)
Exemplo n.º 11
0
    def test_include_name(self):
        """
        Test use of `include_names`.
        """

        oml = OMLNetwork(num_classes=10)
        named_modules = filter_modules(oml, include_names=["adaptation.0"])
        self.assertTrue(len(named_modules) == 1)
        self.assertTrue("adaptation.0" in named_modules)
        self.assertIsInstance(named_modules["adaptation.0"], torch.nn.Linear)
Exemplo n.º 12
0
    def on_init_end(self, args, state, control, model, optimizer=None, **kwargs):
        """Save a list of the sparse modules and initialize the pruning schedule"""

        warmup_steps = self.warmup_steps or self.prune_freq
        self.prune_scheduler = CosineDecayPruneScheduler(
            total_steps=args.max_steps,
            prune_fraction=self.prune_fraction,
            warmup_steps=warmup_steps - 1
        )
        self.sparse_modules = filter_modules(
            model, include_modules=[SparseWeightsBase]
        ).values()
Exemplo n.º 13
0
    def test_get_conv_modules_by_pattern_and_type(self):
        """
        Ensure `include_patterns` and `include_modules` yields the same result
        when they are meant to identify the same params.
        """
        oml = OMLNetwork(num_classes=10)

        include_even_numbers = ["representation.\\d*[02468]"]
        named_modules1 = filter_modules(oml,
                                        include_patterns=include_even_numbers)
        self.assertTrue(
            len(named_modules1) == 7)  # 7 convs and a flatten layer

        include_conv_and_flatten = [torch.nn.Conv2d, torch.nn.Flatten]
        named_modules2 = filter_modules(
            oml, include_modules=include_conv_and_flatten)
        self.assertTrue(len(named_modules2) == 7)

        names1 = list(named_modules1.keys())
        names2 = list(named_modules2.keys())
        self.assertTrue(names1 == names2)
Exemplo n.º 14
0
    def test_get_conv_modules_by_pattern_and_type(self):
        """
        Ensure `include_patterns` and `include_modules` yields the same result
        when they are meant to identify the same params.
        """
        resnet = resnets.resnet50(num_classes=10)

        include_pooling_layers = ["features\\..*pool.*"]
        named_modules1 = filter_modules(
            resnet, include_patterns=include_pooling_layers)
        self.assertEqual(len(named_modules1), 2)

        pooling_layers_types = [
            torch.nn.modules.pooling.AdaptiveAvgPool2d,
            torch.nn.modules.pooling.MaxPool2d,
        ]
        named_modules2 = filter_modules(resnet,
                                        include_modules=pooling_layers_types)
        self.assertEqual(len(named_modules2), 2)

        names1 = list(named_modules1.keys())
        names2 = list(named_modules2.keys())
        self.assertEqual(names1, names2)
Exemplo n.º 15
0
    def sparsify_model(self):
        """
        Sparsify all linear layers in encoder as well as the word embedding layer.
        """

        encoder = self.encoder
        sparsity = self.config.sparsity
        device = self.device

        # Use `getattr` here for backwards compatibility for configs without this param.
        sparsify_all_embeddings = getattr(self.config,
                                          "sparsify_all_embeddings", False)

        def get_sparsity(name):
            if isinstance(sparsity, dict):
                if name in sparsity:
                    return sparsity[name]
                else:
                    raise KeyError(
                        f"Layer {name} not included in sparsity dict.")
            else:
                return sparsity

        # Perform model surgery by replacing the linear layers with `SparseWeights`.
        linear_modules = filter_modules(encoder,
                                        include_modules=[torch.nn.Linear])
        for name, module in linear_modules.items():
            layer_sparsity = get_sparsity("bert.encoder." + name)
            sparse_module = SparseWeights(
                module,
                sparsity=layer_sparsity,
                allow_extremes=True  # this allows the model to start fully dense
            )
            set_module_attr(self.encoder, name, sparse_module.to(device))

        # Replace the embedding layers in a similar fashion.
        if sparsify_all_embeddings:
            embeddings = [
                "word_embeddings", "position_embeddings",
                "token_type_embeddings"
            ]
        else:
            embeddings = ["word_embeddings"]

        for embedding_name in embeddings:
            dense_module = getattr(self.embeddings, embedding_name)
            layer_sparsity = get_sparsity(f"bert.embeddings.{embedding_name}")
            sparse_module = SparseEmbeddings(dense_module,
                                             sparsity=layer_sparsity)
            setattr(self.embeddings, embedding_name, sparse_module.to(device))
Exemplo n.º 16
0
    def sparsify_model(self):
        """
        Sparsify all linear layers in encoder.
        """

        encoder = self.encoder
        sparsity = self.config.sparsity
        device = self.device

        # Perform model surgery by replacing the linear layers with `SparseWeights`.
        linear_modules = filter_modules(encoder,
                                        include_modules=[torch.nn.Linear])
        for name, module in linear_modules.items():
            sparse_module = SparseWeights(module, sparsity=sparsity).to(device)
            set_module_attr(self.encoder, name, sparse_module)
Exemplo n.º 17
0
    def sparsify_model(self):
        """
        Sparsify all linear layers in encoder as well as the word embedding layer.
        """

        encoder = self.encoder
        sparsity = self.config.sparsity
        device = self.device

        # Perform model surgery by replacing the linear layers with `SparseWeights`.
        linear_modules = filter_modules(encoder,
                                        include_modules=[torch.nn.Linear])
        for name, module in linear_modules.items():
            sparse_module = SparseWeights(module, sparsity=sparsity).to(device)
            set_module_attr(self.encoder, name, sparse_module)

        # Replace the embedding layer in a similar fashion.
        dense_embeddings = self.embeddings.word_embeddings
        sparse_embeddings = SparseEmbeddings(dense_embeddings,
                                             sparsity=sparsity)
        self.embeddings.word_embeddings = sparse_embeddings
Exemplo n.º 18
0
def resize_position_embeddings(model, new_seq_length):
    """
    Resizes model's position embeddings matrices if the size of max position embedding
    doesn't match new sequence length.
    (size of position embedding equals size of the attention window)

    :param new_seq_length: Tokenizer sequence length.
    """

    position_embeddings = filter_modules(
        model, include_patterns=[".*position_embeddings.*"])
    for module_name, module in position_embeddings.items():
        original_embed_data = module.weight.data
        max_position_embeddings, embed_hidden_size = original_embed_data.size()
        if max_position_embeddings != new_seq_length:
            new_embed = torch.nn.Embedding(new_seq_length, embed_hidden_size)
            new_embed.weight.data[:, :] = original_embed_data[:
                                                              new_seq_length, :]
            set_module_attr(model, module_name, new_embed)

    return model
    def create_model(self, config, device):
        sample_data = self.get_sample_data(config)
        model = super().create_model(config, device)
        # Process config args
        greedy_infomax_args = config.get("greedy_infomax_args", {})
        gim_hooks_args = greedy_infomax_args.get("greedy_infomax_blocks", {})
        info_estimate_args = greedy_infomax_args.get("info_estimate_args", {})
        patchify_inputs_args = greedy_infomax_args.get("patchify_inputs_args",
                                                       {})

        # Collect information about which modules to apply hooks to
        include_names = gim_hooks_args.pop("include_names", [])
        include_modules = gim_hooks_args.pop("include_modules", [])
        include_patterns = gim_hooks_args.pop("include_patterns", [])
        filter_args = dict(
            include_names=include_names,
            include_modules=include_modules,
            include_patterns=include_patterns,
        )

        # Get named modules for GreedyInfoMaxBlock and BilinearInfo parameters
        named_modules = filter_modules(model, **filter_args)

        # Get the size of the output of each module (for BilinearInfo)
        modules_and_channel_sizes = get_channel_sizes(model, named_modules,
                                                      sample_data)
        # Update the config with the channel sizes
        config["classifier_config"]["model_args"].update(
            in_channels=list(modules_and_channel_sizes.values()))

        n_patches_x, n_patches_y = get_patch_dimensions(
            sample_data, **patchify_inputs_args)

        greedy_infomax_model = GreedyInfoMaxModel(model,
                                                  modules_and_channel_sizes,
                                                  **info_estimate_args,
                                                  n_patches_x=n_patches_x,
                                                  n_patches_y=n_patches_y)

        return greedy_infomax_model
Exemplo n.º 20
0
    def setup_pruning(self, args, model):
        """
        This infers things about the pruning setup such how many many steps to prune
        and how many params to start and end with. Note that we're concerned with
        pruning the parameters within the sparse modules. Thus, any SparseWeightsBase
        will be involved in global pruning. However, the start and end sparsity will
        be measured over all parameters within `model.bert`.
        """

        # Calculate the number of steps in the pruning phase.
        pruning_steps = self.max_steps - self.warmup_steps - self.cooldown_steps
        assert pruning_steps > 0
        assert pruning_steps % self.prune_period == 0

        # Calculate how many times we'll prune.
        self.total_prune_iterations = pruning_steps / self.prune_period + 1

        # Calculate the params that belong to BERT and how many to start and end with.
        total_params, _ = count_nonzero_params(model.bert)
        model_start_params = total_params * (1 - self.start_sparsity)
        model_end_params = total_params * (1 - self.end_sparsity)

        # Get all the sparse modules in BERT.
        sparse_modules = filter_modules(model.bert, include_modules=[SparseWeightsBase])
        self.sparse_modules = list(sparse_modules.values())

        # Calculate the number of params that belong to the non-sparse modules.
        sparse_module_params = 0
        for m in self.sparse_modules:
            sparse_module_params += m.weight.numel()
        non_sparse_module_params = total_params - sparse_module_params

        # Solve for the number of params that will be on. Specifically for the
        # sparse module parameters.
        self.start_on_params = model_start_params - non_sparse_module_params
        self.end_on_params = model_end_params - non_sparse_module_params
Exemplo n.º 21
0
    def test_global_rigl(self):
        """
        Test for globally pruning all sparse modules by their weights and adding back by
        gradients.
        """

        # -----------
        # Init model
        # -----------

        # Make sure there are no random zeros in model params.
        init_all_zero_params(self.model)
        sparsity = calc_sparsity(self.model)
        self.assertEqual(sparsity, 0)

        # Validate initial sparsity after rezeroing the weights.
        self.model.apply(rezero_weights)
        sparsity = calc_sparsity(self.model.bert)
        self.assertTrue(np.isclose(sparsity, 0.4701, atol=1e-4))

        # Get all the SparseWeightsBase modules. These will be pruned.
        sparse_modules = filter_modules(self.model,
                                        include_modules=[SparseWeightsBase])
        sparse_modules = sparse_modules.values()
        self.assertEqual(len(sparse_modules), 7)

        # Validate initial number of off params with sparse modules..
        total_sparse_params = np.sum(
            [m.weight.numel() for m in sparse_modules])
        total_off_mask = np.sum(
            [m.zero_mask.bool().sum() for m in sparse_modules])
        total_off_params = np.sum([(m.weight == 0).sum()
                                   for m in sparse_modules])
        self.assertEqual(total_sparse_params, 168)
        self.assertEqual(total_off_params, 126)
        self.assertEqual(total_off_mask, 126)

        # --------------
        # Prune weights
        # --------------

        num_removed = global_prune_by_abs_weight(sparse_modules,
                                                 prune_fraction=1 / 3)

        self.model.apply(rezero_weights)
        total_off_mask = np.sum(
            [m.zero_mask.bool().sum() for m in sparse_modules])
        total_off_params = np.sum([(m.weight == 0).sum()
                                   for m in sparse_modules])

        self.assertEqual(total_off_mask, 140)
        self.assertEqual(total_off_params, 140)

        # ---------------
        # Regrow weights
        # ---------------

        # Pseudo forward pass to accumulate gradients.
        batch_size = 2
        num_ebeddings = self.config.max_position_embeddings
        attention_mask = torch.ones(batch_size, num_ebeddings).float()
        input_ids = torch.ones(batch_size, num_ebeddings).long()
        token_type_ids = torch.ones(batch_size, num_ebeddings).long()
        labels = torch.ones(batch_size * num_ebeddings).long()

        outputs = self.model(
            attention_mask=attention_mask,
            input_ids=input_ids,
            labels=labels,
            token_type_ids=token_type_ids,
        )
        loss = outputs.loss
        loss.backward()

        # Add weights according to the largest gradients of the model.
        global_add_by_abs_grad(sparse_modules, num_add=num_removed)

        # The new weights are initialized to zero.
        self.model.apply(rezero_weights)
        total_off_mask = np.sum(
            [m.zero_mask.bool().sum() for m in sparse_modules])
        total_off_params = np.sum([(m.weight == 0).sum()
                                   for m in sparse_modules])

        # Validate number of off params after regrowing the weights.
        self.assertEqual(total_off_mask, 126)
        self.assertEqual(total_off_params, 140)

        # Psuedo training step where learning happens on the new zero weights.
        init_all_zero_params(self.model)
        self.model.apply(rezero_weights)

        # Validate number of off params after learning has occurred on new weights.
        total_off_mask = np.sum(
            [m.zero_mask.bool().sum() for m in sparse_modules])
        total_off_params = np.sum([(m.weight == 0).sum()
                                   for m in sparse_modules])

        self.assertEqual(total_off_mask, 126)
        self.assertEqual(total_off_params, 126)
Exemplo n.º 22
0
    def training_step(self, model, inputs):
        """Prune and regrow weights every 'prune_freq' iterations."""

        train_loss = super().training_step(model, inputs)

        if self.state.global_step % self.prune_freq != 0:
            self.prune_scheduler.step()
            return train_loss

        # Retrieve sparse modules (e.g. SparseWeights) after model has been setup for
        # distributed training, if it has.
        if self.sparse_modules is None:
            self.sparse_modules = filter_modules(
                model, include_modules=[SparseWeightsBase]
            ).values()

        # Pre-prune sparsities.
        param_sparsity0, mask_sparsity0 = calc_cumulative_sparsity(self.sparse_modules)

        # Prune weights.
        model.apply(rezero_weights)
        prune_fraction = self.prune_scheduler.get_prune_fraction()
        num_removed = global_prune_by_abs_weight(self.sparse_modules, prune_fraction)
        model.apply(rezero_weights)

        # Post-prune sparsities.
        param_sparsity1, mask_sparsity1 = calc_cumulative_sparsity(self.sparse_modules)

        # Accumulate gradients over one batch.
        self.optimizer.zero_grad()
        train_dataloader = self.callback_handler.train_dataloader
        train_batch = next(iter(train_dataloader))
        inputs_to_device(train_batch, device=self.args.device)
        batch_loss = self.compute_loss(model, train_batch)
        batch_loss.backward()

        # Regrow weights
        num_add = self.prune_scheduler.get_num_add(num_removed)
        global_add_by_abs_grad(self.sparse_modules, num_add)
        self.prune_scheduler.step()

        # Post-grow sparsities.
        param_sparsity2, mask_sparsity2 = calc_cumulative_sparsity(self.sparse_modules)

        # Log pruning stats.
        actual_pruned = param_sparsity1 - param_sparsity0
        actual_pruned_on_params = actual_pruned / (1 - mask_sparsity0)

        logging.info(f"RigLMixin:")
        logging.info(f"Target: remove {prune_fraction} frac of on params")
        logging.info(f"Actual: removed {actual_pruned_on_params} fraction of on params")

        # For now, the logs are very robust to ensure pruning occurs as expected.
        # TODO: Remove non-essential logging.
        logs = dict({
            "rigl/target_pruned_on_params": prune_fraction,
            "rigl/actual_pruned_on_params": actual_pruned_on_params,
            "rigl/target_pruned_all_params": prune_fraction * mask_sparsity0,
            "rigl/actual_pruned_all_params": actual_pruned,
            "rigl/pre_prune_param_sparsity": param_sparsity0,
            "rigl/pre_prune_mask_sparsity": mask_sparsity0,
            "rigl/post_prune_param_sparsity": param_sparsity1,
            "rigl/post_prune_mask_sparsity": mask_sparsity1,
            "rigl/pre_grow_param_sparsity": param_sparsity2,
            "rigl/post_grow_mask_sparsity": mask_sparsity2,
        })
        if wandb.run is not None:
            wandb.log(logs, step=self.state.global_step)

        return train_loss
Exemplo n.º 23
0
def calculate_sparsity_param(sparsity_desired,
                             parameters_desired,
                             experiment,
                             test_sparsity=False):
    """
    :param sparsity_desired: desired sparsity of model
    :param parameters_desired: desired number of on-params;
                               can't be used with sparsity_desired
    :param experiment: name of experiment config with a sparse architecture
    :param test_sparsity: whether to test the calculated sparsity param, this test loads
                          the model and calculates the resulting sparsity.
    """

    # Ensure sparsity_desired or parameters_desired is specified but not both.
    assert not (sparsity_desired is None and parameters_desired is None)
    assert sparsity_desired is not None or parameters_desired is not None

    print(bold("Initializing model... ") + "(this may take a minute)")
    print(f"   experiment: {experiment}")

    # Load and parse model args from config.
    exp_config = CONFIGS[experiment]
    exp_parser = HfArgumentParser(ModelArguments)
    model_args = exp_parser.parse_dict(exp_config)[0]
    model_args = replace(model_args, cache_dir=None)  # enable to run locally
    print(bold("\n\nModel parameters:\n") + pdict(model_args.__dict__))
    print()

    # Initialize model.
    config = init_config(model_args)
    tokenizer = init_tokenizer(model_args)
    model = AutoModelForMaskedLM.from_config(config)
    model.resize_token_embeddings(len(tokenizer))

    print(bold("Calculating target sparsity..."))

    # Get sparse modules and calculate total number of sparsifiable params.
    sparse_modules = filter_modules(model.bert,
                                    include_modules=[SparseWeightsBase])
    sparsifiable_params = 0
    for _, m in sparse_modules.items():
        sparsifiable_params += m.zero_mask.numel()

    # Calculate the total number of params and the needed sparsity.
    total_params, _ = count_nonzero_params(model.bert)

    if parameters_desired is None:
        parameters_desired = total_params * (1 - sparsity_desired)
    elif sparsity_desired is None:
        sparsity_desired = parameters_desired / total_params

    dense_params = total_params - sparsifiable_params
    target_sparsity = 1 - (parameters_desired -
                           dense_params) / sparsifiable_params

    print(f"   sparsity_desired: {sparsity_desired}")
    print(f"   parameters_desired: {parameters_desired}")
    print(f"   sparsifiable_params: {sparsifiable_params}")
    print(f"   total_params: {total_params}")
    print(f"   target_sparsity: {target_sparsity} (set your sparsity to this)")
    print()

    if not test_sparsity:
        return

    print(bold("Testing target sparsity..."))

    # Edit config to use the new sparsity param (sparsity=target_sparsity).
    exp_config["config_kwargs"]["sparsity"] = target_sparsity
    exp_parser = HfArgumentParser(ModelArguments)
    model_args = exp_parser.parse_dict(exp_config)[0]
    model_args = replace(model_args, cache_dir=None)  # remove to run locally

    # Initialize model; this time with the new sparsity param.
    config = init_config(model_args)
    tokenizer = init_tokenizer(model_args)
    model = AutoModelForMaskedLM.from_config(config)
    model.resize_token_embeddings(len(tokenizer))

    # Set all on-weights to one to make sure none are randomly off.
    sparse_modules = filter_modules(model.bert,
                                    include_modules=[SparseWeightsBase])
    for _, m in sparse_modules.items():
        m.weight.data[:] = 1
    model.apply(rezero_weights)  # set off weights to zero.

    resulting_sparsity = calc_model_sparsity(model.bert)
    _, nz_params = count_nonzero_params(model.bert)
    print(
        f"    Resulting sparsity of model.bert using sparsity={target_sparsity}\n"
        f"       actual_sparsity={resulting_sparsity}\n"
        f"       num_nonzero_params={nz_params}\n")
    print(f"    Note this may not be exactly as desired as there are "
          "discrete levels of allowable sparsity")
    print()
Exemplo n.º 24
0
    def training_step(self, model, inputs):
        """Prune and regrow weights every 'prune_freq' iterations."""

        train_loss = super().training_step(model, inputs)

        if self.state.global_step % self.prune_freq != 0:
            self.prune_scheduler.step()
            return train_loss

        # Retrieve sparse modules (e.g. SparseWeights) after model has been setup for
        # distributed training, if it has.
        if self.sparse_modules is None:
            self.sparse_modules = filter_modules(
                model, include_modules=[SparseWeightsBase]).values()
        sparse_modules = self.sparse_modules

        # Pre-prune sparsities (for verbose logging).
        model.apply(rezero_weights)
        if self.verbose_rigl_logging:
            param_sparsity0, mask_sparsity0 = calc_cumulative_sparsity(
                sparse_modules)

        # If prune fraction is 0, say for a warmup step, return and don't prune.
        prune_fraction = self.prune_scheduler.get_prune_fraction()
        if prune_fraction == 0:
            self.prune_scheduler.step()
            return train_loss

        # Prune weights.
        num_removed = global_prune_by_abs_weight(self.sparse_modules,
                                                 prune_fraction)
        model.apply(rezero_weights)

        # Post-prune sparsities (for verbose logging).
        if self.verbose_rigl_logging:
            param_sparsity1, mask_sparsity1 = calc_cumulative_sparsity(
                sparse_modules)

        # Accumulate gradients over one batch.
        self.optimizer.zero_grad()
        train_dataloader = self.callback_handler.train_dataloader
        train_batch = next(iter(train_dataloader))
        inputs_to_device(train_batch, device=self.args.device)
        batch_loss = self.compute_loss(model, train_batch)
        batch_loss.backward()

        # Regrow weights
        num_add = self.prune_scheduler.get_num_add(num_removed)
        global_add_by_abs_grad(self.sparse_modules, num_add)
        self.prune_scheduler.step()

        logs = dict({
            "rigl/target_pruned_on_params": prune_fraction,
        })

        # Post-grow sparsities (for verbose logging).
        if self.verbose_rigl_logging:
            param_sparsity2, mask_sparsity2 = calc_cumulative_sparsity(
                sparse_modules)

            # Log pruning stats.
            actual_pruned = param_sparsity1 - param_sparsity0
            actual_pruned_on_params = actual_pruned / (1 - mask_sparsity0)

            logging.debug(f"Target: remove {prune_fraction} frac of on params")
            logging.debug(f"Actual: removed {actual_pruned_on_params} "
                          "fraction of on params")

            # These are logs are very robust to ensure the actual percentage and count
            # of pruned-params match the target amounts.
            logs = dict({
                "rigl/actual_pruned_on_params": actual_pruned_on_params,
                "rigl/target_pruned_all_params":
                prune_fraction * mask_sparsity0,
                "rigl/actual_pruned_all_params": actual_pruned,
                "rigl/pre_prune_param_sparsity": param_sparsity0,
                "rigl/pre_prune_mask_sparsity": mask_sparsity0,
                "rigl/post_prune_param_sparsity": param_sparsity1,
                "rigl/post_prune_mask_sparsity": mask_sparsity1,
                "rigl/pre_grow_param_sparsity": param_sparsity2,
                "rigl/post_grow_mask_sparsity": mask_sparsity2,
            })

        if wandb.run is not None:
            wandb.log(logs, commit=False)

        return train_loss