def setup_experiment(self, config):
        super().setup_experiment(config)

        # Process config args
        gradient_metrics_args = config.get("gradient_metrics_args", {})
        self.gradient_metrics_plot_freq, self.gradient_metrics_filter_args, \
            self.gradient_metrics_max_samples, self.gradient_metrics = \
            process_gradient_metrics_args(
                gradient_metrics_args)

        # Register hook for tracking hidden activations
        named_modules = filter_modules(self.model,
                                       **self.gradient_metrics_filter_args)
        hook_args = dict(
            max_samples_to_track=self.gradient_metrics_max_samples)
        self.gradient_metric_hooks = ModelHookManager(named_modules,
                                                      TrackGradientsHook,
                                                      hook_type="backward",
                                                      hook_args=hook_args)

        # Log the names of the modules being tracked
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(f"Tracking gradients for modules: {tracked_names}")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.gradient_metric_targets = torch.tensor([]).long()
Exemple #2
0
    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Process config args
        ha_args = config.get("plot_hidden_activations_args", {})
        ha_plot_freq, filter_args, ha_max_samples = self.process_ha_args(
            ha_args)

        self.ha_plot_freq = ha_plot_freq
        self.ha_max_samples = ha_max_samples

        # Register hook for tracking hidden activations
        named_modules = filter_modules(self.model, **filter_args)
        hook_args = dict(max_samples_to_track=self.ha_max_samples)
        self.ha_hook = ModelHookManager(named_modules,
                                        TrackHiddenActivationsHook,
                                        hook_args=hook_args)

        # Log the names of the modules being tracked
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(
            f"Tracking hidden activations of modules: {tracked_names}")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.ha_targets = torch.tensor([]).long()
    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Unpack, validate, and process the default arguments.
        metric_args = config.get("plot_dendrite_metrics_args", {})
        self.metric_args, filter_args, max_samples = self.process_args(metric_args)

        # The maximum 'max_samples_to_track' will be tracked by the all the hooks.
        self.max_samples_to_track = max_samples
        hook_args = dict(max_samples_to_track=self.max_samples_to_track)

        # The 'filter_args' specify which modules to track.
        named_modules = filter_modules(self.model, **filter_args)
        self.dendrite_hooks = ModelHookManager(named_modules,
                                               ApplyDendritesHook,
                                               hook_args=hook_args)

        # The hook is specifically made for `ApplyDendritesBase` modules.
        for module in named_modules.values():
            assert isinstance(module, ApplyDendritesBase)

        # Log the names of the modules being tracked, and warn when there's none.
        names = list(named_modules.keys())
        self.logger.info(f"Dendrite Metric Setup: Tracking modules: {names}")
        if len(names) == 0:
            self.logger.warning("Dendrite Metric Setup: "
                                "No modules found for tracking.")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.targets = torch.tensor([]).long()
Exemple #4
0
    def setup_experiment(self, config):
        """
        Register forward hooks to track the input and output sparsities.

        The subsets of modules to track are defined via `include_*` params. See
        `filter_modules`_ for further details.

        :param config:
            - track_input_sparsity_args:
                - include_modules: a list of module types to track
                - include_names: a list of module names to track e.g. "features.stem"
                - include_patterns: a list of regex patterns to compare to the names;
                                    for instance, all feature parameters in ResNet can
                                    be included through "features.*"
            - track_output_sparsity_args: same as track_input_sparsity_args

        .. _filter_modules: nupic.research.frameworks.pytorch.model_utils.filter_modules
        """
        super().setup_experiment(config)

        # Register hooks to track input sparsities.
        input_tracking_args = config.get("track_input_sparsity_args", {})
        named_modules = filter_modules(self.model, **input_tracking_args)
        self.input_hook_manager = ModelHookManager(named_modules,
                                                   TrackSparsityHook)

        # Log the names of the modules with tracked inputs.
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(
            f"Tracking input sparsity of modules: {tracked_names}")

        # Register hooks to track output sparsities.
        output_tracking_args = config.get("track_output_sparsity_args", {})
        named_modules = filter_modules(self.model, **output_tracking_args)
        self.output_hook_manager = ModelHookManager(named_modules,
                                                    TrackSparsityHook)

        # Log the names of the modules with tracked outputs.
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(
            f"Tracking output sparsity of modules: {tracked_names}")

        # Throw a warning when no modules are being tracked.
        if not input_tracking_args and not output_tracking_args:
            self.logger.warning(
                "No modules specified to track input/output sparsity.")
