示例#1
0
    def initialize(self, module: Module, optimizer: Optimizer):
        """
        Grab the params to control kernel sparsity for

        :param module: module to modify
        :param optimizer: optimizer to modify
        """
        super().initialize(module, optimizer)

        if self._params == ALL_TOKEN or ALL_TOKEN in self._params:
            param_names = ["re:.*"]
        elif self._params == ALL_PRUNABLE_TOKEN or ALL_PRUNABLE_TOKEN in self._params:
            param_names = [
                name + ".weight" for (name, _) in get_prunable_layers(module)
            ]
        else:
            param_names = self._params

        named_layers_and_params = get_named_layers_and_params_by_regex(
            module,
            param_names,
            params_strict=True,
        )

        layers = [
            named_layer_param.layer
            for named_layer_param in named_layers_and_params
        ]
        param_names = [
            named_layer_param.param_name
            for named_layer_param in named_layers_and_params
        ]
        layer_names = [
            named_layer_param.layer_name
            for named_layer_param in named_layers_and_params
        ]
        self._module_masks = ModuleParamPruningMask(
            layers,
            param_names,
            layer_names=layer_names,
            global_sparsity=self._global_sparsity,
        )

        self._analyzers = []
        for layer_name, layer, param_name, _ in named_layers_and_params:
            self._analyzers.append(
                ModulePruningAnalyzer(layer, layer_name, param_name))

        if len(self._analyzers) == 0:
            raise ValueError(
                "Could not find any params matching {} in {}".format(
                    self._params, self.__class__.__name__))
示例#2
0
    def _create_named_layers_and_params(self, module: Module) -> List[NamedLayerParam]:
        if self._check_params_match(ALL_TOKEN):
            param_names = ["re:.*"]
        elif self._check_params_match(ALL_PRUNABLE_TOKEN):
            param_names = [
                name + ".weight" for (name, _) in get_prunable_layers(module)
            ]
        else:
            param_names = self._params

        return get_named_layers_and_params_by_regex(
            module,
            param_names,
            params_strict=True,
        )
def pruning_loss_sens_magnitude(
    module: Module,
    sparsity_levels: Union[List[float],
                           Tuple[float,
                                 ...]] = default_pruning_sparsities_loss(True),
) -> PruningLossSensitivityAnalysis:
    """
    Approximated kernel sparsity (pruning) loss analysis for a given model.
    Returns the results for each prunable param (conv, linear) in the model.

    :param module: the model to calculate the sparse sensitivity analysis for
    :param sparsity_levels: the sparsity levels to calculate the loss for for each param
    :return: the analysis results for the model
    """
    prunable = get_prunable_layers(module)
    analysis = PruningLossSensitivityAnalysis()

    for index, (name, layer) in enumerate(prunable):
        weight = getattr(layer, "weight")
        name = "{}.weight".format(name)
        values, _ = weight.view(-1).abs().sort()
        prev_index = 0

        for sparsity in sparsity_levels:
            val_index = round(sparsity * len(values))

            if val_index >= len(values):
                val_index = len(values) - 1

            if sparsity <= 1e-9:
                baseline = True
                sparsity = 0.0
                sparse_avg = 0.0
            else:
                baseline = False

                if val_index > prev_index:
                    sparse_avg = values[prev_index:val_index].mean().item()
                    prev_index = val_index
                else:
                    sparse_avg = values[val_index].item()
                    prev_index = val_index + 1

            analysis.add_result(None, name, index, sparsity, sparse_avg,
                                baseline)

    return analysis
示例#4
0
    def initialize(self, module: Module, optimizer: Optimizer):
        """
        Grab the params to control kernel sparsity for.

        :param module: module to modify
        :param optimizer: optimizer to modify
        """
        super().initialize(module, optimizer)

        if self._params == ALL_TOKEN or ALL_TOKEN in self._params:
            param_names = ["re:.*"]
        elif self._params == ALL_PRUNABLE_TOKEN or ALL_PRUNABLE_TOKEN in self._params:
            param_names = [
                name + ".weight" for (name, _) in get_prunable_layers(module)
            ]
        else:
            param_names = self._params

        named_layers_and_params = get_named_layers_and_params_by_regex(
            module,
            param_names,
            params_strict=True,
        )

        layers = [
            named_layer_param.layer
            for named_layer_param in named_layers_and_params
        ]
        param_names = [
            named_layer_param.param_name
            for named_layer_param in named_layers_and_params
        ]
        layer_names = [
            named_layer_param.layer_name
            for named_layer_param in named_layers_and_params
        ]
        self._module_masks = ModuleParamPruningMask(layers,
                                                    param_names,
                                                    layer_names=layer_names)

        self._analyzers = []
        for layer_name, layer, param_name, _ in named_layers_and_params:
            self._analyzers.append(
                ModulePruningAnalyzer(layer, layer_name, param_name))
