def test_histogram_observer(self, qdtype, qscheme, reduce_range): myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) x = torch.tensor([2.0, 3.0, 4.0, 5.0]) y = torch.tensor([5.0, 6.0, 7.0, 8.0]) myobs(x) myobs(y) self.assertEqual(myobs.min_val, 2.0) self.assertEqual(myobs.max_val, 8.0) self.assertEqual(myobs.histogram, [2., 3., 3.]) qparams = myobs.calculate_qparams() if reduce_range: if qscheme == torch.per_tensor_symmetric: ref_scale = 0.0470588 * 255 / 127 ref_zero_point = 0 if qdtype is torch.qint8 else 128 else: ref_scale = 0.0235294 * 255 / 127 ref_zero_point = -64 if qdtype is torch.qint8 else 0 else: if qscheme == torch.per_tensor_symmetric: ref_scale = 0.0470588 ref_zero_point = 0 if qdtype is torch.qint8 else 128 else: ref_scale = 0.0235294 ref_zero_point = -128 if qdtype is torch.qint8 else 0 self.assertEqual(qparams[1].item(), ref_zero_point) self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
def test_histogram_observer_one_sided(self): myobs = HistogramObserver(bins=8, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) x = torch.tensor([0.0, 0.3, 1.2, 1.7]) y = torch.tensor([0.1, 1.3, 2.0, 2.7]) myobs(x) myobs(y) self.assertEqual(myobs.min_val, 0) qparams = myobs.calculate_qparams() self.assertEqual(qparams[1].item(), 0)
def test_histogram_observer_against_reference(self, N, bins, dtype, qscheme, reduce_range): ref_obs = _ReferenceHistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range) my_obs = HistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range) for _ in range(10): X = torch.randn(N) my_obs(X) ref_obs(X) ref_qparams = ref_obs.calculate_qparams() my_qparams = my_obs.calculate_qparams() self.assertEqual(ref_qparams, my_qparams)
def test_histogram_observer_same_inputs(self): myobs = HistogramObserver(bins=3, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False) w = torch.ones(4, requires_grad=True) x = torch.zeros(4, requires_grad=True) y = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True) z = torch.tensor([5.0, 6.0, 7.0, 8.0]) myobs(w) myobs(x) myobs(x) myobs(y) myobs(z) qparams = myobs.calculate_qparams() self.assertEqual(myobs.min_val, 2.0) self.assertEqual(myobs.max_val, 8.0) self.assertEqual(myobs.histogram, [2., 3., 3.])
def quantize_statically(model, inputs, data_loader, linear_only=False): if (hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder) and linear_only): qconfig = QConfig( activation=HistogramObserver.with_args(reduce_range=False), weight=default_weight_observer, ) qconfig_dict = {"": None} for layer_idx in range(len(model.encoder.encoder.transformer.layers)): qconfig_dict[ "encoder.encoder.transformer.layers.{}.attention.input_projection" .format(layer_idx)] = qconfig qconfig_dict[ "encoder.encoder.transformer.layers.{}.attention.output_projection" .format(layer_idx)] = qconfig for mlp_idx, m in enumerate(model.encoder.encoder.transformer. layers[layer_idx].residual_mlp.mlp): if type(m) == torch.nn.Linear: qconfig_dict[ "encoder.encoder.transformer.layers.{}.residual_mlp.mlp.{}" .format(layer_idx, mlp_idx)] = qconfig trace = model.graph_mode_quantize(inputs, data_loader, qconfig_dict=qconfig_dict, force_quantize=True) else: trace = model.graph_mode_quantize(inputs, data_loader) return trace
def test_histogram_observer_save_load_state_dict(self): """ Smoke test on saving/loading state_dict """ obs1 = HistogramObserver() obs1(torch.randn(4, 4, 4, 4)) obs2 = HistogramObserver() obs2.load_state_dict(obs1.state_dict()) self.assertEqual(obs2.min_val.shape, torch.Size([])) self.assertEqual(obs2.max_val.shape, torch.Size([]))
def test_histogram_observer_consistent_buffer_shape(self): """ Ensures that the buffer shapes do not change from uninitialized to initialized states for HistogramObserver. """ obs = HistogramObserver() min_shape_before = obs.min_val.shape max_shape_before = obs.max_val.shape for _ in range(2): obs(torch.randn(4, 4, 4, 4)) self.assertEqual(min_shape_before, obs.min_val.shape) self.assertEqual(max_shape_before, obs.max_val.shape)
def quantize_fx(model, inputs, data_loader, dynamic=True): if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder): static = not dynamic if dynamic: qconfig = per_channel_dynamic_qconfig else: qconfig = QConfig( activation=HistogramObserver.with_args(reduce_range=False), weight=default_weight_observer, ) # Only linear layers qconfig_dict = {"": None} qconfig_dict["object_type"] = [(torch.nn.Linear, qconfig)] def calibrate(model, loader, max_samples=-1): model.eval() with torch.no_grad(): for (idx, d) in enumerate(loader): print("Running sample input #" + str(idx)) model(d[1]["tokens"]) if idx == max_samples: break prepared_model = prepare_fx( model.encoder.encoder.transformer.layers.layers, qconfig_dict) # fuse modules and insert observers model.encoder.encoder.transformer.layers.layers = prepared_model if static: calibrate(model, data_loader) # run calibration on sample data model.encoder.encoder.transformer.layers.layers = convert_fx( prepared_model) # Trace the submodule in order to fix the interface if static: input1 = torch.randn([2, 1, 1024], dtype=torch.float) input2 = torch.randn([1, 2]).bool() traced = torch.jit.trace( model.encoder.encoder.transformer.layers.layers, (input1, input2)) model.encoder.encoder.transformer.layers.layers = traced # Trace the overall module trace = model.trace(inputs) return trace
class ObserverTest(QuantizationTestCase): @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), reduce_range=st.booleans()) def test_minmax_observer(self, qdtype, qscheme, reduce_range): # reduce_range cannot be true for symmetric quantization with uint8 if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric: reduce_range = False myobs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) result = myobs(x) result = myobs(y) self.assertEqual(result, y) self.assertEqual(myobs.min_val, 1.0) self.assertEqual(myobs.max_val, 8.0) qparams = myobs.calculate_qparams() if reduce_range: if qscheme == torch.per_tensor_symmetric: ref_scale = 0.062745 * 255 / 127 ref_zero_point = 0 if qdtype is torch.qint8 else 128 else: ref_scale = 0.0313725 * 255 / 127 ref_zero_point = -64 if qdtype is torch.qint8 else 0 else: if qscheme == torch.per_tensor_symmetric: ref_scale = 0.062745 ref_zero_point = 0 if qdtype is torch.qint8 else 128 else: ref_scale = 0.0313725 ref_zero_point = -128 if qdtype is torch.qint8 else 0 self.assertEqual(qparams[1].item(), ref_zero_point) self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) @given(obs=st.sampled_from((torch.quantization.default_observer()(), HistogramObserver(bins=10)))) def test_observer_scriptable(self, obs): scripted = torch.jit.script(obs) x = torch.rand(3, 4) obs(x) scripted(x) self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) buf = io.BytesIO() torch.jit.save(scripted, buf) buf.seek(0) loaded = torch.jit.load(buf) self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
def test_observer_scriptable(self, qdtype, qscheme): ob_list = [ HistogramObserver(dtype=qdtype, qscheme=qscheme), default_histogram_observer() ] for obs in ob_list: scripted = torch.jit.script(obs) x = torch.rand(3, 4) obs(x) scripted(x) self.assertTrue(torch.equal(obs.histogram, scripted.histogram)) buf = io.BytesIO() torch.jit.save(scripted, buf) buf.seek(0) loaded = torch.jit.load(buf) self.assertTrue(torch.equal(obs.histogram, scripted.histogram))
def quantize_statically(model, inputs, data_loader, linear_only=False, module_swap=False): log_feature_usage("export.quantize.statically") if (hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder) and linear_only): log_accelerator_feature_usage("quantize.statically") qconfig = QConfig( activation=HistogramObserver.with_args(reduce_range=False), weight=default_weight_observer, ) qconfig_dict = {"": None} if module_swap: layers = model.encoder.encoder.transformer.layers.layers layers_str = "encoder.encoder.transformer.layers.layers" else: layers = model.encoder.encoder.transformer.layers layers_str = "encoder.encoder.transformer.layers" # skip first layer for layer_idx in range(1, len(layers)): qconfig_dict[ layers_str + ".{}.attention.input_projection".format(layer_idx)] = qconfig qconfig_dict[ layers_str + ".{}.attention.output_projection".format(layer_idx)] = qconfig for mlp_idx, m in enumerate(layers[layer_idx].residual_mlp.mlp): # Only quantize first linear otherwise there are accuarcy issues if type(m) == torch.nn.Linear and mlp_idx < 1: qconfig_dict[layers_str + ".{}.residual_mlp.mlp.{}".format( layer_idx, mlp_idx)] = qconfig trace = model.graph_mode_quantize(inputs, data_loader, qconfig_dict=qconfig_dict, force_quantize=True) else: trace = model.graph_mode_quantize(inputs, data_loader) return trace
def test_histogram_observer(self, qdtype, qscheme, reduce_range): myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) x = torch.tensor([2.0, 3.0, 4.0, 5.0]) y = torch.tensor([5.0, 6.0, 7.0, 8.0]) myobs(x) myobs(y) self.assertEqual(myobs.min_val, 2.0) self.assertEqual(myobs.max_val, 8.0) self.assertEqual(myobs.histogram, [2., 3., 3.]) qparams = myobs.calculate_qparams() if reduce_range: if qscheme == torch.per_tensor_symmetric: ref_scale = 0.0470588 * 255 / 127 ref_zero_point = 0 if qdtype is torch.qint8 else 128 else: ref_scale = 0.0235294 * 255 / 127 ref_zero_point = -64 if qdtype is torch.qint8 else 0 else: if qscheme == torch.per_tensor_symmetric: ref_scale = 0.0470588 ref_zero_point = 0 if qdtype is torch.qint8 else 128 else: ref_scale = 0.0235294 ref_zero_point = -128 if qdtype is torch.qint8 else 0 self.assertEqual(qparams[1].item(), ref_zero_point) self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) # Test for serializability state_dict = myobs.state_dict() b = io.BytesIO() torch.save(state_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in state_dict: self.assertEqual(state_dict[key], loaded_dict[key]) loaded_obs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) loaded_obs.load_state_dict(loaded_dict) loaded_qparams = loaded_obs.calculate_qparams() self.assertEqual(myobs.min_val, loaded_obs.min_val) self.assertEqual(myobs.max_val, loaded_obs.max_val) self.assertEqual(myobs.histogram, loaded_obs.histogram) self.assertEqual(myobs.bins, loaded_obs.bins) self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
def quantize_fx(model, inputs, data_loader, dynamic=True, selective=False): if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder): static = not dynamic if dynamic: qconfig = per_channel_dynamic_qconfig else: qconfig = QConfig( activation=HistogramObserver.with_args(reduce_range=False), weight=default_weight_observer, ) # Only linear layers qconfig_dict = {"": None} if static and selective: qconfig_dict["module_name"] = [] layers = model.encoder.encoder.transformer.layers.layers.layers layers_str = "layers" # skip first layer for layer_idx in range(1, len(layers)): qconfig_dict["module_name"].append(( layers_str + ".{}.attention.input_projection".format(layer_idx), qconfig, )) qconfig_dict["module_name"].append(( layers_str + ".{}.attention.output_projection".format(layer_idx), qconfig, )) for mlp_idx, m in enumerate( layers[layer_idx].residual_mlp.mlp): # Only quantize first linear otherwise there are accuarcy issues with static quantization if type(m) == torch.nn.Linear and mlp_idx < 1: qconfig_dict["module_name"].append(( layers_str + ".{}.residual_mlp.mlp.{}".format( layer_idx, mlp_idx), qconfig, )) else: qconfig_dict["object_type"] = [(torch.nn.Linear, qconfig)] def calibrate(model, loader, max_samples=-1): model.eval() with torch.no_grad(): for (idx, d) in enumerate(loader): print("Running sample input #" + str(idx)) model(d[1]["tokens"]) if idx == max_samples: break prepared_model = prepare_fx( model.encoder.encoder.transformer.layers.layers, qconfig_dict) # fuse modules and insert observers model.encoder.encoder.transformer.layers.layers = prepared_model if static: calibrate(model, data_loader) # run calibration on sample data model.encoder.encoder.transformer.layers.layers = convert_fx( prepared_model) # Trace the submodule in order to fix the interface if static: input1 = torch.randn([2, 1, 1024], dtype=torch.float) input2 = torch.randn([1, 2]).bool() traced = torch.jit.trace( model.encoder.encoder.transformer.layers.layers, (input1, input2)) model.encoder.encoder.transformer.layers.layers = traced # Trace the overall module trace = model.trace(inputs) return trace