Exemple #5
0
    def clone_model(self):
        """
        Clones self.model and register forward hooks on new model.
        """
        model = super().clone_model()

        # Register hooks to track input sparsities.
        track_modules = get_modules_by_names(model, self.track_input_of_names)
        self.input_hook_manager = ModelHookManager(track_modules,
                                                   TrackSparsityHook)

        # Register hooks to track output sparsities.
        track_modules = get_modules_by_names(model, self.track_output_of_names)
        self.output_hook_manager = ModelHookManager(track_modules,
                                                    TrackSparsityHook)

        # Enable tracking the sparsity statistics.
        self.input_hook_manager.start_tracking()
        self.output_hook_manager.start_tracking()

        return model
Exemple #6
0
class PlotHiddenActivations(metaclass=abc.ABCMeta):
    """
    Mixin for creating custom plots of a module's output/hidden activations.

    :param config: a dict containing the following

        - plot_hidden_activations_args: a dict containing the following

            - include_modules: (optional) a list of module types to track
            - include_names: (optional) a list of module names to track e.g.
                             "features.stem"
            - include_patterns: (optional) a list of regex patterns to compare to the
                                names; for instance, all feature parameters in ResNet
                                can be included through "features.*"
            - plot_freq: (optional) how often to create the plot, measured in training
                         iterations; defaults to 1
            - max_samples_to_plot: (optional) how many of samples to use for plotting;
                                   only the newest will be used; defaults to 5000

    Example config:
    ```
    config=dict(
        plot_hidden_activations_args=dict(
            include_modules=[torch.nn.Linear, KWinners],
            plot_freq=1,
            max_samples_to_plot=2000
        )
    )
    ```
    """
    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Process config args
        ha_args = config.get("plot_hidden_activations_args", {})
        ha_plot_freq, filter_args, ha_max_samples = self.process_ha_args(
            ha_args)

        self.ha_plot_freq = ha_plot_freq
        self.ha_max_samples = ha_max_samples

        # Register hook for tracking hidden activations
        named_modules = filter_modules(self.model, **filter_args)
        hook_args = dict(max_samples_to_track=self.ha_max_samples)
        self.ha_hook = ModelHookManager(named_modules,
                                        TrackHiddenActivationsHook,
                                        hook_args=hook_args)

        # Log the names of the modules being tracked
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(
            f"Tracking hidden activations of modules: {tracked_names}")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.ha_targets = torch.tensor([]).long()

    def process_ha_args(self, ha_args):

        ha_args = deepcopy(ha_args)

        # Collect information about which modules to apply hooks to
        include_names = ha_args.pop("include_names", [])
        include_modules = ha_args.pop("include_modules", [])
        include_patterns = ha_args.pop("include_patterns", [])
        filter_args = dict(
            include_names=include_names,
            include_modules=include_modules,
            include_patterns=include_patterns,
        )

        # Others args
        plot_freq = ha_args.get("plot_freq", 1)
        max_samples = ha_args.get("max_samples_to_plot", 1000)

        assert isinstance(plot_freq, int)
        assert isinstance(max_samples, int)
        assert plot_freq > 0
        assert max_samples > 0

        return plot_freq, filter_args, max_samples

    def run_epoch(self):
        """
        This runs the epoch with the hooks in tracking mode. The resulting hidden
        activations collected by the `TrackHiddenActivationsHook` object is plotted by
        calling a plotting function.
        """

        # Run the epoch with tracking enabled.
        with self.ha_hook:
            results = super().run_epoch()

        # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here.
        iteration = self.current_epoch + 1

        # Create visualization, and update results dict.
        if iteration % self.ha_plot_freq == 0:

            for name, _, ha in self.ha_hook.get_statistics():

                visual = plot_hidden_activations_by_unit(ha, self.ha_targets)
                results.update({f"hidden_activations/{name}": visual})
                results.update(
                    {f"_activations/{name}": ha.detach().cpu().numpy()})

        return results

    def error_loss(self, output, target, reduction="mean"):
        """
        This computes the loss and then saves the targets computed on this loss. This
        mixin assumes these targets correspond, in a 1:1 fashion, to the samples seen
        in the forward pass.
        """
        loss = super().error_loss(output, target, reduction=reduction)
        if self.ha_hook.tracking:

            # Targets were initialized on the cpu which could differ from the
            # targets collected during the forward pass.
            self.ha_targets = self.ha_targets.to(target.device)

            # Concatenate and discard the older targets.
            self.ha_targets = torch.cat([target, self.ha_targets], dim=0)
            self.ha_targets = self.ha_targets[:self.ha_max_samples]

        return loss
