def insert_observer( node: Node, observer: torch.quantization.ObserverBase, model: torch.nn.Module, activation_post_process_map: Dict[str, torch.quantization.ObserverBase], env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable, observed_node_names_set: Set[str]): """Insert observer for node by modifying the observed_graph and attach observer module to the model Args: node: Node observer: observer/fake_quantize module instance """ # respect device affinity when adding observers model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute prefix = node.name + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) # put observer instance activation_post_process map assert activation_post_process_map is not None activation_post_process_map[node.name] = observer # insert observer call env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name)
def insert_observer( node: Node, observer: torch.quantization.ObserverBase, model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, ) -> Node: """ Attaches `observer` to `model`, and creates a node which calls `observer` on the output of `node`. """ model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute if is_equalization_observer(observer): prefix = node.name + '_equalization_process_' else: prefix = node.name + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) modules[observer_name] = observer with graph.inserting_after(node): new_obs = graph.create_node( 'call_module', observer_name, (node,), {}) return new_obs