예제 #1
0
 def test_backend_op_config_set_num_tensor_args_to_observation_type(self):
     conf = BackendPatternConfig(torch.add)
     self.assertEqual(len(conf._num_tensor_args_to_observation_type), 0)
     conf._set_num_tensor_args_to_observation_type(
         self._num_tensor_args_to_observation_type)
     self.assertEqual(conf._num_tensor_args_to_observation_type,
                      self._num_tensor_args_to_observation_type)
예제 #2
0
 def test_backend_op_config_set_overwrite_output_observer(self):
     conf = BackendPatternConfig(torch.sigmoid)
     self.assertTrue(conf._overwrite_output_observer is None)
     conf._set_overwrite_output_observer(
         default_fixed_qparams_range_0to1_observer)
     self.assertEqual(conf._overwrite_output_observer,
                      default_fixed_qparams_range_0to1_observer)
예제 #3
0
 def test_backend_op_config_add_dtype_config(self):
     conf = BackendPatternConfig(torch.nn.Linear)
     self.assertEqual(len(conf.dtype_configs), 0)
     conf.add_dtype_config(self.dtype_config1)
     conf.add_dtype_config(self.dtype_config2)
     self.assertEqual(len(conf.dtype_configs), 2)
     self.assertEqual(conf.dtype_configs[0], self.dtype_config1)
     self.assertEqual(conf.dtype_configs[1], self.dtype_config2)
예제 #4
0
 def test_backend_op_config_set_observation_type(self):
     conf = BackendPatternConfig(torch.nn.Linear)
     self.assertEqual(
         conf.observation_type,
         ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
     conf.set_observation_type(
         ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
     self.assertEqual(conf.observation_type,
                      ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
예제 #5
0
 def _get_backend_op_config2(self):
     return BackendPatternConfig(torch.add) \
         .add_dtype_config(self.dtype_config2) \
         ._set_root_node_getter(_default_root_node_getter) \
         ._set_extra_inputs_getter(self._extra_inputs_getter) \
         ._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \
         ._set_input_type_to_index(self._input_type_to_index) \
         ._set_input_output_observed(False) \
         ._set_overwrite_output_fake_quantize(self._fake_quantize) \
         ._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer)
예제 #6
0
 def _get_backend_op_config1(self):
     return BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) \
         .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
         .add_dtype_config(self.dtype_config1) \
         .add_dtype_config(self.dtype_config2) \
         .set_root_module(torch.nn.Linear) \
         .set_qat_module(nnqat.Linear) \
         .set_reference_quantized_module(nnqr.Linear) \
         .set_fused_module(nni.LinearReLU) \
         .set_fuser_method(self._fuser_method)
예제 #7
0
 def test_backend_op_config_from_dict(self):
     conf_dict1 = self._get_backend_pattern_config_dict1()
     conf1 = BackendPatternConfig.from_dict(conf_dict1)
     self.assertEqual(conf1.pattern, (torch.nn.ReLU, torch.nn.Linear))
     self.assertEqual(
         conf1.observation_type,
         ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
     self.assertEqual(conf1.root_module, torch.nn.Linear)
     self.assertEqual(conf1.qat_module, nnqat.Linear)
     self.assertEqual(conf1.reference_quantized_module, nnqr.Linear)
     self.assertEqual(conf1.fused_module, nni.LinearReLU)
     self.assertEqual(conf1.fuser_method, self._fuser_method)
     self.assertTrue(conf1._root_node_getter is None)
     self.assertTrue(conf1._extra_inputs_getter is None)
     self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0)
     self.assertEqual(len(conf1._input_type_to_index), 0)
     self.assertTrue(conf1._input_output_observed is None)
     self.assertTrue(conf1._overwrite_output_fake_quantize is None)
     self.assertTrue(conf1._overwrite_output_observer is None)
     # Test temporary/internal keys
     conf_dict2 = self._get_backend_pattern_config_dict2()
     conf2 = BackendPatternConfig.from_dict(conf_dict2)
     self.assertEqual(conf2.pattern, torch.add)
     self.assertEqual(
         conf2.observation_type,
         ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
     self.assertTrue(conf2.root_module is None)
     self.assertTrue(conf2.qat_module is None)
     self.assertTrue(conf2.reference_quantized_module is None)
     self.assertTrue(conf2.fused_module is None)
     self.assertTrue(conf2.fuser_method is None)
     self.assertEqual(conf2._root_node_getter, _default_root_node_getter)
     self.assertEqual(conf2._extra_inputs_getter, self._extra_inputs_getter)
     self.assertEqual(conf2._num_tensor_args_to_observation_type,
                      self._num_tensor_args_to_observation_type)
     self.assertEqual(conf2._input_type_to_index, self._input_type_to_index)
     self.assertEqual(conf2._input_output_observed, False)
     self.assertEqual(conf2._overwrite_output_fake_quantize,
                      self._fake_quantize)
     self.assertEqual(conf2._overwrite_output_observer,
                      default_fixed_qparams_range_0to1_observer)
예제 #8
0
 def test_backend_op_config_set_extra_inputs_getter(self):
     conf = BackendPatternConfig(torch.nn.Linear)
     self.assertTrue(conf._extra_inputs_getter is None)
     conf._set_extra_inputs_getter(self._extra_inputs_getter)
     self.assertEqual(conf._extra_inputs_getter, self._extra_inputs_getter)
예제 #9
0
 def test_backend_op_config_set_overwrite_output_fake_quantize(self):
     conf = BackendPatternConfig(torch.sigmoid)
     self.assertTrue(conf._overwrite_output_fake_quantize is None)
     conf._set_overwrite_output_fake_quantize(self._fake_quantize)
     self.assertEqual(conf._overwrite_output_fake_quantize,
                      self._fake_quantize)
예제 #10
0
 def test_backend_op_config_set_root_module(self):
     conf = BackendPatternConfig(nni.LinearReLU)
     self.assertTrue(conf.root_module is None)
     conf.set_root_module(torch.nn.Linear)
     self.assertEqual(conf.root_module, torch.nn.Linear)
예제 #11
0
 def test_backend_op_config_set_input_type_to_index(self):
     conf = BackendPatternConfig(torch.addmm)
     self.assertEqual(len(conf._input_type_to_index), 0)
     conf._set_input_type_to_index(self._input_type_to_index)
     self.assertEqual(conf._input_type_to_index, self._input_type_to_index)
예제 #12
0
 def test_backend_op_config_set_input_output_observed(self):
     conf = BackendPatternConfig(torch.nn.Embedding)
     self.assertTrue(conf._input_output_observed is None)
     conf._set_input_output_observed(False)
     self.assertEqual(conf._input_output_observed, False)
예제 #13
0
 def test_backend_op_config_set_qat_module(self):
     conf = BackendPatternConfig(torch.nn.Linear)
     self.assertTrue(conf.qat_module is None)
     conf.set_qat_module(nnqat.Linear)
     self.assertEqual(conf.qat_module, nnqat.Linear)
예제 #14
0
 def test_backend_op_config_set_reference_quantized_module(self):
     conf = BackendPatternConfig(torch.nn.Linear)
     self.assertTrue(conf.reference_quantized_module is None)
     conf.set_reference_quantized_module(nnqr.Linear)
     self.assertEqual(conf.reference_quantized_module, nnqr.Linear)
예제 #15
0
 def test_backend_op_config_set_root_node_getter(self):
     conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
     self.assertTrue(conf._root_node_getter is None)
     conf._set_root_node_getter(_default_root_node_getter)
     self.assertEqual(conf._root_node_getter, _default_root_node_getter)
예제 #16
0
 def test_backend_op_config_set_fuser_method(self):
     conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
     self.assertTrue(conf.fuser_method is None)
     conf.set_fuser_method(self._fuser_method)
     self.assertEqual(conf.fuser_method, self._fuser_method)
예제 #17
0
 def test_backend_op_config_set_fused_module(self):
     conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
     self.assertTrue(conf.fused_module is None)
     conf.set_fused_module(nni.LinearReLU)
     self.assertEqual(conf.fused_module, nni.LinearReLU)