def _apply_binary_mask_to_module_weight_and_bias(module, mask, module_name=""): with torch.no_grad(): # Applying mask to weights inplace_apply_filter_binary_mask(mask, module.weight, module_name) # Applying mask to bias too (if exists) if module.bias is not None: inplace_apply_filter_binary_mask(mask, module.bias, module_name)
def test_assert_broadcastable_mask_and_weight_shape(): nncf_module = NNCFConv2d(1, 2, 2) fill_conv_weight(nncf_module, 1) fill_bias(nncf_module, 1) mask = torch.zeros(10) with pytest.raises(RuntimeError): inplace_apply_filter_binary_mask(mask, nncf_module.weight.data) with pytest.raises(RuntimeError): apply_filter_binary_mask(mask, nncf_module.weight.data)
def _apply_binary_mask_to_module_weight_and_bias( module, mask, module_scope): with torch.no_grad(): dim = module.target_weight_dim_for_compression if isinstance( module, _NNCFModuleMixin) else 0 # Applying mask to weights inplace_apply_filter_binary_mask(mask, module.weight, module_scope, dim) # Applying mask to bias too (if exists) if module.bias is not None: inplace_apply_filter_binary_mask(mask, module.bias, module_scope)
def test_inplace_apply_filter_binary_mask(mask, reference_weight, reference_bias): """ Test that inplace_apply_filter_binary_mask changes the input weight and returns valid result. """ nncf_module = NNCFConv2d(1, 2, 2) fill_conv_weight(nncf_module, 1) fill_bias(nncf_module, 1) result_weight = inplace_apply_filter_binary_mask( mask, nncf_module.weight.data) assert torch.allclose(result_weight, reference_weight) assert torch.allclose(nncf_module.weight, reference_weight) result_bias = inplace_apply_filter_binary_mask(mask, nncf_module.bias.data) assert torch.allclose(result_bias, reference_bias) assert torch.allclose(nncf_module.bias, reference_bias)