def _check_step(self, activation_sparsifier, data_agg_actual):
        """Checks if .step() works as expected. Specifically, checks if the mask is computed correctly.

        Args:
            activation_sparsifier (sparsifier object)
                activation sparsifier object that is being tested.

            data_agg_actual (torch tensor)
                aggregated torch tensor

        """
        model = activation_sparsifier.model
        layer_name = module_to_fqn(model, model.conv1)
        assert layer_name is not None

        reduce_fn = activation_sparsifier.data_groups[layer_name]['reduce_fn']

        data_reduce_actual = reduce_fn(data_agg_actual)
        mask_fn = activation_sparsifier.data_groups[layer_name]['mask_fn']
        sparse_config = activation_sparsifier.data_groups[layer_name][
            'sparse_config']
        mask_actual = mask_fn(data_reduce_actual, **sparse_config)

        mask_model = activation_sparsifier.get_mask(layer_name)

        assert torch.all(mask_model == mask_actual)

        for _, config in activation_sparsifier.data_groups.items():
            assert 'data' not in config
    def _check_pre_forward_hook(self, activation_sparsifier, data_list):
        """Registering a layer attaches a pre-forward hook to that layer. This function
        checks if the pre-forward hook works as expected. Specifically, checks if the
        input is aggregated correctly.

        Basically, asserts that the aggregate of input activations is the same as what was
        computed in the sparsifier.

        Args:
            activation_sparsifier (sparsifier object)
                activation sparsifier object that is being tested.

            data_list (list of torch tensors)
                data input to the model attached to the sparsifier

        """
        # can only check for the first layer
        data_agg_actual = data_list[0]
        model = activation_sparsifier.model
        layer_name = module_to_fqn(model, model.conv1)
        agg_fn = activation_sparsifier.data_groups[layer_name]['aggregate_fn']

        for i in range(1, len(data_list)):
            data_agg_actual = agg_fn(data_agg_actual, data_list[i])

        assert 'data' in activation_sparsifier.data_groups[layer_name]
        assert torch.all(activation_sparsifier.data_groups[layer_name]['data']
                         == data_agg_actual)

        return data_agg_actual
 def test_module_to_fqn_root(self):
     """
     Tests that module_to_fqn returns '' when model and target module are the same
     """
     for model_class in model_list:
         model = model_class()
         fqn = module_to_fqn(model, model)
         self.assertEqual(fqn, "")
 def test_module_to_fqn_fail(self):
     """
     Tests that module_to_fqn returns None when an fqn that doesn't
     correspond to a path to a node/tensor is given
     """
     for model_class in model_list:
         model = model_class()
         fqn = module_to_fqn(model, torch.nn.Linear(3, 3))
         self.assertEqual(fqn, None)
 def test_fqn_to_module(self):
     """
     Tests that fqn_to_module operates as inverse
     of module_to_fqn
     """
     for model_class in model_list:
         model = model_class()
         list_of_modules = [m for _, m in model.named_modules()] + [model]
         for module in list_of_modules:
             fqn = module_to_fqn(model, module)
             check_module = fqn_to_module(model, fqn)
             self.assertEqual(module, check_module)
 def test_module_to_fqn(self):
     """
     Tests that module_to_fqn works as expected when compared to known good
     module.get_submodule(fqn) function
     """
     for model_class in model_list:
         model = model_class()
         list_of_modules = [m for _, m in model.named_modules()] + [model]
         for module in list_of_modules:
             fqn = module_to_fqn(model, module)
             check_module = model.get_submodule(fqn)
             self.assertEqual(module, check_module)
Exemple #7
0
def _fetch_all_embeddings(model):
    """Fetches Embedding and EmbeddingBag modules from the model
    """
    embedding_modules = []
    stack = [model]
    while stack:
        module = stack.pop()
        for _, child in module.named_children():
            fqn_name = module_to_fqn(model, child)
            if type(child) in SUPPORTED_MODULES:
                embedding_modules.append((fqn_name, child))
            else:
                stack.append(child)
    return embedding_modules
    def _check_register_layer(self, activation_sparsifier, defaults,
                              sparse_config, layer_args_list):
        """Checks if layers in the model are correctly mapped to it's arguments.

        Args:
            activation_sparsifier (sparsifier object)
                activation sparsifier object that is being tested.

            defaults (Dict)
                all default config (except sparse_config)

            sparse_config (Dict)
                default sparse config passed to the sparsifier

            layer_args_list (list of tuples)
                Each entry in the list corresponds to the layer arguments.
                First entry in the tuple corresponds to all the arguments other than sparse_config
                Second entry in the tuple corresponds to sparse_config
        """
        # check args
        data_groups = activation_sparsifier.data_groups
        assert len(data_groups) == len(layer_args_list)
        for layer_args in layer_args_list:
            layer_arg, sparse_config_layer = layer_args

            # check sparse config
            sparse_config_actual = copy.deepcopy(sparse_config)
            sparse_config_actual.update(sparse_config_layer)

            name = module_to_fqn(activation_sparsifier.model,
                                 layer_arg['layer'])

            assert data_groups[name]['sparse_config'] == sparse_config_actual

            # assert the rest
            other_config_actual = copy.deepcopy(defaults)
            other_config_actual.update(layer_arg)
            other_config_actual.pop('layer')

            for key, value in other_config_actual.items():
                assert key in data_groups[name]
                assert value == data_groups[name][key]

            # get_mask should raise error
            with self.assertRaises(ValueError):
                activation_sparsifier.get_mask(name=name)
 def test_fqn_to_module_for_tensors(self):
     """
     Tests that fqn_to_module works for tensors, actually all parameters
     of the model. This is tested by identifying a module with a tensor,
     and generating the tensor_fqn using module_to_fqn on the module +
     the name of the tensor.
     """
     for model_class in model_list:
         model = model_class()
         list_of_modules = [m for _, m in model.named_modules()] + [model]
         for module in list_of_modules:
             module_fqn = module_to_fqn(model, module)
             for tensor_name, tensor in module.named_parameters(
                     recurse=False):
                 tensor_fqn = (  # string manip to handle tensors on root
                     module_fqn + ("." if module_fqn != "" else "") +
                     tensor_name)
                 check_tensor = fqn_to_module(model, tensor_fqn)
                 self.assertEqual(tensor, check_tensor)
 def test_get_arg_info_from_tensor_fqn(self):
     """
     Tests that get_arg_info_from_tensor_fqn works for all parameters of the model.
     Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and
     then compares with known (parent) module and tensor_name as well as module_fqn
     from module_to_fqn.
     """
     for model_class in model_list:
         model = model_class()
         list_of_modules = [m for _, m in model.named_modules()] + [model]
         for module in list_of_modules:
             module_fqn = module_to_fqn(model, module)
             for tensor_name, tensor in module.named_parameters(
                     recurse=False):
                 tensor_fqn = (module_fqn +
                               ("." if module_fqn != "" else "") +
                               tensor_name)
                 arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
                 self.assertEqual(arg_info["module"], module)
                 self.assertEqual(arg_info["module_fqn"], module_fqn)
                 self.assertEqual(arg_info["tensor_name"], tensor_name)
                 self.assertEqual(arg_info["tensor_fqn"], tensor_fqn)