class PlotRepresentationOverlap(metaclass=abc.ABCMeta):
    """
    Mixin for plotting a module's inter- and intra-class representation overlap.

    :param config: a dict containing the following

        - plot_representation_overlap_args: a dict containing the following

            - include_modules: (optional) a list of module types to track
            - include_names: (optional) a list of module names to track e.g.
                             "features.stem"
            - include_patterns: (optional) a list of regex patterns to compare to the
                                names; for instance, all feature parameters in ResNet
                                can be included through "features.*"
            - plot_freq: (optional) how often to create the plot, measured in training
                         iterations; defaults to 1
            - plot_args: (optional) either a dictionary or a callable that takes no
                             arguments and returns a dictionary; for instance this may
                             be used to return a random sample of integers specifying
                             units to plot; called only once at setup
            - max_samples_to_plot: (optional) how many of samples to use for plotting;
                                   only the newest will be used; defaults to 5000

    Example config:
    ```
    config=dict(
        plot_representation_overlap_args=dict(
            include_modules=[torch.nn.ReLU, KWinners],
            plot_freq=1,
            plot_args=dict(annotate=False),
            max_samples_to_plot=2000
        )
    )
    ```
    """
    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Process config args
        ro_args = config.get("plot_representation_overlap_args", {})
        ro_plot_freq, filter_args, ro_max_samples = process_ro_args(ro_args)

        self.ro_plot_freq = ro_plot_freq
        self.ro_max_samples = ro_max_samples

        # Register hook for tracking hidden activations - useful for representation
        # overlap
        named_modules = filter_modules(self.model, **filter_args)
        hook_args = dict(max_samples_to_track=self.ro_max_samples)
        self.ro_hook = ModelHookManager(named_modules,
                                        TrackHiddenActivationsHook,
                                        hook_args=hook_args)

        # Log the names of the modules being tracked
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(
            f"Tracking representation overlap of modules: {tracked_names}")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.ro_targets = torch.tensor([]).long()

    def run_epoch(self):

        # Run the epoch with tracking enabled.
        with self.ro_hook:
            results = super().run_epoch()

        # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here.
        iteration = self.current_epoch + 1

        # Create visualization, and update results dict.
        if iteration % self.ro_plot_freq == 0:

            for name, _, activations in self.ro_hook.get_statistics():

                metric_str = "representation_overlap"

                # Representation overlap matrix
                visual = plot_representation_overlap_matrix(
                    activations, self.ro_targets)
                results.update({f"{metric_str}_matrix/{name}": visual})

                # Representation overlap distributions:
                #  * inter-class pairs
                #  * intra-class pairs
                plots = plot_representation_overlap_distributions(
                    activations, self.ro_targets)
                inter_overlaps, intra_overlaps = plots
                results.update(
                    {f"{metric_str}_interclass/{name}": inter_overlaps})
                results.update(
                    {f"{metric_str}_intraclass/{name}": intra_overlaps})

        return results

    def error_loss(self, output, target, reduction="mean"):
        """
        This computes the loss and then saves the targets computed on this loss. This
        mixin assumes these targets correspond, in a 1:1 fashion, to the samples seen
        in the forward pass.
        """
        loss = super().error_loss(output, target, reduction=reduction)
        if self.ro_hook.tracking:

            # Targets were initialized on the cpu which could differ from the
            # targets collected during the forward pass.
            self.ro_targets = self.ro_targets.to(target.device)

            # Concatenate and discard the older targets.
            self.ro_targets = torch.cat([target, self.ro_targets], dim=0)
            self.ro_targets = self.ro_targets[:self.ro_max_samples]

        return loss
