def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer): """test API functionality for nn.quantized.dynamic.Linear""" W = torch.rand(out_features, in_features).float() W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8) W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8) X = torch.rand(batch_size, in_features).float() B = torch.rand(out_features).float() if use_bias else None qlinear = nnqd.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear.set_weight_bias(W_q, B) qlinear(X) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_params._packed_params Z_dq = qlinear(X) # Check if the module implementation matches calling the # ops directly Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack, reduce_range=True) self.assertEqual(Z_ref, Z_dq) # Test serialization of dynamic quantized Linear Module using state_dict model_dict = qlinear.state_dict() b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: if isinstance(model_dict[key], torch._C.ScriptObject): assert isinstance(loaded_dict[key], torch._C.ScriptObject) w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key]) w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key]) self.assertEqual(w_model, w_loaded) self.assertEqual(b_model, b_loaded) else: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qlinear = nnqd.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), linear_unpack(loaded_qlinear._packed_params._packed_params)) if use_bias: self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_params')) self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) self.assertTrue(hasattr(qlinear, '_weight_bias')) self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) Z_dq2 = qlinear(X) self.assertEqual(Z_dq, Z_dq2) b = io.BytesIO() torch.save(qlinear, b) b.seek(0) loaded = torch.load(b) self.assertEqual(qlinear.weight(), loaded.weight()) self.assertEqual(qlinear.zero_point, loaded.zero_point) # Test JIT self.checkScriptable(qlinear, [[X]], check_save_load=True) # Test from_float float_linear = torch.nn.Linear(in_features, out_features).float() if use_default_observer: float_linear.qconfig = torch.quantization.default_dynamic_qconfig prepare_dynamic(float_linear) float_linear(X.float()) quantized_float_linear = nnqd.Linear.from_float(float_linear) # Smoke test to make sure the module actually runs quantized_float_linear(X) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer): """test API functionality for nn.quantized.dynamic.Linear""" W = torch.rand(out_features, in_features).float() W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8) W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8) X = torch.rand(batch_size, in_features).float() B = torch.rand(out_features).float() if use_bias else None qlinear = nnqd.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear.set_weight_bias(W_q, B) qlinear(X) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_params._packed_params Z_dq = qlinear(X) # Check if the module implementation matches calling the # ops directly Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack) self.assertEqual(Z_ref, Z_dq) # Test serialization of dynamic quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['_packed_params.weight'], W_q) if use_bias: self.assertEqual(model_dict['_packed_params.bias'], B) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qlinear = nnqd.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual( linear_unpack(qlinear._packed_params._packed_params), linear_unpack(loaded_qlinear._packed_params._packed_params)) if use_bias: self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_params')) self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) self.assertTrue(hasattr(qlinear, '_weight_bias')) self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual( qlinear._weight_bias(), torch.ops.quantized.linear_unpack( qlinear._packed_params._packed_params)) Z_dq2 = qlinear(X) self.assertEqual(Z_dq, Z_dq2) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # <start code> # b = io.BytesIO() # torch.save(qlinear, b) # b.seek(0) # loaded = torch.load(b) # self.assertEqual(qlinear.weight(), loaded.weight()) # self.assertEqual(qlinear.zero_point, loaded.zero_point) # <end code> with self.assertRaisesRegex( RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(qlinear, b) # Test JIT self.checkScriptable(qlinear, list(zip([X], [Z_ref])), check_save_load=True) # Test from_float float_linear = torch.nn.Linear(in_features, out_features).float() if use_default_observer: float_linear.qconfig = torch.quantization.default_dynamic_qconfig prepare_dynamic(float_linear) float_linear(X.float()) quantized_float_linear = nnqd.Linear.from_float(float_linear) # Smoke test to make sure the module actually runs quantized_float_linear(X) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig): r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8 """ num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) num_indices = np.sum(lengths) indices = torch.from_numpy( np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) offsets = lengths_to_offsets(lengths) # include the last offset offsets = torch.cat( (offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0) weights = torch.from_numpy((np.random.random_sample( (num_embeddings, embedding_dim)) + 1).astype(np.float32)) obs = default_float_qparams_observer() obs(weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() # Quantize the weights to 8bits qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) qemb = nnqd.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, mode='sum') qemb.set_weight(qweight) qemb(indices, offsets) # Ensure the module has the correct weights self.assertEqual(qweight, qemb.weight()) w_packed = qemb._packed_params._packed_weight module_out = qemb(indices, offsets) # Call the qembedding_bag operator directly ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, per_sample_weights=None, include_last_offset=True) self.assertEqual(module_out, ref) # Test serialization of dynamic EmbeddingBag module using state_dict emb_dict = qemb.state_dict() b = io.BytesIO() torch.save(emb_dict, b) b.seek(0) loaded_dict = torch.load(b) embedding_unpack = torch.ops.quantized.embedding_bag_unpack # Check unpacked weight values explicitly for key in emb_dict: if isinstance(emb_dict[key], torch._C.ScriptObject): assert isinstance(loaded_dict[key], torch._C.ScriptObject) emb_weight = embedding_unpack(emb_dict[key]) loaded_weight = embedding_unpack(loaded_dict[key]) self.assertEqual(emb_weight, loaded_weight) # Check state dict serialization and torch.save APIs loaded_qemb = nnqd.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, mode='sum') self.check_eager_serialization(qemb, loaded_qemb, [indices, offsets]) loaded_qemb.load_state_dict(loaded_dict) self.assertEqual( embedding_unpack(qemb._packed_params._packed_weight), embedding_unpack(loaded_qemb._packed_params._packed_weight)) # Test JIT serialization self.checkScriptable(qemb, [[indices, offsets]], check_save_load=True) # Test from_float call float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, scale_grad_by_freq=False, mode='sum') if set_qconfig: float_embedding.qconfig = float_qparams_dynamic_qconfig prepare_dynamic(float_embedding) float_embedding(indices, offsets) q_embeddingbag = nnqd.EmbeddingBag.from_float(float_embedding) q_embeddingbag(indices, offsets) self.assertTrue('DynamicQuantizedEmbeddingBag' in str(q_embeddingbag))