예제 #1
0
파일: mutator.py 프로젝트: linbinskn/nni
 def mutate(self, model: Model):
     # this mutate does not have any effect, but it is recorded in the mutation history
     for node in model.get_nodes_by_label(self.label):
         n_chosen = self.number_of_chosen(node)
         if n_chosen is None:
             candidates = [i for i in self.candidates(node) if self.choice([False, True])]
             # FIXME This is a hack to make choice align with the previous format
             # For example, it will convert [False, True, True] into [1, 2].
             self._cur_samples = candidates
         else:
             for _ in range(n_chosen):
                 self.choice(self.candidates(node))
         break
예제 #2
0
    def mutate(self, model: Model):
        max_num_edges = cast(int, None)
        for node in model.get_nodes_by_label(self.label):
            max_num_edges = node.operation.parameters['max_num_edges']
            break
        assert max_num_edges is not None
        mutation_dict = {
            mut.mutator.label: mut.samples
            for mut in model.history
        }
        num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
        adjacency_list = [
            mutation_dict[f'{self.label}/input{i}']
            for i in range(1, num_nodes)
        ]
        if sum([len(e) for e in adjacency_list]) > max_num_edges:
            raise InvalidMutation(
                f'Expected {max_num_edges} edges, found: {adjacency_list}')
        matrix = _NasBench101CellFixed.build_connection_matrix(
            adjacency_list, num_nodes)

        operations = ['IN'] + [
            mutation_dict[f'{self.label}/op{i}'][0]
            for i in range(1, num_nodes - 1)
        ] + ['OUT']
        assert len(operations) == len(matrix)
        matrix, operations = prune(
            matrix, operations)  # possible to raise InvalidMutation inside

        # NOTE: a hack to maintain a clean copy of what nasbench101 cell looks like
        self._cur_samples = {}
        for i in range(1, len(matrix)):
            if i + 1 < len(matrix):
                self._cur_samples[f'op{i}'] = operations[i]
            self._cur_samples[f'input{i}'] = [
                k for k in range(i) if matrix[k, i]
            ]
        self._cur_samples = [self._cur_samples
                             ]  # by design, _cur_samples is a list of samples