예제 #1
0
    def test_lifecycle(
        self,
        modifier_lambda,
        model_lambda,
        optim_lambda,
        test_steps_per_epoch,  # noqa: F811
    ):
        modifier = modifier_lambda()
        model = model_lambda()
        optimizer = optim_lambda(model)
        self.initialize_helper(modifier, model)
        assert len(modifier._layer_modules) > 0
        if modifier.start_epoch > 0:
            for (name, mod) in modifier._layer_modules.items():
                assert mod is None
                assert not isinstance(get_layer(name, model), Identity)

            # check sparsity is not set before
            for epoch in range(int(modifier.start_epoch)):
                assert not modifier.update_ready(epoch, test_steps_per_epoch)

            for (name, mod) in modifier._layer_modules.items():
                assert mod is None
                assert not isinstance(get_layer(name, model), Identity)

            epoch = int(modifier.start_epoch)
            assert modifier.update_ready(epoch, test_steps_per_epoch)
            modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch)
        else:
            epoch = 0

        for (name, mod) in modifier._layer_modules.items():
            assert mod is not None
            assert isinstance(get_layer(name, model), Identity)

        # check forward pass
        input_shape = model_lambda.layer_descs()[0].input_size
        test_batch = torch.randn(10, *input_shape)
        _ = model(test_batch)

        end_epoch = (
            modifier.end_epoch if modifier.end_epoch > -1 else modifier.start_epoch + 10
        )

        while epoch < end_epoch - 0.1:
            epoch += 0.1
            assert modifier.update_ready(epoch, test_steps_per_epoch)
            modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch)

        _ = model(test_batch)  # check forward pass

        if modifier.end_epoch > -1:
            epoch = int(modifier.end_epoch)
            assert modifier.update_ready(epoch, test_steps_per_epoch)
            modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch)

            for (name, mod) in modifier._layer_modules.items():
                assert mod is None
                assert not isinstance(get_layer(name, model), Identity)
예제 #2
0
    def initialize(
        self,
        module: Module,
        epoch: float = 0,
        loggers: Optional[List[BaseLogger]] = None,
        **kwargs,
    ):
        """
        Grabs the layers to control the activation sparsity for

        :param module: the PyTorch model/module to modify
        :param epoch: The epoch to initialize the modifier and module at.
            Defaults to 0 (start of the training process)
        :param loggers: Optional list of loggers to log the modification process to
        :param kwargs: Optional kwargs to support specific arguments
            for individual modifiers.
        """
        super(ASRegModifier, self).initialize(module, epoch, loggers, **kwargs)
        layers = (get_terminal_layers(module) if self._layers == ALL_TOKEN else
                  [get_layer(name, module) for name in self._layers])

        for layer in layers:
            self._trackers.append(
                ASLayerTracker(
                    layer,
                    track_input=self._reg_tens == "inp",
                    track_output=self._reg_tens == "out",
                    input_func=self._regularize_tracked,
                    output_func=self._regularize_tracked,
                ))
예제 #3
0
def _test_const(module, name, param_name):
    layer = get_layer(name, module)
    analyzer = ModulePruningAnalyzer(layer, name, param_name)
    assert analyzer.module == layer
    assert analyzer.name == name
    assert analyzer.param_name == param_name
    assert analyzer.tag == "{}.{}".format(name, param_name)
    assert isinstance(analyzer.param, Parameter)
예제 #4
0
def test_param_sparsity(module, name, param_name, param_data,
                        expected_sparsity):
    layer = get_layer(name, module)
    analyzer = ModulePruningAnalyzer(layer, name, param_name)
    analyzer.param.data = param_data

    assert torch.sum(
        (analyzer.param_sparsity - expected_sparsity).abs()) < 0.00001
예제 #5
0
    def layer_desc(self, name: Union[str, None] = None) -> AnalyzedLayerDesc:
        """
        Get a specific layer's description within the Module.
        Set to None to get the overall Module's description.

        :param name: name of the layer to get a description for,
            None for an overall description
        :return: the analyzed layer description for the given name
        """
        if not self._forward_called:
            raise RuntimeError(
                "module must have forward called with sample input "
                "before getting a layer desc"
            )

        mod = get_layer(name, self._module) if name is not None else self._module

        return ModuleAnalyzer._mod_desc(mod)
