def test_compare_model_stub_lstm_dynamic_fx(self): r"""Compare the output of dynamic quantized linear layer and its float shadow module""" qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} float_model = LSTMwithHiddenDynamicModel() float_model.eval() prepared_model = prepare_fx(float_model, qconfig_dict) prepared_float_model = copy.deepcopy(prepared_model) prepared_float_model.eval() q_model = convert_fx(prepared_model) module_swap_list = [nn.LSTM] lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) expected_ob_dict_keys = {"lstm.stats"} self.compare_and_validate_model_stub_results_fx( prepared_float_model, q_model, module_swap_list, expected_ob_dict_keys, lstm_input, lstm_hidden, )
def test_compare_model_outputs_lstm_dynamic_fx(self): r"""Compare the output of LSTM layer in dynamic quantized model and corresponding output of linear layer in float model """ qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} float_model = LSTMwithHiddenDynamicModel() float_model.eval() prepared_model = prepare_fx(float_model, qconfig_dict) prepared_float_model = copy.deepcopy(prepared_model) q_model = convert_fx(prepared_model) lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) expected_act_compare_dict_keys = { "x.stats", "hid.stats", "lstm_1.stats" } self.compare_and_validate_model_outputs_results_fx( prepared_float_model, q_model, expected_act_compare_dict_keys, lstm_input, lstm_hidden, )
def test_compare_weights_lstm_dynamic_fx(self): r"""Compare the weights of float and dynamic quantized lstm layer""" qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} float_model = LSTMwithHiddenDynamicModel() float_model.eval() prepared_model = prepare_fx(float_model, qconfig_dict) prepared_float_model = copy.deepcopy(prepared_model) prepared_float_model.eval() q_model = convert_fx(prepared_model) expected_weight_dict_keys = {"lstm._all_weight_values.0.param"} self.compare_and_validate_model_weights_results_fx( prepared_float_model, q_model, expected_weight_dict_keys)