Exemple #8
0
class TrackRepresentationSparsityMetaCL:
    """
    This mixin tracks and reports the average sparsities observed in the model's
    representations. This is similar to `TrackRepresentationSparsity` above, however,
    the functionality is reformatted for the MetaCL setting. Specifically, the hooks
    must be re-registered every epoch on the cloned model prior to the fast steps.

    Tracked statistics are returned from `run_epoch` and are reset before each
    subsequent epoch. The default is to track none of the modules; only those specified
    can be tracked.
    """
    def setup_experiment(self, config):
        """
        Save the names of which modules to track. The forward hooks are registered in
        `clone_model` at the start of every epoch of training.

        The subsets of modules to track are defined via `include_*` params. See
        `filter_modules`_ for further details.

        :param config:
            - track_input_sparsity_args:
                - include_modules: a list of module types to track
                - include_names: a list of module names to track e.g. "features.stem"
                - include_patterns: a list of regex patterns to compare to the names;
                                    for instance, all feature parameters in ResNet can
                                    be included through "features.*"
            - track_output_sparsity_args: same as track_input_sparsity_args
        """
        super().setup_experiment(config)

        # The hooks managers will be initialized for each cloned model, prior to the
        # fast steps. For now, just the names of the modules to track will be saved.
        self.output_hook_manager = None
        self.input_hook_manager = None

        # Save the names of the modules to tracked inputs.
        input_tracking_args = config.get("track_input_sparsity_args", {})
        named_modules = filter_modules(self.model, **input_tracking_args)
        self.track_input_of_names = list(named_modules.keys())

        # Log the names of the modules with tracked inputs.
        tracked_names = pformat(self.track_input_of_names)
        self.logger.info(
            f"Tracking input sparsity of modules: {tracked_names}")

        # Save the names of the modules to tracked outputs.
        output_tracking_args = config.get("track_output_sparsity_args", {})
        named_modules = filter_modules(self.model, **output_tracking_args)
        self.track_output_of_names = list(named_modules.keys())

        # Log the names of the modules with tracked outputs.
        tracked_names = pformat(self.track_output_of_names)
        self.logger.info(
            f"Tracking output sparsity of modules: {tracked_names}")

        # Throw a warning when no modules are being tracked.
        if not input_tracking_args and not output_tracking_args:
            self.logger.warning(
                "No modules specified to track input/output sparsity.")

    def run_epoch(self):
        """Run one epoch and log the observed sparsity statistics."""

        # The start of super's run_epoch, will call `clone_model` to initiate tracking.
        ret = super().run_epoch()

        # Log sparsity statistics collected from the input and output hooks.
        update_results_dict(ret, self.input_hook_manager,
                            self.output_hook_manager)

        return ret

    def clone_model(self):
        """
        Clones self.model and register forward hooks on new model.
        """
        model = super().clone_model()

        # Register hooks to track input sparsities.
        track_modules = get_modules_by_names(model, self.track_input_of_names)
        self.input_hook_manager = ModelHookManager(track_modules,
                                                   TrackSparsityHook)

        # Register hooks to track output sparsities.
        track_modules = get_modules_by_names(model, self.track_output_of_names)
        self.output_hook_manager = ModelHookManager(track_modules,
                                                    TrackSparsityHook)

        # Enable tracking the sparsity statistics.
        self.input_hook_manager.start_tracking()
        self.output_hook_manager.start_tracking()

        return model

    @classmethod
    def get_execution_order(cls):
        # TODO: Update eo.
        eo = super().get_execution_order()
        mixin = "TrackRepresentationSparsityMetaCL: "
        eo["setup_experiment"].append(
            mixin + "Save the names of which modules to track.")
        eo["clone_model"].append(mixin + "Register hooks to the cloned model.")
        eo["run_epoch"].append(
            mixin +
            "Calculate and log sparsity statistics of representations.")
        return eo
class PlotContextSignal(metaclass=abc.ABCMeta):
    """
    Mixin for creating plots of the context vectors used to modulate dendritic
    networks.

    :param config: a dict containing the following

        - plot_context_args: a dict containing the following

            - include_modules: (optional) a list of module types to track
            - include_names: (optional) a list of module names to track e.g.
                             "features.stem"
            - include_patterns: (optional) a list of regex patterns to compare to the
                                names; for instance, all feature parameters in ResNet
                                can be included through "features.*"
            - plot_freq: (optional) how often to create the plot, measured in training
                         iterations; defaults to 1
            - max_samples_to_plot: (optional) how many of samples to use for plotting;
                                   only the newest will be used; defaults to 5000

    Example config:
    ```
    config=dict(
        plot_context_args=dict(
            include_names=["context_net"],
            plot_freq=1,
            max_samples_to_plot=2000
        ),
    )
    ```
    where `context_net` is assumed to refer to the module whose outputs are context
    vectors used to modulate a dendritic network.
    """

    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Process config args
        context_args = config.get("plot_context_args", {})
        context_plot_freq, filter_args, max_samples = process_context_args(context_args)

        self.context_plot_freq = context_plot_freq
        self.context_max_samples = max_samples

        # Register hook for tracking context vectors
        named_modules = filter_modules(self.model, **filter_args)
        hook_args = dict(max_samples_to_track=self.context_max_samples)
        self.context_hook = ModelHookManager(named_modules,
                                             TrackHiddenActivationsHook,
                                             hook_args=hook_args)

        # Log the names of the modules being tracked
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(f"Tracking context signals output from: {tracked_names}")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.context_targets = torch.tensor([]).long()

    def run_epoch(self):
        """
        This runs the epoch with the hooks in tracking mode. The resulting context
        vectors collected by the `TrackHiddenActivationsHook` object are plotted by
        calling a plotting function.
        """

        # Run the epoch with tracking enabled.
        with self.context_hook:
            results = super().run_epoch()

        # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here.
        iteration = self.current_epoch + 1

        # Create visualization, and update results dict.
        if iteration % self.context_plot_freq == 0:

            for name, _, contexts in self.context_hook.get_statistics():

                visual = plot_contexts_by_class(contexts, self.context_targets)
                results.update({f"contexts/{name}": visual})

        return results

    def error_loss(self, output, target, reduction="mean"):
        """
        This computes the loss and then saves the targets computed on this loss. This
        mixin assumes these targets correspond, in a 1:1 fashion, to the samples seen
        in the forward pass.
        """
        loss = super().error_loss(output, target, reduction=reduction)
        if self.context_hook.tracking:

            # Targets were initialized on the cpu which could differ from the
            # targets collected during the forward pass.
            self.context_targets = self.context_targets.to(target.device)

            # Concatenate and discard the older targets.
            self.context_targets = torch.cat([target, self.context_targets], dim=0)
            self.context_targets = self.context_targets[:self.context_max_samples]

        return loss