示例#5
0
    def ks_layer_descs(self) -> List[AnalyzedLayerDesc]:
        """
        Get the descriptions for all layers in the module that support kernel sparsity
        (model pruning). Ex: all convolutions and linear layers.

        :return: a list of descriptions for all layers in the module that support ks
        """
        descs = []

        for (name, _) in get_prunable_layers(self._module):
            desc = self.layer_desc(name)

            if desc is None:
                print("analyzer: no description found for {}".format(name))
            else:
                descs.append(desc)

        descs.sort(key=lambda val: val.execution_order)

        return descs
def model_prunability_magnitude(module: Module):
    """
    Calculate the approximate sensitivity for an overall model.
    Range of the values are not scaled to anything, so must be taken in context
    with other known models.

    :param module: the model to calculate the sensitivity for
    :return: the approximated sensitivity
    """
    prunable = get_prunable_layers(module)
    tensors = []

    for (name, layer) in prunable:
        weight = getattr(layer, "weight")
        values = weight.view(-1).abs()
        tensors.append(values)

    all_weights = torch.cat(tensors)
    avg = all_weights.mean().item()

    return avg
示例#7
0
    def from_sparse_model(model: Module) -> List[ScheduledModifier]:
        """
        Create constant ks modifiers for all prunable params in the given model
        (conv, linear) that have been artificially sparsified (sparsity > 40%).
        Useful for transfer learning from a pruned model.

        :param model: the model to create constant ks modifiers for
        :return: the list of created constant ks modifiers
        """
        prunable = get_prunable_layers(model)
        modifiers = []

        for name, layer in prunable:
            weight = getattr(layer, "weight")
            sparsity = tensor_sparsity(weight)

            if sparsity > 0.4:
                modifiers.append(
                    ConstantPruningModifier(
                        params=["{}.{}".format(name, "weight")]))

        return modifiers
def pruning_loss_sens_one_shot(
    module: Module,
    data: DataLoader,
    loss: Union[LossWrapper, Callable[[Any, Any], Tensor]],
    device: str,
    steps_per_measurement: int,
    sparsity_levels: List[int] = default_pruning_sparsities_loss(False),
    loss_key: str = DEFAULT_LOSS_KEY,
    tester_run_funcs: ModuleRunFuncs = None,
    tester_loggers: List[BaseLogger] = None,
    show_progress: bool = True,
) -> PruningLossSensitivityAnalysis:
    """
    Run a one shot sensitivity analysis for kernel sparsity.
    It does not retrain, and instead puts the model to eval mode.
    Moves layer by layer to calculate the sensitivity analysis for each and
    resets the previously run layers.
    Note, by default it caches the data.
    This means it is not parallel for data loading and the first run can take longer.
    Subsequent sparsity checks for layers and levels will be much faster.

    :param module: the module to run the kernel sparsity sensitivity analysis over
        will extract all prunable layers out
    :param data: the data to run through the module for calculating the sensitivity
        analysis
    :param loss: the loss function to use for the sensitivity analysis
    :param device: the device to run the analysis on; ex: cpu, cuda
    :param steps_per_measurement: the number of samples or items to take for each
        measurement at each sparsity lev
    :param sparsity_levels: the sparsity levels to check for each layer to calculate
        sensitivity
    :param loss_key: the key for the loss function to track in the returned dict
    :param tester_run_funcs: override functions to use in the ModuleTester that runs
    :param tester_loggers: loggers to log data to while running the analysis
    :param show_progress: track progress of the runs if True
    :return: the sensitivity results for every layer that is prunable
    """
    analysis = PruningLossSensitivityAnalysis()
    tester = ModuleTester(
        module,
        device,
        loss,
        loggers=tester_loggers,
        log_summary=False,
        log_steps=max(1, round(steps_per_measurement / 10)),
    )
    layers = get_prunable_layers(module)
    batch_end = _sensitivity_callback(layers, sparsity_levels,
                                      steps_per_measurement, analysis,
                                      loss_key)
    batch_end_hook = tester.run_hooks.register_batch_end_hook(batch_end)
    if tester_run_funcs is not None:
        tester.run_funcs.copy(tester_run_funcs)

    data_loader = infinite_data_loader(data,
                                       early_stop_steps=steps_per_measurement,
                                       cache=True)
    tester.run(
        data_loader,
        desc="KS Analysis",
        show_progress=show_progress,
        track_results=False,
        max_steps=steps_per_measurement * len(sparsity_levels) * len(layers),
    )
    batch_end_hook.remove()

    return analysis
