def test_acts_quant_params_linear(act1_type, act2_type, bn_out_stats):
    # prepare model:
    model = LinearBNSplitAct(act1_type, act2_type)
    stats = gen_stats_for_model(model)
    stats['bn']['output'] = bn_out_stats
    quantizer = PostTrainLinearQuantizer(
        model, model_activation_stats=deepcopy(stats), save_fp_weights=True)
    quantizer.prepare_model(torch.randn(10, 10))
    # get quant params:
    expected_quant_params_keys = {
        'linear.output_zero_point', 'linear.output_scale', 'linear.w_scale',
        'linear.w_zero_point', 'act1.output_zero_point', 'act1.output_scale',
        'act2.output_zero_point', 'act2.output_scale'
    }
    assert set(quantizer.linear_quant_params) == expected_quant_params_keys
    quantizer.set_linear_quant_param('linear.output_zero_point', 2.)
    quantizer.set_linear_quant_param('linear.output_scale', 30.)
    assert model.linear.output_zero_point == 2.
    assert model.linear.output_scale == 30.
    assert model.linear.force_readjust == True
    assert model.act1.force_readjust == True
    expected_quant_param_linear_dict = {
        'output_zero_point': torch.tensor(2.),
        'output_scale': 30.,
        'w_scale': model.linear.w_scale.item(),
        'w_zero_point': model.linear.w_zero_point.item()
    }
    assert dict(model.linear.named_linear_quant_params()
                ) == expected_quant_param_linear_dict
    new_config = {'linear.output_zero_point': 4., 'act2.output_scale': 50}
    quantizer.update_linear_quant_params(new_config)
    assert model.linear.output_zero_point == 4
    assert model.act2.output_scale == 50
    assert model.linear.force_readjust == True
    assert model.act1.force_readjust == True
def test_acts_quant_params_rnn(rnn_model):
    model = DummyWordLangModel(nn.Embedding(41, 20), rnn_model)
    stats = gen_stats_for_model(model)
    quantizer = PostTrainLinearQuantizer(
        model, model_activation_stats=deepcopy(stats))
    dummy_input = torch.randint(0, 41, size=(10, 1))
    quantizer.prepare_model(dummy_input)
    new_config = {
        'rnn.rnn.cells.0.act_o.output_scale': 4,
        'embedding.w_scale': torch.tensor(59.0)
    }
    quantizer.update_linear_quant_params(new_config)
    assert model.rnn.rnn.cells[0].act_o.output_scale == 4
    assert model.embedding.w_scale == 59.0
    assert model.rnn.rnn.cells[0].act_o.force_readjust.item() is True
    assert model.rnn.rnn.cells[0].act_f.force_readjust.item() is True