class GradientMetrics(object):
    """
    Mixin for tracking and plotting module gradient metrics during training.


    :param config: a dict containing the following

        - gradient_metrics_args: a dict containing the following
            - include_modules: (optional) a list of module types to track
            - include_names: (optional) a list of module names to track e.g.
                             "features.stem"
            - include_patterns: (optional) a list of regex patterns to compare to the
                                names; for instance, all feature parameters in ResNet
                                can be included through "features.*"
            - plot_freq: (optional) how often to create the plot, measured in training
                         iterations; defaults to 1
            - metrics: a list of metrics options from ["cosine", "dot", "pearson"];
                       defaults to ["cosine",]
            - gradient_values: (optional) one of "real", "sign", "mask".
                "real" corresponds to the real values of the gradients,
                "sign" corresponds to collecting the sign of the gradients,
                "mask" results in a binary mask corresponding to nonzero gradients;
                 defaults to "real"
            - max_samples_to_track: (optional) how many of samples to use for plotting;
                                   only the newest will be used; defaults to 100

    Example config:
    ```
    config=dict(
        gradient_metrics_args=dict(
            include_modules=[torch.nn.Linear, KWinners],
            plot_freq=1,
            max_samples_to_track=150,
            metrics=["dot", "pearson"],
            gradient_values="mask",

            metric1 = "mask/dot"
            metric2 = "sign/pearson"
        )
    )
    ```
    """
    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Process config args
        gradient_metrics_args = config.get("gradient_metrics_args", {})
        self.gradient_metrics_plot_freq, self.gradient_metrics_filter_args, \
            self.gradient_metrics_max_samples, self.gradient_metrics = \
            process_gradient_metrics_args(
                gradient_metrics_args)

        # Register hook for tracking hidden activations
        named_modules = filter_modules(self.model,
                                       **self.gradient_metrics_filter_args)
        hook_args = dict(
            max_samples_to_track=self.gradient_metrics_max_samples)
        self.gradient_metric_hooks = ModelHookManager(named_modules,
                                                      TrackGradientsHook,
                                                      hook_type="backward",
                                                      hook_args=hook_args)

        # Log the names of the modules being tracked
        tracked_names = pformat(list(named_modules.keys()))
        self.logger.info(f"Tracking gradients for modules: {tracked_names}")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.gradient_metric_targets = torch.tensor([]).long()

    def run_epoch(self):
        """
        This runs the epoch with the hooks in tracking mode. The resulting gradients
        collected by the `TrackGradientsHook` object are plotted by calling a
        plotting function.
        """

        # Run the epoch with tracking enabled.
        with self.gradient_metric_hooks:
            results = super().run_epoch()

        # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here.
        iteration = self.current_epoch + 1

        # Calculate metrics, create visualization, and update results dict.
        if iteration % self.gradient_metrics_plot_freq == 0:
            gradient_stats = self.gradient_metric_hooks.get_statistics()
            gradient_metrics_stats = self.calculate_gradient_metrics_stats(
                gradient_stats)
            gradient_metric_heatmaps = self.plot_gradient_metric_heatmaps(
                gradient_metrics_stats)
            for (name, _, gradient_metric, gradient_value, _, figure) in \
                    gradient_metric_heatmaps:
                results.update(
                    {f"{gradient_metric}/{gradient_value}/{name}": figure})
        return results

    def calculate_gradient_metrics_stats(self, gradients_stats):
        """
        This function calculates statistics given the gradients_stats which are being
        tracked by the TrackGradients backwards hook.

        This function accesses self.metrics, which is alsa a list of tuples. Each
        tuple is the combination of a metric function ("cosine", "dot", "pearson)
        and a gradient transformation ("real", "sign", "mask"). By default,
        the statistics that are calculated are ("cosine", "mask") and ("cosine",
        "real"). ("cosine", "mask") corresponds to the overlap between two different
        gradients, and ("cosine", "real") corresponds to the standard cosine
        similarity between two gradients.

        Args:
            gradients_stats: A list of tuples, each of which contains a named module
            and its gradients

        Returns:
            A list of tuples, each of which contains a named module, the statistics
            calculated from its gradients, and the metric function/transformation
            used on the gradients
        """
        all_stats = []
        for (name, module, gradients) in gradients_stats:
            for gradient_metric, gradient_value in self.gradient_metrics:
                # apply gradient value transformation if necessary
                if gradient_value == "sign":
                    gradients = torch.sign(gradients)
                elif gradient_value == "mask":
                    gradients = torch.abs(torch.sign(gradients))
                # calculate metric function on transformed gradients
                if gradient_metric == "cosine":
                    stats = [
                        torch.cosine_similarity(x, y, dim=0)
                        if not torch.equal(x, y) else 0 for x in gradients
                        for y in gradients
                    ]
                elif gradient_metric == "dot":
                    stats = [
                        x.dot(y) if not torch.equal(x, y) else 0
                        for x in gradients for y in gradients
                    ]
                elif gradient_metric == "pearson":
                    stats = [
                        torch.cosine_similarity(
                            x - x.mean(), y -
                            y.mean(), dim=0) if not torch.equal(x, y) else 0
                        for x in gradients for y in gradients
                    ]
                stats = torch.tensor(stats)
                gradient_dim = len(gradients)
                stats = stats.view(gradient_dim, gradient_dim)
                all_stats.append(
                    (name, module, gradient_metric, gradient_value, stats))
        return all_stats

    def plot_gradient_metric_heatmaps(self, gradient_metrics_stats):
        order_by_class = torch.argsort(self.gradient_metric_targets)
        sorted_gradient_metric_targets = self.gradient_metric_targets[
            order_by_class]
        class_change_indices = \
            (sorted_gradient_metric_targets - sorted_gradient_metric_targets.roll(
                1)).nonzero(as_tuple=True)[0].cpu()
        class_labels = [
            int(_x) for _x in sorted_gradient_metric_targets.unique()
        ]
        class_change_indices_right = class_change_indices.roll(-1).cpu()
        class_change_indices_right[-1] = len(sorted_gradient_metric_targets)
        tick_locations = (class_change_indices +
                          class_change_indices_right) / 2.0 - 0.5
        tick_locations = tick_locations.cpu()

        stats_and_figures = []
        for (name, module, gradient_metric, gradient_value, stats) in \
                gradient_metrics_stats:
            stats = stats[order_by_class, :][:, order_by_class]
            ax = plt.gca()
            max_val = np.abs(stats).max()
            img = ax.imshow(stats, cmap="bwr", vmin=-max_val, vmax=max_val)
            ax.set_xlabel("class")
            ax.set_ylabel("class")
            for idx in class_change_indices:
                ax.axvline(idx - 0.5, color="black")
                ax.axhline(idx - 0.5, color="black")
            ax.set_xticks(tick_locations)
            ax.set_yticks(tick_locations)
            ax.set_xticklabels(class_labels)
            ax.set_yticklabels(class_labels)
            ax.set_title(f"{gradient_metric}:{gradient_value}:{name}")
            plt.colorbar(img, ax=ax)
            plt.tight_layout()
            figure = plt.gcf()
            stats_and_figures.append(
                (name, module, gradient_metric, gradient_value, stats, figure))
        return stats_and_figures

    def error_loss(self, output, target, reduction="mean"):
        """
        This computes the loss and then saves the targets computed on this loss. This
        mixin assumes these targets correspond, in a 1:1 fashion, to the samples seen
        in the forward pass.
        """
        loss = super().error_loss(output, target, reduction=reduction)
        if self.gradient_metric_hooks.tracking:

            # Targets were initialized on the cpu which could differ from the
            # targets collected during the forward pass.
            self.gradient_metric_targets = self.gradient_metric_targets.to(
                target.device)

            # Concatenate and discard the older targets.
            self.gradient_metric_targets = torch.cat(
                [target, self.gradient_metric_targets], dim=0)
            self.gradient_metric_targets = self.gradient_metric_targets[:self.
                                                                        gradient_metrics_max_samples]

        return loss
