Esempio n. 1
0
def extract_mutation_from_pt_module(
        pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]:
    model = Model(_internal=True)
    graph = Graph(model, uid(), '_model', _internal=True)._register()
    model.python_class = pytorch_model.__class__
    if len(inspect.signature(model.python_class.__init__).parameters) > 1:
        if not is_model_wrapped(pytorch_model):
            raise ValueError(
                'Please annotate the model with @model_wrapper decorator in python execution mode '
                'if your model has init parameters.')
        model.python_init_params = cast(dict, pytorch_model.trace_kwargs)
    else:
        model.python_init_params = {}

    # hyper-parameter choice
    namespace: ModelNamespace = cast(ModelNamespace,
                                     pytorch_model._model_namespace)
    for param_spec in namespace.parameter_specs:
        assert param_spec.categorical and param_spec.type == 'choice'
        node = graph.add_node(f'param_spec_{param_spec.name}',
                              'ModelParameterChoice',
                              {'candidates': param_spec.values})
        node.label = param_spec.name

    for name, module in pytorch_model.named_modules():
        # tricky case: value choice that serves as parameters are stored in traced arguments
        if is_basic_unit(module):
            trace_kwargs = cast(Dict[str, Any], module.trace_kwargs)
            for key, value in trace_kwargs.items():
                if isinstance(value, ValueChoiceX):
                    for i, choice in enumerate(value.inner_choices()):
                        node = graph.add_node(
                            f'{name}.init.{key}.{i}', 'ValueChoice',
                            {'candidates': choice.candidates})
                        node.label = choice.label

        if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
            # TODO: check the label of module and warn if it's auto-generated
            pass
        if isinstance(module, LayerChoice):
            node = graph.add_node(name, 'LayerChoice',
                                  {'candidates': module.names})
            node.label = module.label
        if isinstance(module, InputChoice):
            node = graph.add_node(name, 'InputChoice', {
                'n_candidates': module.n_candidates,
                'n_chosen': module.n_chosen
            })
            node.label = module.label
        if isinstance(module, ValueChoiceX):
            for i, choice in enumerate(module.inner_choices()):
                node = graph.add_node(f'{name}.{i}', 'ValueChoice',
                                      {'candidates': choice.candidates})
                node.label = choice.label
        if isinstance(module, NasBench101Cell):
            node = graph.add_node(name, 'NasBench101Cell',
                                  {'max_num_edges': module.max_num_edges})
            node.label = module.label
        if isinstance(module, Placeholder):
            raise NotImplementedError(
                'Placeholder is not supported in python execution mode.')

    model.status = ModelStatus.Frozen
    if not graph.hidden_nodes:
        return model, None

    mutators = []
    mutators_final = []
    for nodes in _group_by_label_and_type(graph.hidden_nodes):
        label = nodes[0].label
        assert label is not None, f'label of {nodes[0]} can not be None.'
        assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
            f'Node with label "{label}" does not all have the same type.'
        assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
            f'Node with label "{label}" does not agree on parameters.'
        if nodes[0].operation.type == 'NasBench101Cell':
            # The mutation of Nas-bench-101 is special, and has to be done lastly.
            mutators_final.append(NasBench101Mutator(label))
        else:
            mutators.append(ManyChooseManyMutator(label))
    return model, mutators + mutators_final
Esempio n. 2
0
def extract_mutation_from_pt_module(
        pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]:
    model = Model(_internal=True)
    graph = Graph(model, uid(), '_model', _internal=True)._register()
    model.python_class = pytorch_model.__class__
    if len(inspect.signature(model.python_class.__init__).parameters) > 1:
        if not getattr(pytorch_model, '_nni_model_wrapper', False):
            raise ValueError(
                'Please annotate the model with @model_wrapper decorator in python execution mode '
                'if your model has init parameters.')
        model.python_init_params = pytorch_model.trace_kwargs
    else:
        model.python_init_params = {}

    for name, module in pytorch_model.named_modules():
        # tricky case: value choice that serves as parameters are stored in traced arguments
        if is_basic_unit(module):
            for key, value in module.trace_kwargs.items():
                if isinstance(value, ValueChoice):
                    node = graph.add_node(name + '.init.' + key, 'ValueChoice',
                                          {'candidates': value.candidates})
                    node.label = value.label

        if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
            # TODO: check the label of module and warn if it's auto-generated
            pass
        if isinstance(module, LayerChoice):
            node = graph.add_node(name, 'LayerChoice',
                                  {'candidates': module.names})
            node.label = module.label
        if isinstance(module, InputChoice):
            node = graph.add_node(name, 'InputChoice', {
                'n_candidates': module.n_candidates,
                'n_chosen': module.n_chosen
            })
            node.label = module.label
        if isinstance(module, ValueChoice):
            node = graph.add_node(name, 'ValueChoice',
                                  {'candidates': module.candidates})
            node.label = module.label
        if isinstance(module, Repeat) and module.min_depth <= module.max_depth:
            node = graph.add_node(name, 'Repeat', {
                'candidates':
                list(range(module.min_depth, module.max_depth + 1))
            })
            node.label = module.label
        if isinstance(module, NasBench101Cell):
            node = graph.add_node(name, 'NasBench101Cell',
                                  {'max_num_edges': module.max_num_edges})
            node.label = module.label
        if isinstance(module, Placeholder):
            raise NotImplementedError(
                'Placeholder is not supported in python execution mode.')

    model.status = ModelStatus.Frozen
    if not graph.hidden_nodes:
        return model, None

    mutators = []
    mutators_final = []
    for nodes in _group_by_label_and_type(graph.hidden_nodes):
        assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
            f'Node with label "{nodes[0].label}" does not all have the same type.'
        assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
            f'Node with label "{nodes[0].label}" does not agree on parameters.'
        if nodes[0].operation.type == 'NasBench101Cell':
            mutators_final.append(NasBench101Mutator(nodes[0].label))
        else:
            mutators.append(ManyChooseManyMutator(nodes[0].label))
    return model, mutators + mutators_final