def _check_step(self, activation_sparsifier, data_agg_actual): """Checks if .step() works as expected. Specifically, checks if the mask is computed correctly. Args: activation_sparsifier (sparsifier object) activation sparsifier object that is being tested. data_agg_actual (torch tensor) aggregated torch tensor """ model = activation_sparsifier.model layer_name = module_to_fqn(model, model.conv1) assert layer_name is not None reduce_fn = activation_sparsifier.data_groups[layer_name]['reduce_fn'] data_reduce_actual = reduce_fn(data_agg_actual) mask_fn = activation_sparsifier.data_groups[layer_name]['mask_fn'] sparse_config = activation_sparsifier.data_groups[layer_name][ 'sparse_config'] mask_actual = mask_fn(data_reduce_actual, **sparse_config) mask_model = activation_sparsifier.get_mask(layer_name) assert torch.all(mask_model == mask_actual) for _, config in activation_sparsifier.data_groups.items(): assert 'data' not in config
def _check_pre_forward_hook(self, activation_sparsifier, data_list): """Registering a layer attaches a pre-forward hook to that layer. This function checks if the pre-forward hook works as expected. Specifically, checks if the input is aggregated correctly. Basically, asserts that the aggregate of input activations is the same as what was computed in the sparsifier. Args: activation_sparsifier (sparsifier object) activation sparsifier object that is being tested. data_list (list of torch tensors) data input to the model attached to the sparsifier """ # can only check for the first layer data_agg_actual = data_list[0] model = activation_sparsifier.model layer_name = module_to_fqn(model, model.conv1) agg_fn = activation_sparsifier.data_groups[layer_name]['aggregate_fn'] for i in range(1, len(data_list)): data_agg_actual = agg_fn(data_agg_actual, data_list[i]) assert 'data' in activation_sparsifier.data_groups[layer_name] assert torch.all(activation_sparsifier.data_groups[layer_name]['data'] == data_agg_actual) return data_agg_actual
def test_module_to_fqn_root(self): """ Tests that module_to_fqn returns '' when model and target module are the same """ for model_class in model_list: model = model_class() fqn = module_to_fqn(model, model) self.assertEqual(fqn, "")
def test_module_to_fqn_fail(self): """ Tests that module_to_fqn returns None when an fqn that doesn't correspond to a path to a node/tensor is given """ for model_class in model_list: model = model_class() fqn = module_to_fqn(model, torch.nn.Linear(3, 3)) self.assertEqual(fqn, None)
def test_fqn_to_module(self): """ Tests that fqn_to_module operates as inverse of module_to_fqn """ for model_class in model_list: model = model_class() list_of_modules = [m for _, m in model.named_modules()] + [model] for module in list_of_modules: fqn = module_to_fqn(model, module) check_module = fqn_to_module(model, fqn) self.assertEqual(module, check_module)
def test_module_to_fqn(self): """ Tests that module_to_fqn works as expected when compared to known good module.get_submodule(fqn) function """ for model_class in model_list: model = model_class() list_of_modules = [m for _, m in model.named_modules()] + [model] for module in list_of_modules: fqn = module_to_fqn(model, module) check_module = model.get_submodule(fqn) self.assertEqual(module, check_module)
def _fetch_all_embeddings(model): """Fetches Embedding and EmbeddingBag modules from the model """ embedding_modules = [] stack = [model] while stack: module = stack.pop() for _, child in module.named_children(): fqn_name = module_to_fqn(model, child) if type(child) in SUPPORTED_MODULES: embedding_modules.append((fqn_name, child)) else: stack.append(child) return embedding_modules
def _check_register_layer(self, activation_sparsifier, defaults, sparse_config, layer_args_list): """Checks if layers in the model are correctly mapped to it's arguments. Args: activation_sparsifier (sparsifier object) activation sparsifier object that is being tested. defaults (Dict) all default config (except sparse_config) sparse_config (Dict) default sparse config passed to the sparsifier layer_args_list (list of tuples) Each entry in the list corresponds to the layer arguments. First entry in the tuple corresponds to all the arguments other than sparse_config Second entry in the tuple corresponds to sparse_config """ # check args data_groups = activation_sparsifier.data_groups assert len(data_groups) == len(layer_args_list) for layer_args in layer_args_list: layer_arg, sparse_config_layer = layer_args # check sparse config sparse_config_actual = copy.deepcopy(sparse_config) sparse_config_actual.update(sparse_config_layer) name = module_to_fqn(activation_sparsifier.model, layer_arg['layer']) assert data_groups[name]['sparse_config'] == sparse_config_actual # assert the rest other_config_actual = copy.deepcopy(defaults) other_config_actual.update(layer_arg) other_config_actual.pop('layer') for key, value in other_config_actual.items(): assert key in data_groups[name] assert value == data_groups[name][key] # get_mask should raise error with self.assertRaises(ValueError): activation_sparsifier.get_mask(name=name)
def test_fqn_to_module_for_tensors(self): """ Tests that fqn_to_module works for tensors, actually all parameters of the model. This is tested by identifying a module with a tensor, and generating the tensor_fqn using module_to_fqn on the module + the name of the tensor. """ for model_class in model_list: model = model_class() list_of_modules = [m for _, m in model.named_modules()] + [model] for module in list_of_modules: module_fqn = module_to_fqn(model, module) for tensor_name, tensor in module.named_parameters( recurse=False): tensor_fqn = ( # string manip to handle tensors on root module_fqn + ("." if module_fqn != "" else "") + tensor_name) check_tensor = fqn_to_module(model, tensor_fqn) self.assertEqual(tensor, check_tensor)
def test_get_arg_info_from_tensor_fqn(self): """ Tests that get_arg_info_from_tensor_fqn works for all parameters of the model. Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and then compares with known (parent) module and tensor_name as well as module_fqn from module_to_fqn. """ for model_class in model_list: model = model_class() list_of_modules = [m for _, m in model.named_modules()] + [model] for module in list_of_modules: module_fqn = module_to_fqn(model, module) for tensor_name, tensor in module.named_parameters( recurse=False): tensor_fqn = (module_fqn + ("." if module_fqn != "" else "") + tensor_name) arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) self.assertEqual(arg_info["module"], module) self.assertEqual(arg_info["module_fqn"], module_fqn) self.assertEqual(arg_info["tensor_name"], tensor_name) self.assertEqual(arg_info["tensor_fqn"], tensor_fqn)
def post_training_sparse_quantize(model, data_sparsifier_class, sparsify_first=True, select_embeddings: List[nn.Module] = None, **sparse_config): """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. Args: - model (nn.Module) model whose embeddings needs to be sparsified - data_sparsifier_class (type of data sparsifier) Type of sparsification that needs to be applied to model - sparsify_first (bool) if true, sparsifies first and then quantizes otherwise, quantizes first and then sparsifies. - select_embeddings (List of Embedding modules) List of embedding modules to in the model to be sparsified & quantized. If None, all embedding modules with be sparsified - sparse_config (Dict) config that will be passed to the constructor of data sparsifier object. Note: 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. - before sparsifying, the embedding layers are dequantized. - scales and zero-points are saved - embedding layers are sparsified and `squash_mask` is applied - embedding weights are requantized using the saved scales and zero-points 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. - embeddings are sparsified first - quantization is applied on the sparsified embeddings """ data_sparsifier = data_sparsifier_class(**sparse_config) # if select_embeddings is None, perform it on all embeddings if select_embeddings is None: embedding_modules = _fetch_all_embeddings(model) else: embedding_modules = [] assert isinstance( select_embeddings, List), "the embedding_modules must be a list of embedding modules" for emb in select_embeddings: assert type( emb ) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags" fqn_name = module_to_fqn(model, emb) assert fqn_name is not None, "the embedding modules must be part of input model" embedding_modules.append((fqn_name, emb)) if sparsify_first: # sparsify for name, emb_module in embedding_modules: valid_name = name.replace('.', '_') data_sparsifier.add_data(name=valid_name, data=emb_module) data_sparsifier.step() data_sparsifier.squash_mask() # quantize for _, emb_module in embedding_modules: emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True) else: # quantize for _, emb_module in embedding_modules: emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True) # retrieve scale & zero_points quantize_params: Dict[str, Dict] = { 'scales': {}, 'zero_points': {}, 'dequant_weights': {}, 'axis': {}, 'dtype': {} } for name, _ in embedding_modules: quantized_emb = fqn_to_module(model, name) assert quantized_emb is not None # satisfy mypy quantized_weight = quantized_emb.weight() # type: ignore[operator] quantize_params['scales'][ name] = quantized_weight.q_per_channel_scales() quantize_params['zero_points'][ name] = quantized_weight.q_per_channel_zero_points() quantize_params['dequant_weights'][name] = torch.dequantize( quantized_weight) quantize_params['axis'][ name] = quantized_weight.q_per_channel_axis() quantize_params['dtype'][name] = quantized_weight.dtype # attach data to sparsifier data_sparsifier.add_data( name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name]) data_sparsifier.step() data_sparsifier.squash_mask() for name, _ in embedding_modules: quantized_emb = fqn_to_module(model, name) assert quantized_emb is not None # satisfy mypy requantized_vector = torch.quantize_per_channel( quantize_params['dequant_weights'][name], scales=quantize_params['scales'][name], zero_points=quantize_params['zero_points'][name], dtype=quantize_params['dtype'][name], axis=quantize_params['axis'][name]) quantized_emb.set_weight( requantized_vector) # type: ignore[operator]