def from_float(cls, mod): r"""Create a quantized sparse dynamic module from a float module. We only care about the convert at this stage, no need for observers just yet. """ assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ # TODO: Need to add options to qconfig to avoid the calibration. # TODO: Add calibration for the sparsity assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' if type(mod) == nni.LinearReLU: mod = mod[0] if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() else: # We have the circular import issues if we import the qconfig in the beginning of this file: # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the # import until we need it. from torch.ao.quantization.qconfig import default_dynamic_qconfig weight_observer = default_dynamic_qconfig.weight() # It is important to multiply by the mask BEFORE calling the `weight_observer` # TODO (zaf): Mask might not be part of the qconfig (T83295194) weight = mod.weight if getattr(mod.qconfig, 'mask', False): weight = mod.qconfig.mask * mod.weight weight_observer(weight) dtype = weight_observer.dtype assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' w_sc, w_zp = weight_observer.calculate_qparams() if isinstance(w_zp, torch.Tensor): assert not torch.any( w_zp.bool()), "All weight zero points must map to 0" else: assert w_zp == 0, 'Weight zero point must map to 0' qweight = _quantize_weight(weight.float(), weight_observer) row_block_size, col_block_size = LinearBlockSparsePattern.block_size() qlinear = cls(mod.in_features, mod.out_features, row_block_size, col_block_size, dtype=dtype) qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) return qlinear
def from_float(cls, mod): r"""Create a quantized sparse module from a float module. We only care about the convert at this stage, no need for observers just yet. TODO: Need to figure out how to store the block shapes in the mod """ assert type(mod) == cls._FLOAT_MODULE, cls._get_name() + \ '.from_float only works for ' + cls._FLOAT_MODULE.__name__ # TODO: Need to add options to qconfig to avoid the calibration. # TODO: Add calibration for the sparsity assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() # Assumption is that the weight is already sparsified by the # `sparsifier.convert` weight = mod.weight weight_post_process(weight) dtype = weight_post_process.dtype act_scale, act_zp = activation_post_process.calculate_qparams() assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' w_sc, w_zp = weight_post_process.calculate_qparams() if isinstance(w_zp, torch.Tensor): assert not torch.any( w_zp.bool()), "All weight zero points must map to 0" else: assert w_zp == 0, 'Weight zero point must map to 0' qweight = _quantize_weight(weight.float(), weight_post_process) row_block_size, col_block_size = LinearBlockSparsePattern.block_size() qlinear = cls(mod.in_features, mod.out_features, row_block_size, col_block_size, dtype=dtype) qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) qlinear.scale = float(act_scale) qlinear.zero_point = int(act_zp) return qlinear
def test_sparse_qlinear_serdes(self): batch_size = 12 input_channels = 4 output_channels = 7 model = self.SparseQuantizedModel(input_channels, output_channels) # For sparse kernels both the activation and weight ZP = 0 X_scale = 0.2 X_zp = 0 W_scale = 1e-2 W_zp = 0 with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32) float_bias = torch.randn(output_channels, dtype=torch.float32) X_q = torch.quantize_per_tensor(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8) X_fp32 = X_q.dequantize() W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32) mask = torch.randint(0, 2, W_fp32.shape) W_fp32 *= mask W_q = torch.quantize_per_tensor(W_fp32, W_scale, W_zp, torch.qint8) model.linear.weight = nn.Parameter(W_q.dequantize()) model.linear.sparse_params = {'sparse_block_shape': (1, 4)} model.eval() # Note: At the moment, for sparse kernels # fbgemm supports only static quantized sparse linear # qnnpack supports only dynamically quantized sparse linear # Hence we have two different tests. # fbgemm tests static flow, qnnpack tests dynamic. # Should be unified later on and tests should be fixed # appropriately. if qengine_is_fbgemm(): model.qconfig = tq.get_default_qconfig('fbgemm') qmodel = copy.deepcopy(model) sqmodel = copy.deepcopy(model) tq.prepare(qmodel, inplace=True) tq.prepare(sqmodel, inplace=True) with torch.no_grad(): qmodel(X_fp32) sqmodel(X_fp32) # Make sure the quantization parameters are computed the same way qparams = qmodel.linear.qconfig.weight().calculate_qparams() sqparams = sqmodel.linear.qconfig.weight().calculate_qparams() self.assertEqual(qparams, sqparams) # Make sure mapping of sparse kernels does not affect the non-sparse sparse_mapping = tq.get_default_static_quant_module_mappings() sparse_mapping[nn.Linear] = ao_nn_sq.Linear tq.convert(sqmodel, inplace=True, mapping=sparse_mapping) tq.convert(qmodel, inplace=True) assert isinstance(sqmodel.linear, ao_nn_sq.Linear), "Convert failed" assert isinstance(qmodel.linear, nn.quantized.Linear), "Mapping failed" scripted_sqmodel = torch.jit.script(sqmodel) scripted_sqmodel.eval() buffer = io.BytesIO() torch.jit.save(scripted_sqmodel, buffer) buffer.seek(0) sqmodel = torch.jit.load(buffer) # Make sure numerics are right Y_ref = qmodel(X_q) Y_hat = sqmodel(X_q) self.assertEqual(Y_ref.dequantize(), Y_hat.dequantize()) elif qengine_is_qnnpack(): qconfig = {nn.Linear: tq.qconfig.default_dynamic_qconfig} dqmodel = copy.deepcopy(model) sdqmodel = copy.deepcopy(model) tq.propagate_qconfig_(dqmodel, qconfig) tq.propagate_qconfig_(sdqmodel, qconfig) # Make sure the quantization parameters are computed the same way qparams = dqmodel.linear.qconfig.weight().calculate_qparams() sqparams = sdqmodel.linear.qconfig.weight().calculate_qparams() self.assertEqual(qparams, sqparams) # Make sure mapping of sparse kernels does not affect the non-sparse sparse_mapping = copy.deepcopy( tq.get_default_dynamic_quant_module_mappings()) sparse_mapping[nn.Linear] = ao_nn_sq.dynamic.Linear with LinearBlockSparsePattern(1, 4): tq.convert(sdqmodel, inplace=True, mapping=sparse_mapping) tq.convert( dqmodel, mapping=tq.get_default_dynamic_quant_module_mappings(), inplace=True) assert isinstance(sdqmodel.linear, ao_nn_sq.dynamic.Linear), "Convert failed" assert isinstance( dqmodel.linear, nn.quantized.dynamic.Linear), "Mapping failed" scripted_sdqmodel = torch.jit.script(sdqmodel) scripted_sdqmodel.eval() buffer = io.BytesIO() torch.jit.save(scripted_sdqmodel, buffer) buffer.seek(0) sdqmodel = torch.jit.load(buffer) # Make sure numerics are right Y_ref = dqmodel(X_fp32) Y_hat = sdqmodel(X_fp32) self.assertEqual(Y_ref, Y_hat)