def setup_experiment(self, config): super().setup_experiment(config) # Process config args gradient_metrics_args = config.get("gradient_metrics_args", {}) self.gradient_metrics_plot_freq, self.gradient_metrics_filter_args, \ self.gradient_metrics_max_samples, self.gradient_metrics = \ process_gradient_metrics_args( gradient_metrics_args) # Register hook for tracking hidden activations named_modules = filter_modules(self.model, **self.gradient_metrics_filter_args) hook_args = dict( max_samples_to_track=self.gradient_metrics_max_samples) self.gradient_metric_hooks = ModelHookManager(named_modules, TrackGradientsHook, hook_type="backward", hook_args=hook_args) # Log the names of the modules being tracked tracked_names = pformat(list(named_modules.keys())) self.logger.info(f"Tracking gradients for modules: {tracked_names}") # The targets will be collected in `self.error_loss` in a 1:1 fashion # to the tensors being collected by the hooks. self.gradient_metric_targets = torch.tensor([]).long()
def setup_experiment(self, config): super().setup_experiment(config) # Process config args ha_args = config.get("plot_hidden_activations_args", {}) ha_plot_freq, filter_args, ha_max_samples = self.process_ha_args( ha_args) self.ha_plot_freq = ha_plot_freq self.ha_max_samples = ha_max_samples # Register hook for tracking hidden activations named_modules = filter_modules(self.model, **filter_args) hook_args = dict(max_samples_to_track=self.ha_max_samples) self.ha_hook = ModelHookManager(named_modules, TrackHiddenActivationsHook, hook_args=hook_args) # Log the names of the modules being tracked tracked_names = pformat(list(named_modules.keys())) self.logger.info( f"Tracking hidden activations of modules: {tracked_names}") # The targets will be collected in `self.error_loss` in a 1:1 fashion # to the tensors being collected by the hooks. self.ha_targets = torch.tensor([]).long()
def setup_experiment(self, config): super().setup_experiment(config) # Unpack, validate, and process the default arguments. metric_args = config.get("plot_dendrite_metrics_args", {}) self.metric_args, filter_args, max_samples = self.process_args(metric_args) # The maximum 'max_samples_to_track' will be tracked by the all the hooks. self.max_samples_to_track = max_samples hook_args = dict(max_samples_to_track=self.max_samples_to_track) # The 'filter_args' specify which modules to track. named_modules = filter_modules(self.model, **filter_args) self.dendrite_hooks = ModelHookManager(named_modules, ApplyDendritesHook, hook_args=hook_args) # The hook is specifically made for `ApplyDendritesBase` modules. for module in named_modules.values(): assert isinstance(module, ApplyDendritesBase) # Log the names of the modules being tracked, and warn when there's none. names = list(named_modules.keys()) self.logger.info(f"Dendrite Metric Setup: Tracking modules: {names}") if len(names) == 0: self.logger.warning("Dendrite Metric Setup: " "No modules found for tracking.") # The targets will be collected in `self.error_loss` in a 1:1 fashion # to the tensors being collected by the hooks. self.targets = torch.tensor([]).long()
def setup_experiment(self, config): """ Register forward hooks to track the input and output sparsities. The subsets of modules to track are defined via `include_*` params. See `filter_modules`_ for further details. :param config: - track_input_sparsity_args: - include_modules: a list of module types to track - include_names: a list of module names to track e.g. "features.stem" - include_patterns: a list of regex patterns to compare to the names; for instance, all feature parameters in ResNet can be included through "features.*" - track_output_sparsity_args: same as track_input_sparsity_args .. _filter_modules: nupic.research.frameworks.pytorch.model_utils.filter_modules """ super().setup_experiment(config) # Register hooks to track input sparsities. input_tracking_args = config.get("track_input_sparsity_args", {}) named_modules = filter_modules(self.model, **input_tracking_args) self.input_hook_manager = ModelHookManager(named_modules, TrackSparsityHook) # Log the names of the modules with tracked inputs. tracked_names = pformat(list(named_modules.keys())) self.logger.info( f"Tracking input sparsity of modules: {tracked_names}") # Register hooks to track output sparsities. output_tracking_args = config.get("track_output_sparsity_args", {}) named_modules = filter_modules(self.model, **output_tracking_args) self.output_hook_manager = ModelHookManager(named_modules, TrackSparsityHook) # Log the names of the modules with tracked outputs. tracked_names = pformat(list(named_modules.keys())) self.logger.info( f"Tracking output sparsity of modules: {tracked_names}") # Throw a warning when no modules are being tracked. if not input_tracking_args and not output_tracking_args: self.logger.warning( "No modules specified to track input/output sparsity.")
def clone_model(self): """ Clones self.model and register forward hooks on new model. """ model = super().clone_model() # Register hooks to track input sparsities. track_modules = get_modules_by_names(model, self.track_input_of_names) self.input_hook_manager = ModelHookManager(track_modules, TrackSparsityHook) # Register hooks to track output sparsities. track_modules = get_modules_by_names(model, self.track_output_of_names) self.output_hook_manager = ModelHookManager(track_modules, TrackSparsityHook) # Enable tracking the sparsity statistics. self.input_hook_manager.start_tracking() self.output_hook_manager.start_tracking() return model
class PlotHiddenActivations(metaclass=abc.ABCMeta): """ Mixin for creating custom plots of a module's output/hidden activations. :param config: a dict containing the following - plot_hidden_activations_args: a dict containing the following - include_modules: (optional) a list of module types to track - include_names: (optional) a list of module names to track e.g. "features.stem" - include_patterns: (optional) a list of regex patterns to compare to the names; for instance, all feature parameters in ResNet can be included through "features.*" - plot_freq: (optional) how often to create the plot, measured in training iterations; defaults to 1 - max_samples_to_plot: (optional) how many of samples to use for plotting; only the newest will be used; defaults to 5000 Example config: ``` config=dict( plot_hidden_activations_args=dict( include_modules=[torch.nn.Linear, KWinners], plot_freq=1, max_samples_to_plot=2000 ) ) ``` """ def setup_experiment(self, config): super().setup_experiment(config) # Process config args ha_args = config.get("plot_hidden_activations_args", {}) ha_plot_freq, filter_args, ha_max_samples = self.process_ha_args( ha_args) self.ha_plot_freq = ha_plot_freq self.ha_max_samples = ha_max_samples # Register hook for tracking hidden activations named_modules = filter_modules(self.model, **filter_args) hook_args = dict(max_samples_to_track=self.ha_max_samples) self.ha_hook = ModelHookManager(named_modules, TrackHiddenActivationsHook, hook_args=hook_args) # Log the names of the modules being tracked tracked_names = pformat(list(named_modules.keys())) self.logger.info( f"Tracking hidden activations of modules: {tracked_names}") # The targets will be collected in `self.error_loss` in a 1:1 fashion # to the tensors being collected by the hooks. self.ha_targets = torch.tensor([]).long() def process_ha_args(self, ha_args): ha_args = deepcopy(ha_args) # Collect information about which modules to apply hooks to include_names = ha_args.pop("include_names", []) include_modules = ha_args.pop("include_modules", []) include_patterns = ha_args.pop("include_patterns", []) filter_args = dict( include_names=include_names, include_modules=include_modules, include_patterns=include_patterns, ) # Others args plot_freq = ha_args.get("plot_freq", 1) max_samples = ha_args.get("max_samples_to_plot", 1000) assert isinstance(plot_freq, int) assert isinstance(max_samples, int) assert plot_freq > 0 assert max_samples > 0 return plot_freq, filter_args, max_samples def run_epoch(self): """ This runs the epoch with the hooks in tracking mode. The resulting hidden activations collected by the `TrackHiddenActivationsHook` object is plotted by calling a plotting function. """ # Run the epoch with tracking enabled. with self.ha_hook: results = super().run_epoch() # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here. iteration = self.current_epoch + 1 # Create visualization, and update results dict. if iteration % self.ha_plot_freq == 0: for name, _, ha in self.ha_hook.get_statistics(): visual = plot_hidden_activations_by_unit(ha, self.ha_targets) results.update({f"hidden_activations/{name}": visual}) results.update( {f"_activations/{name}": ha.detach().cpu().numpy()}) return results def error_loss(self, output, target, reduction="mean"): """ This computes the loss and then saves the targets computed on this loss. This mixin assumes these targets correspond, in a 1:1 fashion, to the samples seen in the forward pass. """ loss = super().error_loss(output, target, reduction=reduction) if self.ha_hook.tracking: # Targets were initialized on the cpu which could differ from the # targets collected during the forward pass. self.ha_targets = self.ha_targets.to(target.device) # Concatenate and discard the older targets. self.ha_targets = torch.cat([target, self.ha_targets], dim=0) self.ha_targets = self.ha_targets[:self.ha_max_samples] return loss
class PlotRepresentationOverlap(metaclass=abc.ABCMeta): """ Mixin for plotting a module's inter- and intra-class representation overlap. :param config: a dict containing the following - plot_representation_overlap_args: a dict containing the following - include_modules: (optional) a list of module types to track - include_names: (optional) a list of module names to track e.g. "features.stem" - include_patterns: (optional) a list of regex patterns to compare to the names; for instance, all feature parameters in ResNet can be included through "features.*" - plot_freq: (optional) how often to create the plot, measured in training iterations; defaults to 1 - plot_args: (optional) either a dictionary or a callable that takes no arguments and returns a dictionary; for instance this may be used to return a random sample of integers specifying units to plot; called only once at setup - max_samples_to_plot: (optional) how many of samples to use for plotting; only the newest will be used; defaults to 5000 Example config: ``` config=dict( plot_representation_overlap_args=dict( include_modules=[torch.nn.ReLU, KWinners], plot_freq=1, plot_args=dict(annotate=False), max_samples_to_plot=2000 ) ) ``` """ def setup_experiment(self, config): super().setup_experiment(config) # Process config args ro_args = config.get("plot_representation_overlap_args", {}) ro_plot_freq, filter_args, ro_max_samples = process_ro_args(ro_args) self.ro_plot_freq = ro_plot_freq self.ro_max_samples = ro_max_samples # Register hook for tracking hidden activations - useful for representation # overlap named_modules = filter_modules(self.model, **filter_args) hook_args = dict(max_samples_to_track=self.ro_max_samples) self.ro_hook = ModelHookManager(named_modules, TrackHiddenActivationsHook, hook_args=hook_args) # Log the names of the modules being tracked tracked_names = pformat(list(named_modules.keys())) self.logger.info( f"Tracking representation overlap of modules: {tracked_names}") # The targets will be collected in `self.error_loss` in a 1:1 fashion # to the tensors being collected by the hooks. self.ro_targets = torch.tensor([]).long() def run_epoch(self): # Run the epoch with tracking enabled. with self.ro_hook: results = super().run_epoch() # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here. iteration = self.current_epoch + 1 # Create visualization, and update results dict. if iteration % self.ro_plot_freq == 0: for name, _, activations in self.ro_hook.get_statistics(): metric_str = "representation_overlap" # Representation overlap matrix visual = plot_representation_overlap_matrix( activations, self.ro_targets) results.update({f"{metric_str}_matrix/{name}": visual}) # Representation overlap distributions: # * inter-class pairs # * intra-class pairs plots = plot_representation_overlap_distributions( activations, self.ro_targets) inter_overlaps, intra_overlaps = plots results.update( {f"{metric_str}_interclass/{name}": inter_overlaps}) results.update( {f"{metric_str}_intraclass/{name}": intra_overlaps}) return results def error_loss(self, output, target, reduction="mean"): """ This computes the loss and then saves the targets computed on this loss. This mixin assumes these targets correspond, in a 1:1 fashion, to the samples seen in the forward pass. """ loss = super().error_loss(output, target, reduction=reduction) if self.ro_hook.tracking: # Targets were initialized on the cpu which could differ from the # targets collected during the forward pass. self.ro_targets = self.ro_targets.to(target.device) # Concatenate and discard the older targets. self.ro_targets = torch.cat([target, self.ro_targets], dim=0) self.ro_targets = self.ro_targets[:self.ro_max_samples] return loss
class TrackRepresentationSparsityMetaCL: """ This mixin tracks and reports the average sparsities observed in the model's representations. This is similar to `TrackRepresentationSparsity` above, however, the functionality is reformatted for the MetaCL setting. Specifically, the hooks must be re-registered every epoch on the cloned model prior to the fast steps. Tracked statistics are returned from `run_epoch` and are reset before each subsequent epoch. The default is to track none of the modules; only those specified can be tracked. """ def setup_experiment(self, config): """ Save the names of which modules to track. The forward hooks are registered in `clone_model` at the start of every epoch of training. The subsets of modules to track are defined via `include_*` params. See `filter_modules`_ for further details. :param config: - track_input_sparsity_args: - include_modules: a list of module types to track - include_names: a list of module names to track e.g. "features.stem" - include_patterns: a list of regex patterns to compare to the names; for instance, all feature parameters in ResNet can be included through "features.*" - track_output_sparsity_args: same as track_input_sparsity_args """ super().setup_experiment(config) # The hooks managers will be initialized for each cloned model, prior to the # fast steps. For now, just the names of the modules to track will be saved. self.output_hook_manager = None self.input_hook_manager = None # Save the names of the modules to tracked inputs. input_tracking_args = config.get("track_input_sparsity_args", {}) named_modules = filter_modules(self.model, **input_tracking_args) self.track_input_of_names = list(named_modules.keys()) # Log the names of the modules with tracked inputs. tracked_names = pformat(self.track_input_of_names) self.logger.info( f"Tracking input sparsity of modules: {tracked_names}") # Save the names of the modules to tracked outputs. output_tracking_args = config.get("track_output_sparsity_args", {}) named_modules = filter_modules(self.model, **output_tracking_args) self.track_output_of_names = list(named_modules.keys()) # Log the names of the modules with tracked outputs. tracked_names = pformat(self.track_output_of_names) self.logger.info( f"Tracking output sparsity of modules: {tracked_names}") # Throw a warning when no modules are being tracked. if not input_tracking_args and not output_tracking_args: self.logger.warning( "No modules specified to track input/output sparsity.") def run_epoch(self): """Run one epoch and log the observed sparsity statistics.""" # The start of super's run_epoch, will call `clone_model` to initiate tracking. ret = super().run_epoch() # Log sparsity statistics collected from the input and output hooks. update_results_dict(ret, self.input_hook_manager, self.output_hook_manager) return ret def clone_model(self): """ Clones self.model and register forward hooks on new model. """ model = super().clone_model() # Register hooks to track input sparsities. track_modules = get_modules_by_names(model, self.track_input_of_names) self.input_hook_manager = ModelHookManager(track_modules, TrackSparsityHook) # Register hooks to track output sparsities. track_modules = get_modules_by_names(model, self.track_output_of_names) self.output_hook_manager = ModelHookManager(track_modules, TrackSparsityHook) # Enable tracking the sparsity statistics. self.input_hook_manager.start_tracking() self.output_hook_manager.start_tracking() return model @classmethod def get_execution_order(cls): # TODO: Update eo. eo = super().get_execution_order() mixin = "TrackRepresentationSparsityMetaCL: " eo["setup_experiment"].append( mixin + "Save the names of which modules to track.") eo["clone_model"].append(mixin + "Register hooks to the cloned model.") eo["run_epoch"].append( mixin + "Calculate and log sparsity statistics of representations.") return eo
class PlotContextSignal(metaclass=abc.ABCMeta): """ Mixin for creating plots of the context vectors used to modulate dendritic networks. :param config: a dict containing the following - plot_context_args: a dict containing the following - include_modules: (optional) a list of module types to track - include_names: (optional) a list of module names to track e.g. "features.stem" - include_patterns: (optional) a list of regex patterns to compare to the names; for instance, all feature parameters in ResNet can be included through "features.*" - plot_freq: (optional) how often to create the plot, measured in training iterations; defaults to 1 - max_samples_to_plot: (optional) how many of samples to use for plotting; only the newest will be used; defaults to 5000 Example config: ``` config=dict( plot_context_args=dict( include_names=["context_net"], plot_freq=1, max_samples_to_plot=2000 ), ) ``` where `context_net` is assumed to refer to the module whose outputs are context vectors used to modulate a dendritic network. """ def setup_experiment(self, config): super().setup_experiment(config) # Process config args context_args = config.get("plot_context_args", {}) context_plot_freq, filter_args, max_samples = process_context_args(context_args) self.context_plot_freq = context_plot_freq self.context_max_samples = max_samples # Register hook for tracking context vectors named_modules = filter_modules(self.model, **filter_args) hook_args = dict(max_samples_to_track=self.context_max_samples) self.context_hook = ModelHookManager(named_modules, TrackHiddenActivationsHook, hook_args=hook_args) # Log the names of the modules being tracked tracked_names = pformat(list(named_modules.keys())) self.logger.info(f"Tracking context signals output from: {tracked_names}") # The targets will be collected in `self.error_loss` in a 1:1 fashion # to the tensors being collected by the hooks. self.context_targets = torch.tensor([]).long() def run_epoch(self): """ This runs the epoch with the hooks in tracking mode. The resulting context vectors collected by the `TrackHiddenActivationsHook` object are plotted by calling a plotting function. """ # Run the epoch with tracking enabled. with self.context_hook: results = super().run_epoch() # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here. iteration = self.current_epoch + 1 # Create visualization, and update results dict. if iteration % self.context_plot_freq == 0: for name, _, contexts in self.context_hook.get_statistics(): visual = plot_contexts_by_class(contexts, self.context_targets) results.update({f"contexts/{name}": visual}) return results def error_loss(self, output, target, reduction="mean"): """ This computes the loss and then saves the targets computed on this loss. This mixin assumes these targets correspond, in a 1:1 fashion, to the samples seen in the forward pass. """ loss = super().error_loss(output, target, reduction=reduction) if self.context_hook.tracking: # Targets were initialized on the cpu which could differ from the # targets collected during the forward pass. self.context_targets = self.context_targets.to(target.device) # Concatenate and discard the older targets. self.context_targets = torch.cat([target, self.context_targets], dim=0) self.context_targets = self.context_targets[:self.context_max_samples] return loss
class GradientMetrics(object): """ Mixin for tracking and plotting module gradient metrics during training. :param config: a dict containing the following - gradient_metrics_args: a dict containing the following - include_modules: (optional) a list of module types to track - include_names: (optional) a list of module names to track e.g. "features.stem" - include_patterns: (optional) a list of regex patterns to compare to the names; for instance, all feature parameters in ResNet can be included through "features.*" - plot_freq: (optional) how often to create the plot, measured in training iterations; defaults to 1 - metrics: a list of metrics options from ["cosine", "dot", "pearson"]; defaults to ["cosine",] - gradient_values: (optional) one of "real", "sign", "mask". "real" corresponds to the real values of the gradients, "sign" corresponds to collecting the sign of the gradients, "mask" results in a binary mask corresponding to nonzero gradients; defaults to "real" - max_samples_to_track: (optional) how many of samples to use for plotting; only the newest will be used; defaults to 100 Example config: ``` config=dict( gradient_metrics_args=dict( include_modules=[torch.nn.Linear, KWinners], plot_freq=1, max_samples_to_track=150, metrics=["dot", "pearson"], gradient_values="mask", metric1 = "mask/dot" metric2 = "sign/pearson" ) ) ``` """ def setup_experiment(self, config): super().setup_experiment(config) # Process config args gradient_metrics_args = config.get("gradient_metrics_args", {}) self.gradient_metrics_plot_freq, self.gradient_metrics_filter_args, \ self.gradient_metrics_max_samples, self.gradient_metrics = \ process_gradient_metrics_args( gradient_metrics_args) # Register hook for tracking hidden activations named_modules = filter_modules(self.model, **self.gradient_metrics_filter_args) hook_args = dict( max_samples_to_track=self.gradient_metrics_max_samples) self.gradient_metric_hooks = ModelHookManager(named_modules, TrackGradientsHook, hook_type="backward", hook_args=hook_args) # Log the names of the modules being tracked tracked_names = pformat(list(named_modules.keys())) self.logger.info(f"Tracking gradients for modules: {tracked_names}") # The targets will be collected in `self.error_loss` in a 1:1 fashion # to the tensors being collected by the hooks. self.gradient_metric_targets = torch.tensor([]).long() def run_epoch(self): """ This runs the epoch with the hooks in tracking mode. The resulting gradients collected by the `TrackGradientsHook` object are plotted by calling a plotting function. """ # Run the epoch with tracking enabled. with self.gradient_metric_hooks: results = super().run_epoch() # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here. iteration = self.current_epoch + 1 # Calculate metrics, create visualization, and update results dict. if iteration % self.gradient_metrics_plot_freq == 0: gradient_stats = self.gradient_metric_hooks.get_statistics() gradient_metrics_stats = self.calculate_gradient_metrics_stats( gradient_stats) gradient_metric_heatmaps = self.plot_gradient_metric_heatmaps( gradient_metrics_stats) for (name, _, gradient_metric, gradient_value, _, figure) in \ gradient_metric_heatmaps: results.update( {f"{gradient_metric}/{gradient_value}/{name}": figure}) return results def calculate_gradient_metrics_stats(self, gradients_stats): """ This function calculates statistics given the gradients_stats which are being tracked by the TrackGradients backwards hook. This function accesses self.metrics, which is alsa a list of tuples. Each tuple is the combination of a metric function ("cosine", "dot", "pearson) and a gradient transformation ("real", "sign", "mask"). By default, the statistics that are calculated are ("cosine", "mask") and ("cosine", "real"). ("cosine", "mask") corresponds to the overlap between two different gradients, and ("cosine", "real") corresponds to the standard cosine similarity between two gradients. Args: gradients_stats: A list of tuples, each of which contains a named module and its gradients Returns: A list of tuples, each of which contains a named module, the statistics calculated from its gradients, and the metric function/transformation used on the gradients """ all_stats = [] for (name, module, gradients) in gradients_stats: for gradient_metric, gradient_value in self.gradient_metrics: # apply gradient value transformation if necessary if gradient_value == "sign": gradients = torch.sign(gradients) elif gradient_value == "mask": gradients = torch.abs(torch.sign(gradients)) # calculate metric function on transformed gradients if gradient_metric == "cosine": stats = [ torch.cosine_similarity(x, y, dim=0) if not torch.equal(x, y) else 0 for x in gradients for y in gradients ] elif gradient_metric == "dot": stats = [ x.dot(y) if not torch.equal(x, y) else 0 for x in gradients for y in gradients ] elif gradient_metric == "pearson": stats = [ torch.cosine_similarity( x - x.mean(), y - y.mean(), dim=0) if not torch.equal(x, y) else 0 for x in gradients for y in gradients ] stats = torch.tensor(stats) gradient_dim = len(gradients) stats = stats.view(gradient_dim, gradient_dim) all_stats.append( (name, module, gradient_metric, gradient_value, stats)) return all_stats def plot_gradient_metric_heatmaps(self, gradient_metrics_stats): order_by_class = torch.argsort(self.gradient_metric_targets) sorted_gradient_metric_targets = self.gradient_metric_targets[ order_by_class] class_change_indices = \ (sorted_gradient_metric_targets - sorted_gradient_metric_targets.roll( 1)).nonzero(as_tuple=True)[0].cpu() class_labels = [ int(_x) for _x in sorted_gradient_metric_targets.unique() ] class_change_indices_right = class_change_indices.roll(-1).cpu() class_change_indices_right[-1] = len(sorted_gradient_metric_targets) tick_locations = (class_change_indices + class_change_indices_right) / 2.0 - 0.5 tick_locations = tick_locations.cpu() stats_and_figures = [] for (name, module, gradient_metric, gradient_value, stats) in \ gradient_metrics_stats: stats = stats[order_by_class, :][:, order_by_class] ax = plt.gca() max_val = np.abs(stats).max() img = ax.imshow(stats, cmap="bwr", vmin=-max_val, vmax=max_val) ax.set_xlabel("class") ax.set_ylabel("class") for idx in class_change_indices: ax.axvline(idx - 0.5, color="black") ax.axhline(idx - 0.5, color="black") ax.set_xticks(tick_locations) ax.set_yticks(tick_locations) ax.set_xticklabels(class_labels) ax.set_yticklabels(class_labels) ax.set_title(f"{gradient_metric}:{gradient_value}:{name}") plt.colorbar(img, ax=ax) plt.tight_layout() figure = plt.gcf() stats_and_figures.append( (name, module, gradient_metric, gradient_value, stats, figure)) return stats_and_figures def error_loss(self, output, target, reduction="mean"): """ This computes the loss and then saves the targets computed on this loss. This mixin assumes these targets correspond, in a 1:1 fashion, to the samples seen in the forward pass. """ loss = super().error_loss(output, target, reduction=reduction) if self.gradient_metric_hooks.tracking: # Targets were initialized on the cpu which could differ from the # targets collected during the forward pass. self.gradient_metric_targets = self.gradient_metric_targets.to( target.device) # Concatenate and discard the older targets. self.gradient_metric_targets = torch.cat( [target, self.gradient_metric_targets], dim=0) self.gradient_metric_targets = self.gradient_metric_targets[:self. gradient_metrics_max_samples] return loss
class PlotDendriteMetrics(metaclass=abc.ABCMeta): """ This is a mixin for creating custom plots of metrics for apply-dendrite modules (those of type `ApplyDendritesBase`_). The user defines and gives a plotting function which is then called on the following arguments - dendrite_activations: the input activations passed to the apply-dendrites module; these are meant to be the output of a `DendriteSegments` module; they will be of shape batch_size x num_units x num_segments - winning_mask: the mask of the winning segments, those chosen by the apply-dendrites modules; this will be of shape batch_size x num_units - targets: the targets that correspond to each sample in the batch Plots can be configured to use fewer samples (helpful for plotting a small batches of individual samples) and to plot every so many epochs (so that training isn't slowed down too much). Whenever a plot is made, the raw data used to create it is saved so it may be reproduced and edited off-line. .. warning:: When using this mixin with Ray, be careful to have 'plot_func' return an object that can be logged. Often, Ray will attempt to create a deepcopy prior to logging which can't be done on most plots. Try wrapping the plot as done in `prep_plot_for_wandb`_. .. _ApplyDendritesBase: nupic/reasearch/frameworks/dendrites/modules .. _prep_plot_for_wandb: nupic/reasearch/frameworks/wandb/ray_wandb :param config: a dict containing the following - plot_dendrite_metrics_args: a dict containing the following - include_modules: (optional) a list of module types to track; defaults to include all `ApplyDendritesBase` modules if no other `include_*` arguments are given. - include_names: (optional) a list of module names to track e.g. "features.stem" - include_patterns: (optional) a list of regex patterns to compare to the names; for instance, all feature parameters in ResNet can be included through "features.*" <insert any plot name here>: This can be any string and maps to a dictionary of the plot arguments below. The resulting plot will be logged under "<plot_name>/<module_name>" in the results dictionary. - plot_func: the function called for plotting; must take three arguments: 'dendrite_activations', 'winning_mask', and 'targets' (see above) - plot_freq: (optional) how often to create the plot, measured in training iterations; defaults to 1 - plot_args: (optional) either a dictionary or a callable that takes no arguments and returns a dictionary; for instance this may be used to return a random sample of integers specifying units to plot; called only once at setup - max_samples_to_plot: (optional) how many of samples to use for plotting; only the newest will be used; defaults to 1000 Example config: ``` config=dict( plot_dendrite_metrics_args=dict( include_modules=[DendriticGate1d], mean_selected=dict( plot_func=plot_mean_selected_activations, ) ) ) ``` """ def setup_experiment(self, config): super().setup_experiment(config) # Unpack, validate, and process the default arguments. metric_args = config.get("plot_dendrite_metrics_args", {}) self.metric_args, filter_args, max_samples = self.process_args(metric_args) # The maximum 'max_samples_to_track' will be tracked by the all the hooks. self.max_samples_to_track = max_samples hook_args = dict(max_samples_to_track=self.max_samples_to_track) # The 'filter_args' specify which modules to track. named_modules = filter_modules(self.model, **filter_args) self.dendrite_hooks = ModelHookManager(named_modules, ApplyDendritesHook, hook_args=hook_args) # The hook is specifically made for `ApplyDendritesBase` modules. for module in named_modules.values(): assert isinstance(module, ApplyDendritesBase) # Log the names of the modules being tracked, and warn when there's none. names = list(named_modules.keys()) self.logger.info(f"Dendrite Metric Setup: Tracking modules: {names}") if len(names) == 0: self.logger.warning("Dendrite Metric Setup: " "No modules found for tracking.") # The targets will be collected in `self.error_loss` in a 1:1 fashion # to the tensors being collected by the hooks. self.targets = torch.tensor([]).long() def process_args(self, metric_args): metric_args = deepcopy(metric_args) # Remove and collect information about which modules to track. include_names = metric_args.pop("include_names", []) include_modules = metric_args.pop("include_modules", []) include_patterns = metric_args.pop("include_patterns", []) # Default to track all `ApplyDendritesBase` modules. if len(include_names) == len(include_modules) == len(include_patterns) == 0: include_modules = [ApplyDendritesBase] filter_args = dict( include_names=include_names, include_modules=include_modules, include_patterns=include_patterns, ) # Gather and validate the metric arguments. The max of the 'max_samples_to_plot' # will be saved to dictate how many samples will be tracked by the hooks. all_max_num_samples = [] new_metric_args = {} for metric_name, plotting_args in metric_args.items(): plot_func = plotting_args.get("plot_func", None) plot_freq = plotting_args.get("plot_freq", 1) plot_args = plotting_args.get("plot_args", {}) max_samples_to_plot = plotting_args.get("max_samples_to_plot", 1000) assert callable(plot_func) assert isinstance(plot_freq, int) assert isinstance(max_samples_to_plot, int) assert plot_freq > 0 assert max_samples_to_plot > 0 # The arguments may be given as a callable; useful for sampling random # values that dictate plotting to, say, only plot a subset of units if callable(plot_args): plot_args = plot_args() assert isinstance(plot_args, dict) new_metric_args[metric_name] = dict( plot_func=plot_func, plot_freq=plot_freq, plot_args=plot_args, max_samples_to_plot=max_samples_to_plot, ) all_max_num_samples.append(max_samples_to_plot) max_samples_to_plot = max(all_max_num_samples) return new_metric_args, filter_args, max_samples_to_plot def run_epoch(self): """ This runs the epoch with the hooks in tracking mode. The resulting 'activations' and 'winning_masks' collected by these hooks are plotted via each 'plot_func' along with their corresponding targets. """ # Run the epoch with tracking enabled. with self.dendrite_hooks: results = super().run_epoch() # The epoch was iterated in `run_epoch` so epoch 0 is really epoch 1 here. iteration = self.current_epoch - 1 # Gather and plot the statistics. for name, _, activations, winners in self.dendrite_hooks.get_statistics(): if len(activations) == 0 or len(winners) == 0: self.logger.warning(f"Skipping plots for module '{name}';" " no data could be collected.") continue # Keep track of whether a plot is made below. If so, save the raw data. plot_made = False # Each 'plot_func' will be applied to each module being tracked. for metric_name, plotting_args in self.metric_args.items(): # All of the defaults were set in `process_args`. plot_func = plotting_args["plot_func"] plot_freq = plotting_args["plot_freq"] plot_args = plotting_args["plot_args"] max_samples_to_plot = plotting_args["max_samples_to_plot"] if iteration % plot_freq != 0: continue # Only use up the the max number of samples for plotting. targets = self.targets[:max_samples_to_plot] activations = activations[:max_samples_to_plot] winners = winners[:max_samples_to_plot] # Call and log the results of the plot function. # Here, "{name}" is the name of the module. visual = plot_func(activations, winners, targets, **plot_args) results.update({f"{metric_name}/{name}": visual}) plot_made = True # Log the raw data. if plot_made: targets = self.targets[:self.max_samples_to_track].cpu().numpy() activations = activations[:self.max_samples_to_track].cpu().numpy() winners = winners[:self.max_samples_to_track].cpu().numpy() results.update({f"targets/{name}": targets}) results.update({f"dendrite_activations/{name}": activations}) results.update({f"winning_mask/{name}": winners}) return results def error_loss(self, output, target, reduction="mean"): """ This computes the loss and then saves the targets computed on this loss. This mixin assumes these targets correspond, in a 1:1 fashion, to the images seen in the forward pass. """ loss = super().error_loss(output, target, reduction=reduction) if self.dendrite_hooks.tracking: # Targets were initialized on the cpu which could differ from the # targets collected during the forward pass. self.targets = self.targets.to(target.device) # Concatenate and discard the older targets. self.targets = torch.cat([target, self.targets], dim=0) self.targets = self.targets[:self.max_samples_to_track] return loss