def _get_return_nodes(self, model): set_rng_seed(0) exclude_nodes_filter = [ "getitem", "floordiv", "size", "chunk", "_assert", "eq", "dim", "getattr", ] train_nodes, eval_nodes = get_graph_node_names( model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True) # Get rid of any nodes that don't return tensors as they cause issues # when testing backward pass. train_nodes = [ n for n in train_nodes if not any(x in n for x in exclude_nodes_filter) ] eval_nodes = [ n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter) ] return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
def _create_fx_model(model, train=False): # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names tracer_kwargs = dict( leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions), #enable_cpatching=True, param_shapes_constant=True) train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) eval_return_nodes = [eval_nodes[-1]] train_return_nodes = [train_nodes[-1]] if train: tracer = NodePathTracer(**tracer_kwargs) graph = tracer.trace(model) graph_nodes = list(reversed(graph.nodes)) output_node_names = [ n.name for n in graph_nodes[0]._input_nodes.keys() ] graph_node_names = [n.name for n in graph_nodes] output_node_indices = [ -graph_node_names.index(node_name) for node_name in output_node_names ] train_return_nodes = [train_nodes[ix] for ix in output_node_indices] fx_model = create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes, tracer_kwargs=tracer_kwargs, ) return fx_model
def _get_return_nodes(self, model): set_rng_seed(0) exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk'] train_nodes, eval_nodes = get_graph_node_names( model, tracer_kwargs={'leaf_modules': self.leaf_modules}, suppress_diff_warning=True) # Get rid of any nodes that don't return tensors as they cause issues # when testing backward pass. train_nodes = [ n for n in train_nodes if not any(x in n for x in exclude_nodes_filter) ] eval_nodes = [ n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter) ] return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
def convert_names(model: nn.Module, names: str | list[str], leaf_modules: list[nn.Module], train_mode: bool ) -> list[str]: # a helper function if isinstance(names, str): names = [names] tracer_kwargs = {} if leaf_modules is not None: tracer_kwargs['leaf_modules'] = leaf_modules _names = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) _names = _names[0] if train_mode else _names[1] _names = _names[1:] # because the first element is input if names is None: names = _names else: if not (set(names) <= set(_names)): diff = set(names) - set(_names) raise RuntimeError(f'Unknown names: {list(diff)}') return names
def test_node_name_conventions(self): model = TestModule() train_nodes, _ = get_graph_node_names(model) assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
sys.exit(0) elif not args.models: sys.stderr.write("Please specify at least one model to be exported\n") sys.exit(-1) device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu' logging.info("Device: %s", device) # An instance of your model. for mname in args.models: if mname not in model_classes: logging.warn("model %s is unknown and will not be exported", mname) continue if args.print_extract: train_nodes, eval_nodes = get_graph_node_names(model_classes[mname]()) print("*** Extractable layers of", mname, "***") for node in train_nodes: print(node) continue logging.info("Exporting model %s %s", mname, "(pretrained)" if args.pretrained else "") detection = mname in detection_model_classes segmentation = mname in segmentation_model_classes if detection: if "fasterrcnn" in mname and version.parse(torchvision.__version__) < version.parse("0.10.0"): raise RuntimeError("Fasterrcnn needs torchvision >= 0.10.0 (current = %s)" % torchvision.__version__) if mname in ["fasterrcnn", "retinanet"]: if args.backbone and args.backbone in model_classes: