def _build_cell(self, cell_desc: CellDesc, rand_ops: RandOps) -> None: # sanity check: we have random ops for each node assert len(cell_desc.nodes()) == len(rand_ops.ops_and_ins) reduction = (cell_desc.cell_type == CellType.Reduction) # Add random op for each edge for node, (op_names, to_states) in zip(cell_desc.nodes(), rand_ops.ops_and_ins): # as we want cell to be completely random, remove previous edges node.edges.clear() # add random edges for op_name, to_state in zip(op_names, to_states): op_desc = OpDesc(op_name, params={ 'conv': cell_desc.conv_params, 'stride': 2 if reduction and to_state < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[to_state]) node.edges.append(edge)
def _build_cell(self, cell_desc: CellDesc) -> None: reduction = (cell_desc.cell_type == CellType.Reduction) self._ensure_nonempty_nodes(cell_desc) # we operate on last node, inserting another node before it new_nodes = [n.clone() for n in cell_desc.nodes()] petridish_node = NodeDesc(edges=[]) new_nodes.insert(len(new_nodes) - 1, petridish_node) input_ids = list(range(len(new_nodes))) # 2 + len-2 assert len(input_ids) >= 2 op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op', params={ 'conv': cell_desc.conv_params, # specify strides for each input, later we will # give this to each primitive '_strides':[2 if reduction and j < 2 else 1 \ for j in input_ids], }, in_len=len(input_ids), trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=input_ids) petridish_node.edges.append(edge) # note that post op will be recreated which means there is no # warm start for post op when number of nodes changes cell_desc.reset_nodes(new_nodes, cell_desc.node_ch_out, cell_desc.post_op.name)
def build_cell(self, in_shapes:TensorShapesList, conf_cell:Config, cell_index:int) ->CellDesc: stem_shapes, stems = self.build_cell_stems(in_shapes, conf_cell, cell_index) cell_type = self.get_cell_type(cell_index) if self.template is None: node_count = self.get_node_count(cell_index) in_shape = stem_shapes[0] # input shape to noded is same as cell stem out_shape = stem_shapes[0] # we ask nodes to keep the output shape same node_shapes, nodes = self.build_nodes(stem_shapes, conf_cell, cell_index, cell_type, node_count, in_shape, out_shape) else: node_shapes, nodes = self.build_nodes_from_template(stem_shapes, conf_cell, cell_index) post_op_shape, post_op_desc = self.build_cell_post_op(stem_shapes, node_shapes, conf_cell, cell_index) cell_desc = CellDesc( id=cell_index, cell_type=self.get_cell_type(cell_index), conf_cell=conf_cell, stems=stems, stem_shapes=stem_shapes, nodes=nodes, node_shapes=node_shapes, post_op=post_op_desc, out_shape=post_op_shape, trainables_from=self.get_trainables_from(cell_index) ) # output same shape twice to indicate s0 and s1 inputs for next cell in_shapes.append([post_op_shape]) return cell_desc
def finalize_cell(self, cell:Cell, cell_index:int, model_desc:ModelDesc, *args, **kwargs)->CellDesc: # first finalize each node, we will need to recreate node desc with final version max_final_edges = model_desc.max_final_edges node_descs:List[NodeDesc] = [] dcell = self._divnas_cells[id(cell)] assert len(cell.dag) == len(list(dcell.node_covs.values())) for i,node in enumerate(cell.dag): node_cov = dcell.node_covs[id(node)] node_desc = self.finalize_node(node, i, cell.desc.nodes()[i], max_final_edges, node_cov) node_descs.append(node_desc) # (optional) clear out all activation collection information dcell.clear_collect_activations() desc = cell.desc finalized = CellDesc( id = desc.id, cell_type=desc.cell_type, conf_cell=desc.conf_cell, stems=[cell.s0_op.finalize()[0], cell.s1_op.finalize()[0]], stem_shapes=desc.stem_shapes, nodes = node_descs, node_shapes=desc.node_shapes, post_op=cell.post_op.finalize()[0], out_shape=desc.out_shape, trainables_from = desc.trainables_from ) return finalized
def _build_cell(self, cell_desc: CellDesc) -> None: for i, node in enumerate(cell_desc.nodes()): input_ids = [] first_proj = False # if input node is connected then it needs projection if self._cell_matrix[0, i + 1]: # nadbench internal node starts at 1 input_ids.append(0) # connect to s0 first_proj = True for j in range(i): # look at all internal vertex before us if self._cell_matrix[j + 1, i + 1]: # if there is connection input_ids.append(j + 2) # offset because of s0, s1 op_desc = OpDesc( 'nasbench101_op', params={ 'conv': cell_desc.conv_params, 'stride': 1, 'vertex_op': self._vertex_ops[i + 1], # offset because of input node 'first_proj': first_proj }, in_len=len(input_ids), trainables=None, children=None) # TODO: should we pass children here? edge = EdgeDesc(op_desc, input_ids=input_ids) node.edges.append(edge)
def finalize_cell(self, cell: Cell, *args, **kwargs) -> CellDesc: # first finalize each node, we will need to recreate node desc with final version logger.info(f'cell id {cell.desc.id}') node_descs: List[NodeDesc] = [] dcell = self._divnas_cells[cell] assert len(cell.dag) == len(list(dcell.node_covs.values())) for i, node in enumerate(cell.dag): node_cov = dcell.node_covs[id(node)] logger.info(f'node {i}') node_desc = self.finalize_node( node, cell.desc.max_final_edges, node_cov, cell, i) node_descs.append(node_desc) # (optional) clear out all activation collection information dcell.clear_collect_activations() finalized = CellDesc( cell_type=cell.desc.cell_type, id=cell.desc.id, nodes=node_descs, s0_op=cell.s0_op.finalize()[0], s1_op=cell.s1_op.finalize()[0], template_cell=cell.desc.template_cell, max_final_edges=cell.desc.max_final_edges, node_ch_out=cell.desc.node_ch_out, post_op=cell.post_op.finalize()[0] ) return finalized
def finalize_cell(self, cell: Cell, cell_index: int, model_desc: ModelDesc, *args, **kwargs) -> CellDesc: # first finalize each node, we will need to recreate node desc with final version max_final_edges = model_desc.max_final_edges node_descs: List[NodeDesc] = [] for i, node in enumerate(cell.dag): node_desc = self.finalize_node(node, i, cell.desc.nodes()[i], max_final_edges) node_descs.append(node_desc) desc = cell.desc finalized = CellDesc( id=desc.id, cell_type=desc.cell_type, conf_cell=desc.conf_cell, stems=[cell.s0_op.finalize()[0], cell.s1_op.finalize()[0]], stem_shapes=desc.stem_shapes, nodes=node_descs, node_shapes=desc.node_shapes, post_op=cell.post_op.finalize()[0], out_shape=desc.out_shape, trainables_from=desc.trainables_from) return finalized
def _build_cell(self, cell_desc:CellDesc)->None: reduction = (cell_desc.cell_type==CellType.Reduction) # add xnas op for each edge for i, node in enumerate(cell_desc.nodes()): for j in range(i+2): op_desc = OpDesc('xnas_op', params={ 'conv': cell_desc.conv_params, 'stride': 2 if reduction and j < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[j]) node.edges.append(edge)
def __init__(self, desc: CellDesc, affine: bool, droppath: bool, template_cell: Optional['Cell'] ): # template cell, if any, to use for arch params super().__init__() # some of these members are public as finalizer needs access self.desc = desc self.s0_op = Op.create(desc.s0_op, affine=affine) self.s1_op = Op.create(desc.s1_op, affine=affine) self.dag = Cell._create_dag(desc.nodes(), affine=affine, droppath=droppath, template_cell=template_cell) self.post_op = Op.create(desc.post_op, affine=affine)
def __init__(self, desc:CellDesc, affine:bool, droppath:bool, trainables_from:Optional['Cell']): # template cell, if any, to use for arch params super().__init__() # some of these members are public as finalizer needs access self.desc = desc # TODO: support any number of stems assert len(desc.stems)==2, "Cell compiler currently only supports 2 stems" self.s0_op = Op.create(desc.stems[0], affine=affine) self.s1_op = Op.create(desc.stems[1], affine=affine) self.dag = Cell._create_dag(desc.nodes(), affine=affine, droppath=droppath, trainables_from=trainables_from) self.post_op = Op.create(desc.post_op, affine=affine)
def _build_cell(self, cell_desc: CellDesc) -> None: reduction = (cell_desc.cell_type == CellType.Reduction) # add mixed op for each edge in each node # how does the stride works? For all ops connected to s0 and s1, we apply # reduction in WxH. All ops connected elsewhere automatically gets # reduced WxH (because all subsequent states are derived from s0 and s1). # Note that channel is increased via conv_params for the cell for i, node in enumerate(cell_desc.nodes()): for j in range(i + 2): op_desc = OpDesc('mixed_op', params={ 'conv': cell_desc.conv_params, 'stride': 2 if reduction and j < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[j]) node.edges.append(edge)
def _ensure_nonempty_nodes(self, cell_desc: CellDesc): assert len(cell_desc.nodes()) > 0 for node in cell_desc.nodes(): assert len(node.edges) > 0