def initialize(self, module: Module, optimizer: Optimizer): """ Grab the params to control kernel sparsity for :param module: module to modify :param optimizer: optimizer to modify """ super().initialize(module, optimizer) if self._params == ALL_TOKEN or ALL_TOKEN in self._params: param_names = ["re:.*"] elif self._params == ALL_PRUNABLE_TOKEN or ALL_PRUNABLE_TOKEN in self._params: param_names = [ name + ".weight" for (name, _) in get_prunable_layers(module) ] else: param_names = self._params named_layers_and_params = get_named_layers_and_params_by_regex( module, param_names, params_strict=True, ) layers = [ named_layer_param.layer for named_layer_param in named_layers_and_params ] param_names = [ named_layer_param.param_name for named_layer_param in named_layers_and_params ] layer_names = [ named_layer_param.layer_name for named_layer_param in named_layers_and_params ] self._module_masks = ModuleParamPruningMask( layers, param_names, layer_names=layer_names, global_sparsity=self._global_sparsity, ) self._analyzers = [] for layer_name, layer, param_name, _ in named_layers_and_params: self._analyzers.append( ModulePruningAnalyzer(layer, layer_name, param_name)) if len(self._analyzers) == 0: raise ValueError( "Could not find any params matching {} in {}".format( self._params, self.__class__.__name__))
def _create_named_layers_and_params(self, module: Module) -> List[NamedLayerParam]: if self._check_params_match(ALL_TOKEN): param_names = ["re:.*"] elif self._check_params_match(ALL_PRUNABLE_TOKEN): param_names = [ name + ".weight" for (name, _) in get_prunable_layers(module) ] else: param_names = self._params return get_named_layers_and_params_by_regex( module, param_names, params_strict=True, )
def pruning_loss_sens_magnitude( module: Module, sparsity_levels: Union[List[float], Tuple[float, ...]] = default_pruning_sparsities_loss(True), ) -> PruningLossSensitivityAnalysis: """ Approximated kernel sparsity (pruning) loss analysis for a given model. Returns the results for each prunable param (conv, linear) in the model. :param module: the model to calculate the sparse sensitivity analysis for :param sparsity_levels: the sparsity levels to calculate the loss for for each param :return: the analysis results for the model """ prunable = get_prunable_layers(module) analysis = PruningLossSensitivityAnalysis() for index, (name, layer) in enumerate(prunable): weight = getattr(layer, "weight") name = "{}.weight".format(name) values, _ = weight.view(-1).abs().sort() prev_index = 0 for sparsity in sparsity_levels: val_index = round(sparsity * len(values)) if val_index >= len(values): val_index = len(values) - 1 if sparsity <= 1e-9: baseline = True sparsity = 0.0 sparse_avg = 0.0 else: baseline = False if val_index > prev_index: sparse_avg = values[prev_index:val_index].mean().item() prev_index = val_index else: sparse_avg = values[val_index].item() prev_index = val_index + 1 analysis.add_result(None, name, index, sparsity, sparse_avg, baseline) return analysis
def initialize(self, module: Module, optimizer: Optimizer): """ Grab the params to control kernel sparsity for. :param module: module to modify :param optimizer: optimizer to modify """ super().initialize(module, optimizer) if self._params == ALL_TOKEN or ALL_TOKEN in self._params: param_names = ["re:.*"] elif self._params == ALL_PRUNABLE_TOKEN or ALL_PRUNABLE_TOKEN in self._params: param_names = [ name + ".weight" for (name, _) in get_prunable_layers(module) ] else: param_names = self._params named_layers_and_params = get_named_layers_and_params_by_regex( module, param_names, params_strict=True, ) layers = [ named_layer_param.layer for named_layer_param in named_layers_and_params ] param_names = [ named_layer_param.param_name for named_layer_param in named_layers_and_params ] layer_names = [ named_layer_param.layer_name for named_layer_param in named_layers_and_params ] self._module_masks = ModuleParamPruningMask(layers, param_names, layer_names=layer_names) self._analyzers = [] for layer_name, layer, param_name, _ in named_layers_and_params: self._analyzers.append( ModulePruningAnalyzer(layer, layer_name, param_name))
def ks_layer_descs(self) -> List[AnalyzedLayerDesc]: """ Get the descriptions for all layers in the module that support kernel sparsity (model pruning). Ex: all convolutions and linear layers. :return: a list of descriptions for all layers in the module that support ks """ descs = [] for (name, _) in get_prunable_layers(self._module): desc = self.layer_desc(name) if desc is None: print("analyzer: no description found for {}".format(name)) else: descs.append(desc) descs.sort(key=lambda val: val.execution_order) return descs
def model_prunability_magnitude(module: Module): """ Calculate the approximate sensitivity for an overall model. Range of the values are not scaled to anything, so must be taken in context with other known models. :param module: the model to calculate the sensitivity for :return: the approximated sensitivity """ prunable = get_prunable_layers(module) tensors = [] for (name, layer) in prunable: weight = getattr(layer, "weight") values = weight.view(-1).abs() tensors.append(values) all_weights = torch.cat(tensors) avg = all_weights.mean().item() return avg
def from_sparse_model(model: Module) -> List[ScheduledModifier]: """ Create constant ks modifiers for all prunable params in the given model (conv, linear) that have been artificially sparsified (sparsity > 40%). Useful for transfer learning from a pruned model. :param model: the model to create constant ks modifiers for :return: the list of created constant ks modifiers """ prunable = get_prunable_layers(model) modifiers = [] for name, layer in prunable: weight = getattr(layer, "weight") sparsity = tensor_sparsity(weight) if sparsity > 0.4: modifiers.append( ConstantPruningModifier( params=["{}.{}".format(name, "weight")])) return modifiers
def pruning_loss_sens_one_shot( module: Module, data: DataLoader, loss: Union[LossWrapper, Callable[[Any, Any], Tensor]], device: str, steps_per_measurement: int, sparsity_levels: List[int] = default_pruning_sparsities_loss(False), loss_key: str = DEFAULT_LOSS_KEY, tester_run_funcs: ModuleRunFuncs = None, tester_loggers: List[BaseLogger] = None, show_progress: bool = True, ) -> PruningLossSensitivityAnalysis: """ Run a one shot sensitivity analysis for kernel sparsity. It does not retrain, and instead puts the model to eval mode. Moves layer by layer to calculate the sensitivity analysis for each and resets the previously run layers. Note, by default it caches the data. This means it is not parallel for data loading and the first run can take longer. Subsequent sparsity checks for layers and levels will be much faster. :param module: the module to run the kernel sparsity sensitivity analysis over will extract all prunable layers out :param data: the data to run through the module for calculating the sensitivity analysis :param loss: the loss function to use for the sensitivity analysis :param device: the device to run the analysis on; ex: cpu, cuda :param steps_per_measurement: the number of samples or items to take for each measurement at each sparsity lev :param sparsity_levels: the sparsity levels to check for each layer to calculate sensitivity :param loss_key: the key for the loss function to track in the returned dict :param tester_run_funcs: override functions to use in the ModuleTester that runs :param tester_loggers: loggers to log data to while running the analysis :param show_progress: track progress of the runs if True :return: the sensitivity results for every layer that is prunable """ analysis = PruningLossSensitivityAnalysis() tester = ModuleTester( module, device, loss, loggers=tester_loggers, log_summary=False, log_steps=max(1, round(steps_per_measurement / 10)), ) layers = get_prunable_layers(module) batch_end = _sensitivity_callback(layers, sparsity_levels, steps_per_measurement, analysis, loss_key) batch_end_hook = tester.run_hooks.register_batch_end_hook(batch_end) if tester_run_funcs is not None: tester.run_funcs.copy(tester_run_funcs) data_loader = infinite_data_loader(data, early_stop_steps=steps_per_measurement, cache=True) tester.run( data_loader, desc="KS Analysis", show_progress=show_progress, track_results=False, max_steps=steps_per_measurement * len(sparsity_levels) * len(layers), ) batch_end_hook.remove() return analysis
def train( train_args: TrainingArguments, model: Module, train_loader: DataLoader, val_loader: DataLoader, input_shape: Tuple[int, ...], save_dir: str, loggers: List[Any], ) -> None: """ Utility function to drive the training processing :param train_args: A TrainingArguments object with arguments for current training task :param model: model architecture to train :param train_loader: A DataLoader for training data :param val_loader: A DataLoader for validation data :param input_shape: A tuple of integers representing the shape of inputs :param save_dir: Directory to store checkpoints at during training process :param loggers: List of loggers to use during training process """ # loss setup val_loss = helpers.get_loss_wrapper(arch_key=train_args.arch_key, training=True) LOGGER.info(f"created loss for validation: {val_loss}") train_loss = helpers.get_loss_wrapper(arch_key=train_args.arch_key, training=True) LOGGER.info(f"created loss for training: {train_loss}") # training setup if not train_args.eval_mode: epoch, optim, manager = helpers.create_scheduled_optimizer( train_args, model, train_loader, loggers, ) else: epoch = 0 train_loss = None optim = None manager = None # device setup if train_args.rank == -1: device = train_args.device ddp = False else: torch.cuda.set_device(train_args.local_rank) device = train_args.local_rank ddp = True model, device, device_ids = model_to_device(model, device, ddp=ddp) LOGGER.info(f"running on device {device} for ids {device_ids}") trainer = (ModuleTrainer( model, device, train_loss, optim, loggers=loggers, device_context=ModuleDeviceContext( use_mixed_precision=train_args.use_mixed_precision, world_size=train_args.world_size, ), ) if not train_args.eval_mode else None) if train_args.is_main_process: # only test on one DDP process if using DDP tester = ModuleTester(model, device, val_loss, loggers=loggers, log_steps=-1) # initial baseline eval run tester.run_epoch(val_loader, epoch=epoch - 1, max_steps=train_args.debug_steps) if not train_args.eval_mode: helpers.save_recipe(recipe_manager=manager, save_dir=save_dir) LOGGER.info(f"starting training from epoch {epoch}") if epoch > 0: LOGGER.info("adjusting ScheduledOptimizer to restore point") optim.adjust_current_step(epoch, 0) target_metric = ("top1acc" if "top1acc" in tester.loss.available_losses else DEFAULT_LOSS_KEY) best_metric = None val_res = None while epoch < manager.max_epochs: if train_args.debug_steps > 0: # correct since all optimizer steps are not # taken in the epochs for debug mode optim.adjust_current_step(epoch, 0) if train_args.rank != -1: # sync DDP dataloaders train_loader.sampler.set_epoch(epoch) trainer.run_epoch( train_loader, epoch, max_steps=train_args.debug_steps, show_progress=train_args.is_main_process, ) # testing steps if train_args.is_main_process: # only test and save on main process val_res = tester.run_epoch(val_loader, epoch, max_steps=train_args.debug_steps) val_metric = val_res.result_mean(target_metric).item() if epoch >= train_args.save_best_after and ( best_metric is None or (val_metric <= best_metric if target_metric != "top1acc" else val_metric >= best_metric)): helpers.save_model_training( model, optim, input_shape, "checkpoint-best", save_dir, epoch, val_res, ) best_metric = val_metric # save checkpoints _save_epoch = (train_args.is_main_process and train_args.save_epochs and epoch in train_args.save_epochs) if _save_epoch: helpers.save_model_training( model, optim, input_shape, f"checkpoint-{epoch:04d}-{val_metric:.04f}", save_dir, epoch, val_res, ) epoch += 1 # export the final model LOGGER.info("completed...") if train_args.is_main_process: # only convert qat -> quantized ONNX graph for finalized model # TODO: change this to all checkpoints when conversion times improve helpers.save_model_training(model, optim, input_shape, "model", save_dir, epoch - 1, val_res, True) LOGGER.info("layer sparsities:") for (name, layer) in get_prunable_layers(model): LOGGER.info( f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}" ) # close DDP if train_args.rank != -1: torch.distributed.destroy_process_group()