Exemple #11
0
def post_training_sparse_quantize(model,
                                  data_sparsifier_class,
                                  sparsify_first=True,
                                  select_embeddings: List[nn.Module] = None,
                                  **sparse_config):
    """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
    The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.

    Args:
        - model (nn.Module)
            model whose embeddings needs to be sparsified
        - data_sparsifier_class (type of data sparsifier)
            Type of sparsification that needs to be applied to model
        - sparsify_first (bool)
            if true, sparsifies first and then quantizes
            otherwise, quantizes first and then sparsifies.
        - select_embeddings (List of Embedding modules)
            List of embedding modules to in the model to be sparsified & quantized.
            If None, all embedding modules with be sparsified
        - sparse_config (Dict)
            config that will be passed to the constructor of data sparsifier object.

    Note:
        1. When `sparsify_first=False`, quantization occurs first followed by sparsification.
            - before sparsifying, the embedding layers are dequantized.
            - scales and zero-points are saved
            - embedding layers are sparsified and `squash_mask` is applied
            - embedding weights are requantized using the saved scales and zero-points
        2. When `sparsify_first=True`, sparsification occurs first followed by quantization.
            - embeddings are sparsified first
            - quantization is applied on the sparsified embeddings
    """
    data_sparsifier = data_sparsifier_class(**sparse_config)

    # if select_embeddings is None, perform it on all embeddings
    if select_embeddings is None:
        embedding_modules = _fetch_all_embeddings(model)

    else:
        embedding_modules = []
        assert isinstance(
            select_embeddings,
            List), "the embedding_modules must be a list of embedding modules"
        for emb in select_embeddings:
            assert type(
                emb
            ) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags"
            fqn_name = module_to_fqn(model, emb)
            assert fqn_name is not None, "the embedding modules must be part of input model"
            embedding_modules.append((fqn_name, emb))

    if sparsify_first:
        # sparsify
        for name, emb_module in embedding_modules:
            valid_name = name.replace('.', '_')
            data_sparsifier.add_data(name=valid_name, data=emb_module)

        data_sparsifier.step()
        data_sparsifier.squash_mask()

        # quantize
        for _, emb_module in embedding_modules:
            emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig

        torch.quantization.prepare(model, inplace=True)
        torch.quantization.convert(model, inplace=True)

    else:
        # quantize
        for _, emb_module in embedding_modules:
            emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig

        torch.quantization.prepare(model, inplace=True)
        torch.quantization.convert(model, inplace=True)

        # retrieve scale & zero_points
        quantize_params: Dict[str, Dict] = {
            'scales': {},
            'zero_points': {},
            'dequant_weights': {},
            'axis': {},
            'dtype': {}
        }

        for name, _ in embedding_modules:
            quantized_emb = fqn_to_module(model, name)
            assert quantized_emb is not None  # satisfy mypy

            quantized_weight = quantized_emb.weight()  # type: ignore[operator]
            quantize_params['scales'][
                name] = quantized_weight.q_per_channel_scales()
            quantize_params['zero_points'][
                name] = quantized_weight.q_per_channel_zero_points()
            quantize_params['dequant_weights'][name] = torch.dequantize(
                quantized_weight)
            quantize_params['axis'][
                name] = quantized_weight.q_per_channel_axis()
            quantize_params['dtype'][name] = quantized_weight.dtype

            # attach data to sparsifier
            data_sparsifier.add_data(
                name=name.replace('.', '_'),
                data=quantize_params['dequant_weights'][name])

        data_sparsifier.step()
        data_sparsifier.squash_mask()

        for name, _ in embedding_modules:
            quantized_emb = fqn_to_module(model, name)
            assert quantized_emb is not None  # satisfy mypy
            requantized_vector = torch.quantize_per_channel(
                quantize_params['dequant_weights'][name],
                scales=quantize_params['scales'][name],
                zero_points=quantize_params['zero_points'][name],
                dtype=quantize_params['dtype'][name],
                axis=quantize_params['axis'][name])

            quantized_emb.set_weight(
                requantized_vector)  # type: ignore[operator]