def extract_mutation_from_pt_module( pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]: model = Model(_internal=True) graph = Graph(model, uid(), '_model', _internal=True)._register() model.python_class = pytorch_model.__class__ if len(inspect.signature(model.python_class.__init__).parameters) > 1: if not is_model_wrapped(pytorch_model): raise ValueError( 'Please annotate the model with @model_wrapper decorator in python execution mode ' 'if your model has init parameters.') model.python_init_params = cast(dict, pytorch_model.trace_kwargs) else: model.python_init_params = {} # hyper-parameter choice namespace: ModelNamespace = cast(ModelNamespace, pytorch_model._model_namespace) for param_spec in namespace.parameter_specs: assert param_spec.categorical and param_spec.type == 'choice' node = graph.add_node(f'param_spec_{param_spec.name}', 'ModelParameterChoice', {'candidates': param_spec.values}) node.label = param_spec.name for name, module in pytorch_model.named_modules(): # tricky case: value choice that serves as parameters are stored in traced arguments if is_basic_unit(module): trace_kwargs = cast(Dict[str, Any], module.trace_kwargs) for key, value in trace_kwargs.items(): if isinstance(value, ValueChoiceX): for i, choice in enumerate(value.inner_choices()): node = graph.add_node( f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates}) node.label = choice.label if isinstance(module, (LayerChoice, InputChoice, ValueChoice)): # TODO: check the label of module and warn if it's auto-generated pass if isinstance(module, LayerChoice): node = graph.add_node(name, 'LayerChoice', {'candidates': module.names}) node.label = module.label if isinstance(module, InputChoice): node = graph.add_node(name, 'InputChoice', { 'n_candidates': module.n_candidates, 'n_chosen': module.n_chosen }) node.label = module.label if isinstance(module, ValueChoiceX): for i, choice in enumerate(module.inner_choices()): node = graph.add_node(f'{name}.{i}', 'ValueChoice', {'candidates': choice.candidates}) node.label = choice.label if isinstance(module, NasBench101Cell): node = graph.add_node(name, 'NasBench101Cell', {'max_num_edges': module.max_num_edges}) node.label = module.label if isinstance(module, Placeholder): raise NotImplementedError( 'Placeholder is not supported in python execution mode.') model.status = ModelStatus.Frozen if not graph.hidden_nodes: return model, None mutators = [] mutators_final = [] for nodes in _group_by_label_and_type(graph.hidden_nodes): label = nodes[0].label assert label is not None, f'label of {nodes[0]} can not be None.' assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \ f'Node with label "{label}" does not all have the same type.' assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \ f'Node with label "{label}" does not agree on parameters.' if nodes[0].operation.type == 'NasBench101Cell': # The mutation of Nas-bench-101 is special, and has to be done lastly. mutators_final.append(NasBench101Mutator(label)) else: mutators.append(ManyChooseManyMutator(label)) return model, mutators + mutators_final
def extract_mutation_from_pt_module( pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]: model = Model(_internal=True) graph = Graph(model, uid(), '_model', _internal=True)._register() model.python_class = pytorch_model.__class__ if len(inspect.signature(model.python_class.__init__).parameters) > 1: if not getattr(pytorch_model, '_nni_model_wrapper', False): raise ValueError( 'Please annotate the model with @model_wrapper decorator in python execution mode ' 'if your model has init parameters.') model.python_init_params = pytorch_model.trace_kwargs else: model.python_init_params = {} for name, module in pytorch_model.named_modules(): # tricky case: value choice that serves as parameters are stored in traced arguments if is_basic_unit(module): for key, value in module.trace_kwargs.items(): if isinstance(value, ValueChoice): node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates}) node.label = value.label if isinstance(module, (LayerChoice, InputChoice, ValueChoice)): # TODO: check the label of module and warn if it's auto-generated pass if isinstance(module, LayerChoice): node = graph.add_node(name, 'LayerChoice', {'candidates': module.names}) node.label = module.label if isinstance(module, InputChoice): node = graph.add_node(name, 'InputChoice', { 'n_candidates': module.n_candidates, 'n_chosen': module.n_chosen }) node.label = module.label if isinstance(module, ValueChoice): node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates}) node.label = module.label if isinstance(module, Repeat) and module.min_depth <= module.max_depth: node = graph.add_node(name, 'Repeat', { 'candidates': list(range(module.min_depth, module.max_depth + 1)) }) node.label = module.label if isinstance(module, NasBench101Cell): node = graph.add_node(name, 'NasBench101Cell', {'max_num_edges': module.max_num_edges}) node.label = module.label if isinstance(module, Placeholder): raise NotImplementedError( 'Placeholder is not supported in python execution mode.') model.status = ModelStatus.Frozen if not graph.hidden_nodes: return model, None mutators = [] mutators_final = [] for nodes in _group_by_label_and_type(graph.hidden_nodes): assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \ f'Node with label "{nodes[0].label}" does not all have the same type.' assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \ f'Node with label "{nodes[0].label}" does not agree on parameters.' if nodes[0].operation.type == 'NasBench101Cell': mutators_final.append(NasBench101Mutator(nodes[0].label)) else: mutators.append(ManyChooseManyMutator(nodes[0].label)) return model, mutators + mutators_final