示例#9
0
def train(
    train_args: TrainingArguments,
    model: Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    input_shape: Tuple[int, ...],
    save_dir: str,
    loggers: List[Any],
) -> None:
    """
    Utility function to drive the training processing

    :param train_args: A TrainingArguments object with
        arguments for current training task
    :param model: model architecture to train
    :param train_loader: A DataLoader for training data
    :param val_loader: A DataLoader for validation data
    :param input_shape: A tuple of integers representing the shape of inputs
    :param save_dir: Directory to store checkpoints at during training process
    :param loggers: List of loggers to use during training process
    """
    # loss setup
    val_loss = helpers.get_loss_wrapper(arch_key=train_args.arch_key,
                                        training=True)
    LOGGER.info(f"created loss for validation: {val_loss}")

    train_loss = helpers.get_loss_wrapper(arch_key=train_args.arch_key,
                                          training=True)
    LOGGER.info(f"created loss for training: {train_loss}")

    # training setup
    if not train_args.eval_mode:
        epoch, optim, manager = helpers.create_scheduled_optimizer(
            train_args,
            model,
            train_loader,
            loggers,
        )
    else:
        epoch = 0
        train_loss = None
        optim = None
        manager = None

    # device setup
    if train_args.rank == -1:
        device = train_args.device
        ddp = False
    else:
        torch.cuda.set_device(train_args.local_rank)
        device = train_args.local_rank
        ddp = True

    model, device, device_ids = model_to_device(model, device, ddp=ddp)
    LOGGER.info(f"running on device {device} for ids {device_ids}")

    trainer = (ModuleTrainer(
        model,
        device,
        train_loss,
        optim,
        loggers=loggers,
        device_context=ModuleDeviceContext(
            use_mixed_precision=train_args.use_mixed_precision,
            world_size=train_args.world_size,
        ),
    ) if not train_args.eval_mode else None)

    if train_args.is_main_process:  # only test on one DDP process if using DDP
        tester = ModuleTester(model,
                              device,
                              val_loss,
                              loggers=loggers,
                              log_steps=-1)

        # initial baseline eval run
        tester.run_epoch(val_loader,
                         epoch=epoch - 1,
                         max_steps=train_args.debug_steps)

    if not train_args.eval_mode:
        helpers.save_recipe(recipe_manager=manager, save_dir=save_dir)
        LOGGER.info(f"starting training from epoch {epoch}")

        if epoch > 0:
            LOGGER.info("adjusting ScheduledOptimizer to restore point")
            optim.adjust_current_step(epoch, 0)

        target_metric = ("top1acc" if "top1acc" in tester.loss.available_losses
                         else DEFAULT_LOSS_KEY)
        best_metric = None
        val_res = None

        while epoch < manager.max_epochs:
            if train_args.debug_steps > 0:
                # correct since all optimizer steps are not
                # taken in the epochs for debug mode
                optim.adjust_current_step(epoch, 0)

            if train_args.rank != -1:  # sync DDP dataloaders
                train_loader.sampler.set_epoch(epoch)

            trainer.run_epoch(
                train_loader,
                epoch,
                max_steps=train_args.debug_steps,
                show_progress=train_args.is_main_process,
            )

            # testing steps
            if train_args.is_main_process:
                # only test and save on main process
                val_res = tester.run_epoch(val_loader,
                                           epoch,
                                           max_steps=train_args.debug_steps)
                val_metric = val_res.result_mean(target_metric).item()

                if epoch >= train_args.save_best_after and (
                        best_metric is None or
                    (val_metric <= best_metric if target_metric != "top1acc"
                     else val_metric >= best_metric)):
                    helpers.save_model_training(
                        model,
                        optim,
                        input_shape,
                        "checkpoint-best",
                        save_dir,
                        epoch,
                        val_res,
                    )
                    best_metric = val_metric

            # save checkpoints
            _save_epoch = (train_args.is_main_process
                           and train_args.save_epochs
                           and epoch in train_args.save_epochs)
            if _save_epoch:
                helpers.save_model_training(
                    model,
                    optim,
                    input_shape,
                    f"checkpoint-{epoch:04d}-{val_metric:.04f}",
                    save_dir,
                    epoch,
                    val_res,
                )

            epoch += 1

        # export the final model
        LOGGER.info("completed...")
        if train_args.is_main_process:
            # only convert qat -> quantized ONNX graph for finalized model
            # TODO: change this to all checkpoints when conversion times improve
            helpers.save_model_training(model, optim, input_shape, "model",
                                        save_dir, epoch - 1, val_res, True)

            LOGGER.info("layer sparsities:")
            for (name, layer) in get_prunable_layers(model):
                LOGGER.info(
                    f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}"
                )

    # close DDP
    if train_args.rank != -1:
        torch.distributed.destroy_process_group()