예제 #6
0
    def initialize(self, module: Module, optimizer: Optimizer):
        """
        Grab the layer's to control activation sparsity for

        :param module: module to modify
        :param optimizer: optimizer to modify
        """
        super(ASRegModifier, self).initialize(module, optimizer)
        layers = (get_terminal_layers(module) if self._layers == ALL_TOKEN else
                  [get_layer(name, module) for name in self._layers])

        for layer in layers:
            self._trackers.append(
                ASLayerTracker(
                    layer,
                    track_input=self._reg_tens == "inp",
                    track_output=self._reg_tens == "out",
                    input_func=self._regularize_tracked,
                    output_func=self._regularize_tracked,
                ))
예제 #7
0
    def _measure_layer(self, layer: Union[str,
                                          None], module: Module, device: str,
                       desc: str) -> Tuple[ModuleRunResults, Tensor]:
        layer = get_layer(layer, module) if layer else None
        as_analyzer = None

        if layer:
            as_analyzer = ModuleASAnalyzer(layer,
                                           dim=None,
                                           track_outputs_sparsity=True)
            as_analyzer.enable()

        tester = ModuleTester(module, device, self._loss)
        data_loader = DataLoader(self._dataset, self._batch_size,
                                 **self._dataloader_kwargs)
        results = tester.run(data_loader,
                             desc=desc,
                             show_progress=True,
                             track_results=True)

        if as_analyzer:
            as_analyzer.disable()

        return results, as_analyzer.outputs_sparsity_mean if as_analyzer else None
예제 #8
0
    def _binary_search_fat(
        self,
        layer: str,
        module: Module,
        device: str,
        min_thresh: float,
        max_thresh: float,
        max_target_metric_loss: float,
        metric_key: str,
        metric_increases: bool,
        precision: float = 0.001,
    ):
        print("\n\n\nstarting binary search for layer {} between ({}, {})...".
              format(layer, min_thresh, max_thresh))
        base_res, base_as = self._measure_layer(
            layer, module, device, "baseline for layer: {}".format(layer))

        fat_relu = get_layer(layer, module)  # type: FATReLU
        init_thresh = fat_relu.get_threshold()

        if min_thresh > init_thresh:
            min_thresh = init_thresh

        while True:
            thresh = ModuleASOneShootBooster._get_mid_point(
                min_thresh, max_thresh)
            fat_relu.set_threshold(thresh)

            thresh_res, thresh_as = self._measure_layer(
                layer,
                module,
                device,
                "threshold for layer: {} @ {:.4f} ({:.4f}, {:.4f})".format(
                    layer, thresh, min_thresh, max_thresh),
            )

            if ModuleASOneShootBooster._passes_loss(
                    base_res,
                    thresh_res,
                    max_target_metric_loss,
                    metric_key,
                    metric_increases,
            ):
                min_thresh = thresh
                print("loss check passed for max change: {:.4f}".format(
                    max_target_metric_loss))
                print("   current loss: {:.4f} baseline loss: {:.4f}".format(
                    thresh_res.result_mean(metric_key),
                    base_res.result_mean(metric_key),
                ))
                print("   current AS: {:.4f} baseline AS: {:.4f}".format(
                    thresh_as, base_as))
            else:
                max_thresh = thresh
                print("loss check failed for max change: {:.4f}".format(
                    max_target_metric_loss))
                print("   current loss: {:.4f} baseline loss: {:.4f}".format(
                    thresh_res.result_mean(metric_key),
                    base_res.result_mean(metric_key),
                ))
                print("   current AS: {:.4f} baseline AS: {:.4f}".format(
                    thresh_as, base_as))

            if max_thresh - min_thresh <= precision:
                break

        print("completed binary search for layer {}")
        print("   found threshold: {}".format(thresh))
        print("   AS delta: {:.4f} ({:.4f} => {:.4f})".format(
            thresh_as - base_as, base_as, thresh_as))

        return LayerBoostResults(layer, thresh, thresh_as, thresh_res, base_as,
                                 base_res)