class PlotDendriteMetrics(metaclass=abc.ABCMeta):
    """
    This is a mixin for creating custom plots of metrics for
    apply-dendrite modules (those of type `ApplyDendritesBase`_). The user defines and
    gives a plotting function which is then called on the following arguments

        - dendrite_activations: the input activations passed to the apply-dendrites
                                module; these are meant to be the output of a
                                `DendriteSegments` module; they will be of shape
                                batch_size x num_units x num_segments
        - winning_mask: the mask of the winning segments, those chosen by the
                        apply-dendrites modules; this will be of shape
                        batch_size x num_units
        - targets: the targets that correspond to each sample in the batch

    Plots can be configured to use fewer samples (helpful for plotting a small batches
    of individual samples) and to plot every so many epochs (so that training isn't
    slowed down too much). Whenever a plot is made, the raw data used to create it is
    saved so it may be reproduced and edited off-line.

    .. warning:: When using this mixin with Ray, be careful to have 'plot_func' return
                 an object that can be logged. Often, Ray will attempt to create a
                 deepcopy prior to logging which can't be done on most plots. Try
                 wrapping the plot as done in `prep_plot_for_wandb`_.

    .. _ApplyDendritesBase: nupic/reasearch/frameworks/dendrites/modules
    .. _prep_plot_for_wandb: nupic/reasearch/frameworks/wandb/ray_wandb


    :param config: a dict containing the following

        - plot_dendrite_metrics_args: a dict containing the following

            - include_modules: (optional) a list of module types to track; defaults to
                                          include all `ApplyDendritesBase` modules if no
                                          other `include_*` arguments are given.
            - include_names: (optional) a list of module names to track e.g.
                             "features.stem"
            - include_patterns: (optional) a list of regex patterns to compare to the
                                names; for instance, all feature parameters in ResNet
                                can be included through "features.*"

            <insert any plot name here>: This can be any string and maps to a dictionary
                                         of the plot arguments below. The
                                         resulting plot will be logged under
                                         "<plot_name>/<module_name>" in the results
                                         dictionary.

                - plot_func: the function called for plotting; must take three
                             arguments: 'dendrite_activations', 'winning_mask', and
                             'targets' (see above)
                - plot_freq: (optional) how often to create the plot, measured in
                             training iterations; defaults to 1
                - plot_args: (optional) either a dictionary or a callable that takes no
                             arguments and returns a dictionary; for instance this may
                             be used to return a random sample of integers specifying
                             units to plot; called only once at setup
                - max_samples_to_plot: (optional) how many of samples to use for
                                       plotting; only the newest will be used;
                                       defaults to 1000

    Example config:
    ```
    config=dict(
        plot_dendrite_metrics_args=dict(
            include_modules=[DendriticGate1d],
            mean_selected=dict(
                plot_func=plot_mean_selected_activations,
            )
        )
    )
    ```
    """

    def setup_experiment(self, config):
        super().setup_experiment(config)

        # Unpack, validate, and process the default arguments.
        metric_args = config.get("plot_dendrite_metrics_args", {})
        self.metric_args, filter_args, max_samples = self.process_args(metric_args)

        # The maximum 'max_samples_to_track' will be tracked by the all the hooks.
        self.max_samples_to_track = max_samples
        hook_args = dict(max_samples_to_track=self.max_samples_to_track)

        # The 'filter_args' specify which modules to track.
        named_modules = filter_modules(self.model, **filter_args)
        self.dendrite_hooks = ModelHookManager(named_modules,
                                               ApplyDendritesHook,
                                               hook_args=hook_args)

        # The hook is specifically made for `ApplyDendritesBase` modules.
        for module in named_modules.values():
            assert isinstance(module, ApplyDendritesBase)

        # Log the names of the modules being tracked, and warn when there's none.
        names = list(named_modules.keys())
        self.logger.info(f"Dendrite Metric Setup: Tracking modules: {names}")
        if len(names) == 0:
            self.logger.warning("Dendrite Metric Setup: "
                                "No modules found for tracking.")

        # The targets will be collected in `self.error_loss` in a 1:1 fashion
        # to the tensors being collected by the hooks.
        self.targets = torch.tensor([]).long()

    def process_args(self, metric_args):

        metric_args = deepcopy(metric_args)

        # Remove and collect information about which modules to track.
        include_names = metric_args.pop("include_names", [])
        include_modules = metric_args.pop("include_modules", [])
        include_patterns = metric_args.pop("include_patterns", [])

        # Default to track all `ApplyDendritesBase` modules.
        if len(include_names) == len(include_modules) == len(include_patterns) == 0:
            include_modules = [ApplyDendritesBase]

        filter_args = dict(
            include_names=include_names,
            include_modules=include_modules,
            include_patterns=include_patterns,
        )

        # Gather and validate the metric arguments. The max of the 'max_samples_to_plot'
        # will be saved to dictate how many samples will be tracked by the hooks.
        all_max_num_samples = []
        new_metric_args = {}
        for metric_name, plotting_args in metric_args.items():

            plot_func = plotting_args.get("plot_func", None)
            plot_freq = plotting_args.get("plot_freq", 1)
            plot_args = plotting_args.get("plot_args", {})
            max_samples_to_plot = plotting_args.get("max_samples_to_plot", 1000)

            assert callable(plot_func)
            assert isinstance(plot_freq, int)
            assert isinstance(max_samples_to_plot, int)
            assert plot_freq > 0
            assert max_samples_to_plot > 0

            # The arguments may be given as a callable; useful for sampling random
            # values that dictate plotting to, say, only plot a subset of units
            if callable(plot_args):
                plot_args = plot_args()
            assert isinstance(plot_args, dict)

            new_metric_args[metric_name] = dict(
                plot_func=plot_func,
                plot_freq=plot_freq,
                plot_args=plot_args,
                max_samples_to_plot=max_samples_to_plot,
            )

            all_max_num_samples.append(max_samples_to_plot)

        max_samples_to_plot = max(all_max_num_samples)
        return new_metric_args, filter_args, max_samples_to_plot

    def run_epoch(self):
        """
        This runs the epoch with the hooks in tracking mode. The resulting 'activations'
        and 'winning_masks' collected by these hooks are plotted via each 'plot_func'
        along with their corresponding targets.
        """

        # Run the epoch with tracking enabled.
        with self.dendrite_hooks:
            results = super().run_epoch()

        # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here.
        iteration = self.current_epoch - 1

        # Gather and plot the statistics.
        for name, _, activations, winners in self.dendrite_hooks.get_statistics():

            if len(activations) == 0 or len(winners) == 0:
                self.logger.warning(f"Skipping plots for module '{name}';"
                                    " no data could be collected.")
                continue

            # Keep track of whether a plot is made below. If so, save the raw data.
            plot_made = False

            # Each 'plot_func' will be applied to each module being tracked.
            for metric_name, plotting_args in self.metric_args.items():

                # All of the defaults were set in `process_args`.
                plot_func = plotting_args["plot_func"]
                plot_freq = plotting_args["plot_freq"]
                plot_args = plotting_args["plot_args"]
                max_samples_to_plot = plotting_args["max_samples_to_plot"]

                if iteration % plot_freq != 0:
                    continue

                # Only use up the the max number of samples for plotting.
                targets = self.targets[:max_samples_to_plot]
                activations = activations[:max_samples_to_plot]
                winners = winners[:max_samples_to_plot]

                # Call and log the results of the plot function.
                # Here, "{name}" is the name of the module.
                visual = plot_func(activations, winners, targets, **plot_args)
                results.update({f"{metric_name}/{name}": visual})
                plot_made = True

            # Log the raw data.
            if plot_made:
                targets = self.targets[:self.max_samples_to_track].cpu().numpy()
                activations = activations[:self.max_samples_to_track].cpu().numpy()
                winners = winners[:self.max_samples_to_track].cpu().numpy()
                results.update({f"targets/{name}": targets})
                results.update({f"dendrite_activations/{name}": activations})
                results.update({f"winning_mask/{name}": winners})

        return results

    def error_loss(self, output, target, reduction="mean"):
        """
        This computes the loss and then saves the targets computed on this loss. This
        mixin assumes these targets correspond, in a 1:1 fashion, to the images seen in
        the forward pass.
        """
        loss = super().error_loss(output, target, reduction=reduction)
        if self.dendrite_hooks.tracking:

            # Targets were initialized on the cpu which could differ from the
            # targets collected during the forward pass.
            self.targets = self.targets.to(target.device)

            # Concatenate and discard the older targets.
            self.targets = torch.cat([target, self.targets], dim=0)
            self.targets = self.targets[:self.max_samples_to_track]

        return loss