def test_acts_quant_params_linear(act1_type, act2_type, bn_out_stats): # prepare model: model = LinearBNSplitAct(act1_type, act2_type) stats = gen_stats_for_model(model) stats['bn']['output'] = bn_out_stats quantizer = PostTrainLinearQuantizer( model, model_activation_stats=deepcopy(stats), save_fp_weights=True) quantizer.prepare_model(torch.randn(10, 10)) # get quant params: expected_quant_params_keys = { 'linear.output_zero_point', 'linear.output_scale', 'linear.w_scale', 'linear.w_zero_point', 'act1.output_zero_point', 'act1.output_scale', 'act2.output_zero_point', 'act2.output_scale' } assert set(quantizer.linear_quant_params) == expected_quant_params_keys quantizer.set_linear_quant_param('linear.output_zero_point', 2.) quantizer.set_linear_quant_param('linear.output_scale', 30.) assert model.linear.output_zero_point == 2. assert model.linear.output_scale == 30. assert model.linear.force_readjust == True assert model.act1.force_readjust == True expected_quant_param_linear_dict = { 'output_zero_point': torch.tensor(2.), 'output_scale': 30., 'w_scale': model.linear.w_scale.item(), 'w_zero_point': model.linear.w_zero_point.item() } assert dict(model.linear.named_linear_quant_params() ) == expected_quant_param_linear_dict new_config = {'linear.output_zero_point': 4., 'act2.output_scale': 50} quantizer.update_linear_quant_params(new_config) assert model.linear.output_zero_point == 4 assert model.act2.output_scale == 50 assert model.linear.force_readjust == True assert model.act1.force_readjust == True
def test_acts_quant_params_rnn(rnn_model): model = DummyWordLangModel(nn.Embedding(41, 20), rnn_model) stats = gen_stats_for_model(model) quantizer = PostTrainLinearQuantizer( model, model_activation_stats=deepcopy(stats)) dummy_input = torch.randint(0, 41, size=(10, 1)) quantizer.prepare_model(dummy_input) new_config = { 'rnn.rnn.cells.0.act_o.output_scale': 4, 'embedding.w_scale': torch.tensor(59.0) } quantizer.update_linear_quant_params(new_config) assert model.rnn.rnn.cells[0].act_o.output_scale == 4 assert model.embedding.w_scale == 59.0 assert model.rnn.rnn.cells[0].act_o.force_readjust.item() is True assert model.rnn.rnn.cells[0].act_f.force_readjust.item() is True