def test_lifecycle( self, modifier_lambda, model_lambda, optim_lambda, test_steps_per_epoch, # noqa: F811 ): modifier = modifier_lambda() model = model_lambda() optimizer = optim_lambda(model) self.initialize_helper(modifier, model) assert len(modifier._layer_modules) > 0 if modifier.start_epoch > 0: for (name, mod) in modifier._layer_modules.items(): assert mod is None assert not isinstance(get_layer(name, model), Identity) # check sparsity is not set before for epoch in range(int(modifier.start_epoch)): assert not modifier.update_ready(epoch, test_steps_per_epoch) for (name, mod) in modifier._layer_modules.items(): assert mod is None assert not isinstance(get_layer(name, model), Identity) epoch = int(modifier.start_epoch) assert modifier.update_ready(epoch, test_steps_per_epoch) modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch) else: epoch = 0 for (name, mod) in modifier._layer_modules.items(): assert mod is not None assert isinstance(get_layer(name, model), Identity) # check forward pass input_shape = model_lambda.layer_descs()[0].input_size test_batch = torch.randn(10, *input_shape) _ = model(test_batch) end_epoch = ( modifier.end_epoch if modifier.end_epoch > -1 else modifier.start_epoch + 10 ) while epoch < end_epoch - 0.1: epoch += 0.1 assert modifier.update_ready(epoch, test_steps_per_epoch) modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch) _ = model(test_batch) # check forward pass if modifier.end_epoch > -1: epoch = int(modifier.end_epoch) assert modifier.update_ready(epoch, test_steps_per_epoch) modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch) for (name, mod) in modifier._layer_modules.items(): assert mod is None assert not isinstance(get_layer(name, model), Identity)
def initialize( self, module: Module, epoch: float = 0, loggers: Optional[List[BaseLogger]] = None, **kwargs, ): """ Grabs the layers to control the activation sparsity for :param module: the PyTorch model/module to modify :param epoch: The epoch to initialize the modifier and module at. Defaults to 0 (start of the training process) :param loggers: Optional list of loggers to log the modification process to :param kwargs: Optional kwargs to support specific arguments for individual modifiers. """ super(ASRegModifier, self).initialize(module, epoch, loggers, **kwargs) layers = (get_terminal_layers(module) if self._layers == ALL_TOKEN else [get_layer(name, module) for name in self._layers]) for layer in layers: self._trackers.append( ASLayerTracker( layer, track_input=self._reg_tens == "inp", track_output=self._reg_tens == "out", input_func=self._regularize_tracked, output_func=self._regularize_tracked, ))
def _test_const(module, name, param_name): layer = get_layer(name, module) analyzer = ModulePruningAnalyzer(layer, name, param_name) assert analyzer.module == layer assert analyzer.name == name assert analyzer.param_name == param_name assert analyzer.tag == "{}.{}".format(name, param_name) assert isinstance(analyzer.param, Parameter)
def test_param_sparsity(module, name, param_name, param_data, expected_sparsity): layer = get_layer(name, module) analyzer = ModulePruningAnalyzer(layer, name, param_name) analyzer.param.data = param_data assert torch.sum( (analyzer.param_sparsity - expected_sparsity).abs()) < 0.00001
def layer_desc(self, name: Union[str, None] = None) -> AnalyzedLayerDesc: """ Get a specific layer's description within the Module. Set to None to get the overall Module's description. :param name: name of the layer to get a description for, None for an overall description :return: the analyzed layer description for the given name """ if not self._forward_called: raise RuntimeError( "module must have forward called with sample input " "before getting a layer desc" ) mod = get_layer(name, self._module) if name is not None else self._module return ModuleAnalyzer._mod_desc(mod)
def initialize(self, module: Module, optimizer: Optimizer): """ Grab the layer's to control activation sparsity for :param module: module to modify :param optimizer: optimizer to modify """ super(ASRegModifier, self).initialize(module, optimizer) layers = (get_terminal_layers(module) if self._layers == ALL_TOKEN else [get_layer(name, module) for name in self._layers]) for layer in layers: self._trackers.append( ASLayerTracker( layer, track_input=self._reg_tens == "inp", track_output=self._reg_tens == "out", input_func=self._regularize_tracked, output_func=self._regularize_tracked, ))
def _measure_layer(self, layer: Union[str, None], module: Module, device: str, desc: str) -> Tuple[ModuleRunResults, Tensor]: layer = get_layer(layer, module) if layer else None as_analyzer = None if layer: as_analyzer = ModuleASAnalyzer(layer, dim=None, track_outputs_sparsity=True) as_analyzer.enable() tester = ModuleTester(module, device, self._loss) data_loader = DataLoader(self._dataset, self._batch_size, **self._dataloader_kwargs) results = tester.run(data_loader, desc=desc, show_progress=True, track_results=True) if as_analyzer: as_analyzer.disable() return results, as_analyzer.outputs_sparsity_mean if as_analyzer else None
def _binary_search_fat( self, layer: str, module: Module, device: str, min_thresh: float, max_thresh: float, max_target_metric_loss: float, metric_key: str, metric_increases: bool, precision: float = 0.001, ): print("\n\n\nstarting binary search for layer {} between ({}, {})...". format(layer, min_thresh, max_thresh)) base_res, base_as = self._measure_layer( layer, module, device, "baseline for layer: {}".format(layer)) fat_relu = get_layer(layer, module) # type: FATReLU init_thresh = fat_relu.get_threshold() if min_thresh > init_thresh: min_thresh = init_thresh while True: thresh = ModuleASOneShootBooster._get_mid_point( min_thresh, max_thresh) fat_relu.set_threshold(thresh) thresh_res, thresh_as = self._measure_layer( layer, module, device, "threshold for layer: {} @ {:.4f} ({:.4f}, {:.4f})".format( layer, thresh, min_thresh, max_thresh), ) if ModuleASOneShootBooster._passes_loss( base_res, thresh_res, max_target_metric_loss, metric_key, metric_increases, ): min_thresh = thresh print("loss check passed for max change: {:.4f}".format( max_target_metric_loss)) print(" current loss: {:.4f} baseline loss: {:.4f}".format( thresh_res.result_mean(metric_key), base_res.result_mean(metric_key), )) print(" current AS: {:.4f} baseline AS: {:.4f}".format( thresh_as, base_as)) else: max_thresh = thresh print("loss check failed for max change: {:.4f}".format( max_target_metric_loss)) print(" current loss: {:.4f} baseline loss: {:.4f}".format( thresh_res.result_mean(metric_key), base_res.result_mean(metric_key), )) print(" current AS: {:.4f} baseline AS: {:.4f}".format( thresh_as, base_as)) if max_thresh - min_thresh <= precision: break print("completed binary search for layer {}") print(" found threshold: {}".format(thresh)) print(" AS delta: {:.4f} ({:.4f} => {:.4f})".format( thresh_as - base_as, base_as, thresh_as)) return LayerBoostResults(layer, thresh, thresh_as, thresh_res, base_as, base_res)