コード例 #1
0
    def __init__(self, config, n_nodes, chn_in, stride, shared_a, allocator,
                 merger_state, merger_out, enumerator, preproc, aggregate,
                 edge_cls, edge_kwargs):
        super().__init__()
        global edge_id
        self.edge_id = edge_id
        edge_id += 1
        self.n_nodes = n_nodes
        self.stride = stride
        self.chn_in = chn_in
        self.n_input = len(chn_in)
        self.n_states = self.n_input + self.n_nodes
        self.n_input_e = len(edge_kwargs['chn_in'])
        self.shared_a = shared_a
        if shared_a:
            NASModule.add_shared_param()
        self.allocator = allocator(self.n_input, self.n_nodes)
        self.merger_state = merger_state()
        self.merger_out = merger_out(start=self.n_input)
        self.merge_out_range = self.merger_out.merge_range(self.n_states)
        self.enumerator = enumerator()

        chn_states = []
        if not preproc:
            self.preprocs = None
            chn_states.extend(chn_in)
        else:
            chn_cur = edge_kwargs['chn_in'][0]
            self.preprocs = nn.ModuleList()
            for i in range(self.n_input):
                self.preprocs.append(preproc[i](chn_in[i], chn_cur, False))
                chn_states.append(chn_cur)

        if not config.augment:
            self.fixed = False
            self.dag = nn.ModuleList()
            self.edges = []
            self.num_edges = 0
            for i in range(n_nodes):
                cur_state = self.n_input + i
                self.dag.append(nn.ModuleList())
                num_edges = self.enumerator.len_enum(cur_state, self.n_input_e)
                for sidx in self.enumerator.enum(cur_state, self.n_input_e):
                    e_chn_in = self.allocator.chn_in(
                        [chn_states[s] for s in sidx], sidx, cur_state)
                    edge_kwargs['chn_in'] = e_chn_in
                    edge_kwargs['stride'] = stride if all(s < self.n_input
                                                          for s in sidx) else 1
                    edge_kwargs['shared_a'] = shared_a
                    e = edge_cls(**edge_kwargs)
                    self.dag[i].append(e)
                    self.edges.append(e)
                self.num_edges += num_edges
                chn_states.append(
                    self.merger_state.chn_out(
                        [ei.chn_out for ei in self.dag[i]]))
                self.chn_out = self.merger_out.chn_out(chn_states)
            # print('DAGLayer: etype:{} chn_in:{} chn:{} #n:{} #e:{}'.format(str(edge_cls), self.chn_in, edge_kwargs['chn_in'][0],self.n_nodes, self.num_edges))
            # print('DAGLayer param count: {:.6f}'.format(param_count(self)))
        else:
            self.chn_states = chn_states
            self.edge_cls = edge_cls
            self.edge_kwargs = edge_kwargs
            self.fixed = True

        if aggregate is not None:
            self.merge_filter = aggregate(n_in=self.n_input + self.n_nodes,
                                          n_out=self.n_input +
                                          self.n_nodes // 2)
        else:
            self.merge_filter = None
        self.chn_out = self.merger_out.chn_out(chn_states)
コード例 #2
0
    def __init__(self,
                 config,
                 n_nodes,
                 chn_in,
                 stride,
                 shared_a,
                 allocator,
                 merger_out,
                 preproc,
                 aggregate,
                 child_cls,
                 child_kwargs,
                 edge_cls,
                 edge_kwargs,
                 children=None,
                 edges=None):
        super().__init__()
        self.edges = nn.ModuleList()
        self.subnets = nn.ModuleList()
        chn_in = (chn_in, ) if isinstance(chn_in, int) else chn_in
        self.n_input = len(chn_in)
        self.n_nodes = n_nodes
        self.n_states = self.n_input + self.n_nodes
        self.allocator = allocator(self.n_input, self.n_nodes)
        self.merger_out = merger_out(start=self.n_input)
        self.merge_out_range = self.merger_out.merge_range(self.n_states)
        if shared_a:
            NASModule.add_shared_param()

        chn_states = []
        if not preproc:
            self.preprocs = None
            chn_states.extend(chn_in)
        else:
            chn_cur = edge_kwargs['chn_in'][0]
            self.preprocs = nn.ModuleList()
            for i in range(self.n_input):
                self.preprocs.append(
                    preproc(chn_in[i], chn_cur, 1, 1, 0, False))
                chn_states.append(chn_cur)

        sidx = range(self.n_input)
        for i in range(self.n_nodes):
            e_chn_in = self.allocator.chn_in([chn_states[s] for s in sidx],
                                             sidx, i)
            if not edges is None:
                self.edges.append(edges[i])
                c_chn_in = edges[i].chn_out
            elif not edge_cls is None:
                edge_kwargs['chn_in'] = e_chn_in
                edge_kwargs['stride'] = stride
                if 'shared_a' in edge_kwargs:
                    edge_kwargs['shared_a'] = shared_a
                e = edge_cls(**edge_kwargs)
                self.edges.append(e)
                c_chn_in = e.chn_out
            else:
                self.edges.append(None)
                c_chn_in = e_chn_in
            if not children is None:
                self.subnets.append(children[i])
            elif not child_cls is None:
                child_kwargs['chn_in'] = c_chn_in
                self.subnets.append(child_cls(**child_kwargs))
            else:
                self.subnets.append(None)

        if aggregate is not None:
            self.merge_filter = aggregate(n_in=self.n_states,
                                          n_out=self.n_states // 2)
        else:
            self.merge_filter = None