def __init__(self): super().__init__() self.cell = nn.Cell({ 'first': nn.Linear(16, 16), 'second': nn.Linear(16, 16, bias=False) }, num_nodes=4, num_ops_per_node=2, num_predecessors=2, preprocessor=CellPreprocessor(), postprocessor=CellPostprocessor(), merge_op='all')
def __init__(self): super().__init__() self.cell = nn.Cell( [nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def __call__(self, repeat_idx: int): if self._expect_idx != repeat_idx: raise ValueError( f'Expect index {self._expect_idx}, found {repeat_idx}') # It takes an index that is the index in the repeat. # Number of predecessors for each cell is fixed to 2. num_predecessors = 2 # Number of ops per node is fixed to 2. num_ops_per_node = 2 # Reduction cell means stride = 2 and channel multiplied by 2. is_reduction_cell = repeat_idx == 0 and self.first_cell_reduce # self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built. preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce) ops_factory: Dict[str, Callable[ [int, int, Optional[int]], nn.Module]] = { op: # make final chosen ops named with their aliases lambda node_index, op_index, input_index: OPS[op]( self.C, 2 if is_reduction_cell and (input_index is None or input_index < num_predecessors # could be none when constructing search sapce ) else 1, True) for op in self.op_candidates } cell = nn.Cell(ops_factory, self.num_nodes, num_ops_per_node, num_predecessors, self.merge_op, preprocessor=preprocessor, postprocessor=CellPostprocessor(), label='reduce' if is_reduction_cell else 'normal') # update state self.C_prev_in = self.C_in self.C_in = self.C * len(cell.output_node_indices) self.last_cell_reduce = is_reduction_cell self._expect_idx += 1 return cell
def __init__(self): super().__init__() self.stem = nn.Conv2d(1, 5, 7, stride=4) self.cells = nn.Repeat( lambda index: nn.Cell( { 'conv1': lambda _, __, inp: nn.Conv2d( (5 if index == 0 else 3 * 4) if inp is not None and inp < 1 else 4, 4, 1), 'conv2': lambda _, __, inp: nn.Conv2d( (5 if index == 0 else 3 * 4) if inp is not None and inp < 1 else 4, 4, 3, padding=1), }, 3, merge_op='loose_end'), (1, 3)) self.fc = nn.Linear(3 * 4, 10)
def __call__(self, repeat_idx: int): if self._expect_idx != repeat_idx: raise ValueError( f'Expect index {self._expect_idx}, found {repeat_idx}') # Reduction cell means stride = 2 and channel multiplied by 2. is_reduction_cell = repeat_idx == 0 and self.first_cell_reduce # self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built. preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce) ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {} for op in self.op_candidates: ops_factory[op] = partial(self.op_factory, op=op, channels=cast(int, self.C), is_reduction_cell=is_reduction_cell) cell = nn.Cell(ops_factory, self.num_nodes, self.num_ops_per_node, self.num_predecessors, self.merge_op, preprocessor=preprocessor, postprocessor=CellPostprocessor(), label='reduce' if is_reduction_cell else 'normal') # update state self.C_prev_in = self.C_in self.C_in = self.C * len(cell.output_node_indices) self.last_cell_reduce = is_reduction_cell self._expect_idx += 1 return cell
def __init__(self): super().__init__() self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def __init__(self): super().__init__() self.cell = nn.Cell({ 'first': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16), 'second': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16, bias=False) }, num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')