def __init__(self, model_desc: ModelDesc, droppath: bool, affine: bool): super().__init__() # some of these fields are public as finalizer needs access to them self.desc = model_desc # TODO: support any number of stems assert len(model_desc.model_stems ) == 2, "Model compiler currently only supports 2 stems" stem0_op = Op.create(model_desc.model_stems[0], affine=affine) stem1_op = Op.create(model_desc.model_stems[1], affine=affine) self.model_stems = nn.ModuleList((stem0_op, stem1_op)) self.cells = nn.ModuleList() self._aux_towers = nn.ModuleList() for i, (cell_desc, aux_tower_desc) in \ enumerate(zip(model_desc.cell_descs(), model_desc.aux_tower_descs)): self._build_cell(cell_desc, aux_tower_desc, droppath, affine) # adaptive pooling output size to 1x1 self.pool_op = Op.create(model_desc.pool_op, affine=affine) # since ch_p records last cell's output channels # it indicates the input channel number self.logits_op = Op.create(model_desc.logits_op, affine=affine)
def build(self, model_desc: ModelDesc, search_iter: int) -> None: # if this is not the first iteration, we add new node to each cell if search_iter > 0: self.add_node(model_desc) for cell_desc in model_desc.cell_descs(): self._build_cell(cell_desc, model_desc.params['gs_num_sample'])
def seed(self, model_desc: ModelDesc) -> None: # for petridish we add one node with identity to s1 # this will be our seed model for cell_desc in model_desc.cell_descs(): node_count = len(cell_desc.nodes()) assert node_count >= 1 first_node = cell_desc.nodes()[0] # if there are no edges for 1st node, add identity to s1 if len(first_node.edges) == 0: op_desc = OpDesc( 'skip_connect', params={ 'conv': cell_desc.conv_params, 'stride': 2 if cell_desc.cell_type == CellType.Reduction else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[1]) first_node.edges.append(edge) # remove empty nodes new_nodes = [ n.clone() for n in cell_desc.nodes() if len(n.edges) > 0 ] if len(new_nodes) != len(cell_desc.nodes()): cell_desc.reset_nodes(new_nodes, cell_desc.node_ch_out, cell_desc.post_op.name) self._ensure_nonempty_nodes(cell_desc)
def build(self, model_desc: ModelDesc, search_iter: int) -> None: cell_matrix = model_desc.params['cell_matrix'] vertex_ops = model_desc.params['vertex_ops'] self._cell_matrix, self._vertex_ops = model_matrix.prune( cell_matrix, vertex_ops) for cell_desc in model_desc.cell_descs(): self._build_cell(cell_desc)
def build(self, model_desc:ModelDesc, search_iter:int)->None: # if this is not the first iteration, we add new node to each cell if search_iter > 0: self.add_node(model_desc) conf = get_conf() self._gs_num_sample = conf['nas']['search']['gs']['num_sample'] for cell_desc in model_desc.cell_descs(): self._build_cell(cell_desc, self._gs_num_sample)
def build(self, model_desc: ModelDesc, search_iter: int) -> None: # create random op sets for two cell types assert len(model_desc.cell_descs()) n_nodes = len(model_desc.cell_descs()[0].nodes()) max_edges = 2 # create two sets of random ops, one for each cell type normal_ops, reduction_ops = RandOps(n_nodes, max_edges), RandOps( n_nodes, max_edges) for cell_desc in model_desc.cell_descs(): # select rand_ops for cell type if cell_desc.cell_type == CellType.Regular: rand_ops = normal_ops elif cell_desc.cell_type == CellType.Reduction: rand_ops = reduction_ops else: raise NotImplementedError( f'CellType {cell_desc.cell_type} is not recognized') self._build_cell(cell_desc, rand_ops)
def _add_node(self, model_desc: ModelDesc, model_desc_builder: ModelDescBuilder) -> None: for ci, cell_desc in enumerate(model_desc.cell_descs()): reduction = (cell_desc.cell_type == CellType.Reduction) nodes = cell_desc.nodes() # petridish must seed with one node assert len(nodes) > 0 # input/output channels for all nodes are same conv_params = nodes[0].conv_params # assign input IDs to nodes, s0 and s1 have IDs 0 and 1 # however as we will be inserting new node before last one input_ids = list(range(len(nodes) + 1)) assert len(input_ids) >= 2 # 2 stem inputs op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op', params={ 'conv': conv_params, # specify strides for each input, later we will # give this to each primitive '_strides':[2 if reduction and j < 2 else 1 \ for j in input_ids], }, in_len=len(input_ids), trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=input_ids) new_node = NodeDesc(edges=[edge], conv_params=conv_params) nodes.insert(len(nodes) - 1, new_node) # output shape of all nodes are same node_shapes = cell_desc.node_shapes new_node_shape = copy.deepcopy(node_shapes[-1]) node_shapes.insert(len(node_shapes) - 1, new_node_shape) # post op needs rebuilding because number of inputs to it has changed so input/output channels may be different post_op_shape, post_op_desc = model_desc_builder.build_cell_post_op( cell_desc.stem_shapes, node_shapes, cell_desc.conf_cell, ci) cell_desc.reset_nodes(nodes, node_shapes, post_op_desc, post_op_shape)
def __init__(self, model_desc:ModelDesc, droppath:bool, affine:bool): super().__init__() # some of these fields are public as finalizer needs access to them self.desc = model_desc self.stem0_op = Op.create(model_desc.stem0_op, affine=affine) self.stem1_op = Op.create(model_desc.stem1_op, affine=affine) self.cells = nn.ModuleList() self._aux_towers = nn.ModuleList() for i, (cell_desc, aux_tower_desc) in \ enumerate(zip(model_desc.cell_descs(), model_desc.aux_tower_descs)): self._build_cell(cell_desc, aux_tower_desc, droppath, affine) # adaptive pooling output size to 1x1 self.pool_op = Op.create(model_desc.pool_op, affine=affine) # since ch_p records last cell's output channels # it indicates the input channel number self.logits_op = Op.create(model_desc.logits_op, affine=affine) # for i,cell in enumerate(self.cells): # print(i, ml_utils.param_size(cell)) logger.info({'model_summary': self.summary()})
def build(self, model_desc: ModelDesc, search_iter: int) -> None: for cell_desc in model_desc.cell_descs(): self._build_cell(cell_desc)