def test_backend_op_config_set_num_tensor_args_to_observation_type(self): conf = BackendPatternConfig(torch.add) self.assertEqual(len(conf._num_tensor_args_to_observation_type), 0) conf._set_num_tensor_args_to_observation_type( self._num_tensor_args_to_observation_type) self.assertEqual(conf._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)
def test_backend_op_config_set_overwrite_output_observer(self): conf = BackendPatternConfig(torch.sigmoid) self.assertTrue(conf._overwrite_output_observer is None) conf._set_overwrite_output_observer( default_fixed_qparams_range_0to1_observer) self.assertEqual(conf._overwrite_output_observer, default_fixed_qparams_range_0to1_observer)
def test_backend_op_config_add_dtype_config(self): conf = BackendPatternConfig(torch.nn.Linear) self.assertEqual(len(conf.dtype_configs), 0) conf.add_dtype_config(self.dtype_config1) conf.add_dtype_config(self.dtype_config2) self.assertEqual(len(conf.dtype_configs), 2) self.assertEqual(conf.dtype_configs[0], self.dtype_config1) self.assertEqual(conf.dtype_configs[1], self.dtype_config2)
def test_backend_op_config_set_observation_type(self): conf = BackendPatternConfig(torch.nn.Linear) self.assertEqual( conf.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) conf.set_observation_type( ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) self.assertEqual(conf.observation_type, ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
def _get_backend_op_config2(self): return BackendPatternConfig(torch.add) \ .add_dtype_config(self.dtype_config2) \ ._set_root_node_getter(_default_root_node_getter) \ ._set_extra_inputs_getter(self._extra_inputs_getter) \ ._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \ ._set_input_type_to_index(self._input_type_to_index) \ ._set_input_output_observed(False) \ ._set_overwrite_output_fake_quantize(self._fake_quantize) \ ._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer)
def _get_backend_op_config1(self): return BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) \ .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ .add_dtype_config(self.dtype_config1) \ .add_dtype_config(self.dtype_config2) \ .set_root_module(torch.nn.Linear) \ .set_qat_module(nnqat.Linear) \ .set_reference_quantized_module(nnqr.Linear) \ .set_fused_module(nni.LinearReLU) \ .set_fuser_method(self._fuser_method)
def test_backend_op_config_from_dict(self): conf_dict1 = self._get_backend_pattern_config_dict1() conf1 = BackendPatternConfig.from_dict(conf_dict1) self.assertEqual(conf1.pattern, (torch.nn.ReLU, torch.nn.Linear)) self.assertEqual( conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) self.assertEqual(conf1.root_module, torch.nn.Linear) self.assertEqual(conf1.qat_module, nnqat.Linear) self.assertEqual(conf1.reference_quantized_module, nnqr.Linear) self.assertEqual(conf1.fused_module, nni.LinearReLU) self.assertEqual(conf1.fuser_method, self._fuser_method) self.assertTrue(conf1._root_node_getter is None) self.assertTrue(conf1._extra_inputs_getter is None) self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0) self.assertEqual(len(conf1._input_type_to_index), 0) self.assertTrue(conf1._input_output_observed is None) self.assertTrue(conf1._overwrite_output_fake_quantize is None) self.assertTrue(conf1._overwrite_output_observer is None) # Test temporary/internal keys conf_dict2 = self._get_backend_pattern_config_dict2() conf2 = BackendPatternConfig.from_dict(conf_dict2) self.assertEqual(conf2.pattern, torch.add) self.assertEqual( conf2.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) self.assertTrue(conf2.root_module is None) self.assertTrue(conf2.qat_module is None) self.assertTrue(conf2.reference_quantized_module is None) self.assertTrue(conf2.fused_module is None) self.assertTrue(conf2.fuser_method is None) self.assertEqual(conf2._root_node_getter, _default_root_node_getter) self.assertEqual(conf2._extra_inputs_getter, self._extra_inputs_getter) self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type) self.assertEqual(conf2._input_type_to_index, self._input_type_to_index) self.assertEqual(conf2._input_output_observed, False) self.assertEqual(conf2._overwrite_output_fake_quantize, self._fake_quantize) self.assertEqual(conf2._overwrite_output_observer, default_fixed_qparams_range_0to1_observer)
def test_backend_op_config_set_extra_inputs_getter(self): conf = BackendPatternConfig(torch.nn.Linear) self.assertTrue(conf._extra_inputs_getter is None) conf._set_extra_inputs_getter(self._extra_inputs_getter) self.assertEqual(conf._extra_inputs_getter, self._extra_inputs_getter)
def test_backend_op_config_set_overwrite_output_fake_quantize(self): conf = BackendPatternConfig(torch.sigmoid) self.assertTrue(conf._overwrite_output_fake_quantize is None) conf._set_overwrite_output_fake_quantize(self._fake_quantize) self.assertEqual(conf._overwrite_output_fake_quantize, self._fake_quantize)
def test_backend_op_config_set_root_module(self): conf = BackendPatternConfig(nni.LinearReLU) self.assertTrue(conf.root_module is None) conf.set_root_module(torch.nn.Linear) self.assertEqual(conf.root_module, torch.nn.Linear)
def test_backend_op_config_set_input_type_to_index(self): conf = BackendPatternConfig(torch.addmm) self.assertEqual(len(conf._input_type_to_index), 0) conf._set_input_type_to_index(self._input_type_to_index) self.assertEqual(conf._input_type_to_index, self._input_type_to_index)
def test_backend_op_config_set_input_output_observed(self): conf = BackendPatternConfig(torch.nn.Embedding) self.assertTrue(conf._input_output_observed is None) conf._set_input_output_observed(False) self.assertEqual(conf._input_output_observed, False)
def test_backend_op_config_set_qat_module(self): conf = BackendPatternConfig(torch.nn.Linear) self.assertTrue(conf.qat_module is None) conf.set_qat_module(nnqat.Linear) self.assertEqual(conf.qat_module, nnqat.Linear)
def test_backend_op_config_set_reference_quantized_module(self): conf = BackendPatternConfig(torch.nn.Linear) self.assertTrue(conf.reference_quantized_module is None) conf.set_reference_quantized_module(nnqr.Linear) self.assertEqual(conf.reference_quantized_module, nnqr.Linear)
def test_backend_op_config_set_root_node_getter(self): conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) self.assertTrue(conf._root_node_getter is None) conf._set_root_node_getter(_default_root_node_getter) self.assertEqual(conf._root_node_getter, _default_root_node_getter)
def test_backend_op_config_set_fuser_method(self): conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) self.assertTrue(conf.fuser_method is None) conf.set_fuser_method(self._fuser_method) self.assertEqual(conf.fuser_method, self._fuser_method)
def test_backend_op_config_set_fused_module(self): conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) self.assertTrue(conf.fused_module is None) conf.set_fused_module(nni.LinearReLU) self.assertEqual(conf.fused_module, nni.LinearReLU)