Ejemplo n.º 1
0
 def __init__(self):
     super().__init__()
     self.cell = nn.Cell({
         'first': nn.Linear(16, 16),
         'second': nn.Linear(16, 16, bias=False)
     }, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
     preprocessor=CellPreprocessor(), postprocessor=CellPostprocessor(), merge_op='all')
Ejemplo n.º 2
0
 def __init__(self):
     super().__init__()
     self.cell = nn.Cell(
         [nn.Linear(16, 16),
          nn.Linear(16, 16, bias=False)],
         num_nodes=4,
         num_ops_per_node=2,
         num_predecessors=2,
         merge_op='all')
Ejemplo n.º 3
0
    def __call__(self, repeat_idx: int):
        if self._expect_idx != repeat_idx:
            raise ValueError(
                f'Expect index {self._expect_idx}, found {repeat_idx}')

        # It takes an index that is the index in the repeat.
        # Number of predecessors for each cell is fixed to 2.
        num_predecessors = 2
        # Number of ops per node is fixed to 2.
        num_ops_per_node = 2

        # Reduction cell means stride = 2 and channel multiplied by 2.
        is_reduction_cell = repeat_idx == 0 and self.first_cell_reduce

        # self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
        preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C,
                                        self.last_cell_reduce)

        ops_factory: Dict[str, Callable[
            [int, int, Optional[int]], nn.Module]] = {
                op:  # make final chosen ops named with their aliases
                lambda node_index, op_index, input_index: OPS[op](
                    self.C,
                    2 if is_reduction_cell and
                    (input_index is None or input_index <
                     num_predecessors  # could be none when constructing search sapce
                     ) else 1,
                    True)
                for op in self.op_candidates
            }

        cell = nn.Cell(ops_factory,
                       self.num_nodes,
                       num_ops_per_node,
                       num_predecessors,
                       self.merge_op,
                       preprocessor=preprocessor,
                       postprocessor=CellPostprocessor(),
                       label='reduce' if is_reduction_cell else 'normal')

        # update state
        self.C_prev_in = self.C_in
        self.C_in = self.C * len(cell.output_node_indices)
        self.last_cell_reduce = is_reduction_cell
        self._expect_idx += 1

        return cell
Ejemplo n.º 4
0
 def __init__(self):
     super().__init__()
     self.stem = nn.Conv2d(1, 5, 7, stride=4)
     self.cells = nn.Repeat(
         lambda index: nn.Cell(
             {
                 'conv1':
                 lambda _, __, inp: nn.Conv2d(
                     (5 if index == 0 else 3 * 4)
                     if inp is not None and inp < 1 else 4, 4, 1),
                 'conv2':
                 lambda _, __, inp: nn.Conv2d(
                     (5 if index == 0 else 3 * 4)
                     if inp is not None and inp < 1 else 4,
                     4,
                     3,
                     padding=1),
             },
             3,
             merge_op='loose_end'), (1, 3))
     self.fc = nn.Linear(3 * 4, 10)
Ejemplo n.º 5
0
    def __call__(self, repeat_idx: int):
        if self._expect_idx != repeat_idx:
            raise ValueError(
                f'Expect index {self._expect_idx}, found {repeat_idx}')

        # Reduction cell means stride = 2 and channel multiplied by 2.
        is_reduction_cell = repeat_idx == 0 and self.first_cell_reduce

        # self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
        preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C,
                                        self.last_cell_reduce)

        ops_factory: Dict[str, Callable[[int, int, Optional[int]],
                                        nn.Module]] = {}
        for op in self.op_candidates:
            ops_factory[op] = partial(self.op_factory,
                                      op=op,
                                      channels=cast(int, self.C),
                                      is_reduction_cell=is_reduction_cell)

        cell = nn.Cell(ops_factory,
                       self.num_nodes,
                       self.num_ops_per_node,
                       self.num_predecessors,
                       self.merge_op,
                       preprocessor=preprocessor,
                       postprocessor=CellPostprocessor(),
                       label='reduce' if is_reduction_cell else 'normal')

        # update state
        self.C_prev_in = self.C_in
        self.C_in = self.C * len(cell.output_node_indices)
        self.last_cell_reduce = is_reduction_cell
        self._expect_idx += 1

        return cell
Ejemplo n.º 6
0
 def __init__(self):
     super().__init__()
     self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
Ejemplo n.º 7
0
 def __init__(self):
     super().__init__()
     self.cell = nn.Cell({
         'first': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16),
         'second': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16, bias=False)
     }, num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')