def build_nodes(self, stem_shapes:TensorShapes, conf_cell:Config, cell_index:int, cell_type:CellType, node_count:int, in_shape:TensorShape, out_shape:TensorShape) \ ->Tuple[TensorShapes, List[NodeDesc]]: assert in_shape[0]==out_shape[0] reduction = (cell_type==CellType.Reduction) nodes:List[NodeDesc] = [] conv_params = ConvMacroParams(in_shape[0], out_shape[0]) # add xnas op for each edge for i in range(node_count): edges=[] for j in range(i+2): op_desc = OpDesc('xnas_op', params={ 'conv': conv_params, 'stride': 2 if reduction and j < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[j]) edges.append(edge) nodes.append(NodeDesc(edges=edges, conv_params=conv_params)) out_shapes = [copy.deepcopy(out_shape) for _ in range(node_count)] return out_shapes, nodes
def _build_cell(self, cell_desc: CellDesc, rand_ops: RandOps) -> None: # sanity check: we have random ops for each node assert len(cell_desc.nodes()) == len(rand_ops.ops_and_ins) reduction = (cell_desc.cell_type == CellType.Reduction) # Add random op for each edge for node, (op_names, to_states) in zip(cell_desc.nodes(), rand_ops.ops_and_ins): # as we want cell to be completely random, remove previous edges node.edges.clear() # add random edges for op_name, to_state in zip(op_names, to_states): op_desc = OpDesc(op_name, params={ 'conv': cell_desc.conv_params, 'stride': 2 if reduction and to_state < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[to_state]) node.edges.append(edge)
def build_nodes(self, stem_shapes:TensorShapes, conf_cell:Config, cell_index:int, cell_type:CellType, node_count:int, in_shape:TensorShape, out_shape:TensorShape) \ ->Tuple[TensorShapes, List[NodeDesc]]: assert in_shape[0] == out_shape[0] reduction = (cell_type == CellType.Reduction) nodes: List[NodeDesc] = [] conv_params = ConvMacroParams(in_shape[0], out_shape[0]) # add div op for each edge in each node # how does the stride works? For all ops connected to s0 and s1, we apply # reduction in WxH. All ops connected elsewhere automatically gets # reduced WxH (because all subsequent states are derived from s0 and s1). # Note that channel is increased via conv_params for the cell for i in range(node_count): edges = [] for j in range(i + 2): op_desc = OpDesc('div_op', params={ 'conv': conv_params, 'stride': 2 if reduction and j < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[j]) edges.append(edge) nodes.append(NodeDesc(edges=edges, conv_params=conv_params)) out_shapes = [copy.deepcopy(out_shape) for _ in range(node_count)] return out_shapes, nodes
def _build_cell(self, cell_desc: CellDesc) -> None: for i, node in enumerate(cell_desc.nodes()): input_ids = [] first_proj = False # if input node is connected then it needs projection if self._cell_matrix[0, i + 1]: # nadbench internal node starts at 1 input_ids.append(0) # connect to s0 first_proj = True for j in range(i): # look at all internal vertex before us if self._cell_matrix[j + 1, i + 1]: # if there is connection input_ids.append(j + 2) # offset because of s0, s1 op_desc = OpDesc( 'nasbench101_op', params={ 'conv': cell_desc.conv_params, 'stride': 1, 'vertex_op': self._vertex_ops[i + 1], # offset because of input node 'first_proj': first_proj }, in_len=len(input_ids), trainables=None, children=None) # TODO: should we pass children here? edge = EdgeDesc(op_desc, input_ids=input_ids) node.edges.append(edge)
def build_nodes(self, stem_shapes:TensorShapes, conf_cell:Config, cell_index:int, cell_type:CellType, node_count:int, in_shape:TensorShape, out_shape:TensorShape) \ ->Tuple[TensorShapes, List[NodeDesc]]: # For petridish we add one node with identity to s1. # This will be our seed model to start with. # Later in PetridishSearcher, we will add one more node in parent after each sampling. assert in_shape[0] == out_shape[0] reduction = (cell_type == CellType.Reduction) # channels for conv filters conv_params = ConvMacroParams(in_shape[0], out_shape[0]) # identity op to connect S1 to the node op_desc = OpDesc('skip_connect', params={ 'conv': conv_params, 'stride': 2 if reduction else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[1]) new_node = NodeDesc(edges=[edge], conv_params=conv_params) nodes = [new_node] # each node has same out channels as in channels out_shapes = [copy.deepcopy(out_shape) for _ in nodes] return out_shapes, nodes
def _build_cell(self, cell_desc: CellDesc) -> None: reduction = (cell_desc.cell_type == CellType.Reduction) self._ensure_nonempty_nodes(cell_desc) # we operate on last node, inserting another node before it new_nodes = [n.clone() for n in cell_desc.nodes()] petridish_node = NodeDesc(edges=[]) new_nodes.insert(len(new_nodes) - 1, petridish_node) input_ids = list(range(len(new_nodes))) # 2 + len-2 assert len(input_ids) >= 2 op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op', params={ 'conv': cell_desc.conv_params, # specify strides for each input, later we will # give this to each primitive '_strides':[2 if reduction and j < 2 else 1 \ for j in input_ids], }, in_len=len(input_ids), trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=input_ids) petridish_node.edges.append(edge) # note that post op will be recreated which means there is no # warm start for post op when number of nodes changes cell_desc.reset_nodes(new_nodes, cell_desc.node_ch_out, cell_desc.post_op.name)
def seed(self, model_desc: ModelDesc) -> None: # for petridish we add one node with identity to s1 # this will be our seed model for cell_desc in model_desc.cell_descs(): node_count = len(cell_desc.nodes()) assert node_count >= 1 first_node = cell_desc.nodes()[0] # if there are no edges for 1st node, add identity to s1 if len(first_node.edges) == 0: op_desc = OpDesc( 'skip_connect', params={ 'conv': cell_desc.conv_params, 'stride': 2 if cell_desc.cell_type == CellType.Reduction else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[1]) first_node.edges.append(edge) # remove empty nodes new_nodes = [ n.clone() for n in cell_desc.nodes() if len(n.edges) > 0 ] if len(new_nodes) != len(cell_desc.nodes()): cell_desc.reset_nodes(new_nodes, cell_desc.node_ch_out, cell_desc.post_op.name) self._ensure_nonempty_nodes(cell_desc)
def build_nodes(self, stem_shapes:TensorShapes, conf_cell:Config, cell_index:int, cell_type:CellType, node_count:int, in_shape:TensorShape, out_shape:TensorShape) \ ->Tuple[TensorShapes, List[NodeDesc]]: assert in_shape[0]==out_shape[0] nodes:List[NodeDesc] = [] conv_params = ConvMacroParams(in_shape[0], out_shape[0]) for i in range(node_count): edges = [] input_ids = [] first_proj = False # if input node is connected then it needs projection if self._cell_matrix[0, i+1]: # nadbench internal node starts at 1 input_ids.append(0) # connect to s0 first_proj = True for j in range(i): # look at all internal vertex before us if self._cell_matrix[j+1, i+1]: # if there is connection input_ids.append(j+2) # offset because of s0, s1 op_desc = OpDesc('nasbench101_op', params={ 'conv': conv_params, 'stride': 1, 'vertex_op': self._vertex_ops[i+1], # offset because of input node 'first_proj': first_proj }, in_len=len(input_ids), trainables=None, children=None) # TODO: should we pass children here? edge = EdgeDesc(op_desc, input_ids=input_ids) edges.append(edge) nodes.append(NodeDesc(edges=edges, conv_params=conv_params)) out_shapes = [copy.deepcopy(out_shape) for _ in range(node_count)] return out_shapes, nodes
def build_nodes(self, stem_shapes:TensorShapes, conf_cell:Config, cell_index:int, cell_type:CellType, node_count:int, in_shape:TensorShape, out_shape:TensorShape) \ ->Tuple[TensorShapes, List[NodeDesc]]: assert in_shape[0] == out_shape[0] reduction = (cell_type == CellType.Reduction) ops = self._reduction_ops if reduction else self._normal_ops assert node_count == len(ops.ops_and_ins) nodes: List[NodeDesc] = [] conv_params = ConvMacroParams(in_shape[0], out_shape[0]) for op_names, to_states in ops.ops_and_ins: edges = [] # add random edges for op_name, to_state in zip(op_names, to_states): op_desc = OpDesc(op_name, params={ 'conv': conv_params, 'stride': 2 if reduction and to_state < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[to_state]) edges.append(edge) nodes.append(NodeDesc(edges=edges, conv_params=conv_params)) out_shapes = [copy.deepcopy(out_shape) for _ in range(node_count)] return out_shapes, nodes
def finalize_node(self, node: nn.ModuleList, node_index: int, node_desc: NodeDesc, max_final_edges: int, *args, **kwargs) -> NodeDesc: conf = get_conf() gs_num_sample = conf['nas']['search']['model_desc']['cell']['gs'][ 'num_sample'] # gather the alphas of all edges in this node node_alphas = [] for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp: alphas = [alpha for op, alpha in edge._op.ops()] node_alphas.extend(alphas) # TODO: will creating a tensor from a list of tensors preserve the graph? node_alphas = torch.Tensor(node_alphas) assert node_alphas.nelement() > 0 # sample ops via gumbel softmax sample_storage = [] for _ in range(gs_num_sample): sampled = F.gumbel_softmax(node_alphas, tau=1, hard=True, eps=1e-10, dim=-1) sample_storage.append(sampled) samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0) # send the sampled op weights to their # respective edges to be used for edge level finalize selected_edges = [] counter = 0 for _, edge in enumerate(node): if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp: this_edge_sampled_weights = samples_summed[counter:counter + len(edge._op. PRIMITIVES)] counter += len(edge._op.PRIMITIVES) # finalize the edge if this_edge_sampled_weights.bool().any(): op_desc, _ = edge._op.finalize(this_edge_sampled_weights) new_edge = EdgeDesc(op_desc, edge.input_ids) selected_edges.append(new_edge) # delete excess edges if len(selected_edges) > max_final_edges: # since these are sample edges there is no ordering # amongst them so we just arbitrarily select a few selected_edges = selected_edges[:max_final_edges] return NodeDesc(selected_edges, node_desc.conv_params)
def finalize_node(self, node: nn.ModuleList, max_final_edges: int) -> NodeDesc: # get total number of ops incoming to this node in_ops = [(edge,op) for edge in node \ for op, order in edge._op.ops() if not isinstance(op, Zero)] assert len(in_ops) >= max_final_edges selected = random.sample(in_ops, max_final_edges) # finalize selected op, select 1st value from return which is op finalized desc selected_edges = [EdgeDesc(s[1].finalize()[0], s[0].input_ids) \ for s in selected] return NodeDesc(selected_edges)
def _build_cell(self, cell_desc:CellDesc)->None: reduction = (cell_desc.cell_type==CellType.Reduction) # add xnas op for each edge for i, node in enumerate(cell_desc.nodes()): for j in range(i+2): op_desc = OpDesc('xnas_op', params={ 'conv': cell_desc.conv_params, 'stride': 2 if reduction and j < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[j]) node.edges.append(edge)
def finalize_node(self, node:nn.ModuleList, node_index:int, node_desc:NodeDesc, max_final_edges:int, *args, **kwargs)->NodeDesc: # node is a list of edges assert len(node) >= max_final_edges # covariance matrix shape must be square 2-D assert len(cov.shape) == 2 assert cov.shape[0] == cov.shape[1] # the number of primitive operators has to be greater # than equal to the maximum number of final edges # allowed assert cov.shape[0] >= max_final_edges # get total number of ops incoming to this node num_ops = sum([edge._op.num_valid_div_ops for edge in node]) # and collect some bookkeeping indices edge_num_and_op_ind = [] for j, edge in enumerate(node): if type(edge._op) == DivOp: for k in range(edge._op.num_valid_div_ops): edge_num_and_op_ind.append((j, k)) assert len(edge_num_and_op_ind) == num_ops # run brute force set selection algorithm max_subset, max_mi = compute_brute_force_sol(cov, max_final_edges) # convert the cov indices to edge descs selected_edges = [] for ind in max_subset: edge_ind, op_ind = edge_num_and_op_ind[ind] op_desc = node[edge_ind]._op.get_valid_op_desc(op_ind) new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids) selected_edges.append(new_edge) # for edge in selected_edges: # self.finalize_edge(edge) return NodeDesc(selected_edges, node_desc.conv_params)
def _build_cell(self, cell_desc: CellDesc) -> None: reduction = (cell_desc.cell_type == CellType.Reduction) # add mixed op for each edge in each node # how does the stride works? For all ops connected to s0 and s1, we apply # reduction in WxH. All ops connected elsewhere automatically gets # reduced WxH (because all subsequent states are derived from s0 and s1). # Note that channel is increased via conv_params for the cell for i, node in enumerate(cell_desc.nodes()): for j in range(i + 2): op_desc = OpDesc('mixed_op', params={ 'conv': cell_desc.conv_params, 'stride': 2 if reduction and j < 2 else 1 }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[j]) node.edges.append(edge)
def _add_node(self, model_desc: ModelDesc, model_desc_builder: ModelDescBuilder) -> None: for ci, cell_desc in enumerate(model_desc.cell_descs()): reduction = (cell_desc.cell_type == CellType.Reduction) nodes = cell_desc.nodes() # petridish must seed with one node assert len(nodes) > 0 # input/output channels for all nodes are same conv_params = nodes[0].conv_params # assign input IDs to nodes, s0 and s1 have IDs 0 and 1 # however as we will be inserting new node before last one input_ids = list(range(len(nodes) + 1)) assert len(input_ids) >= 2 # 2 stem inputs op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op', params={ 'conv': conv_params, # specify strides for each input, later we will # give this to each primitive '_strides':[2 if reduction and j < 2 else 1 \ for j in input_ids], }, in_len=len(input_ids), trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=input_ids) new_node = NodeDesc(edges=[edge], conv_params=conv_params) nodes.insert(len(nodes) - 1, new_node) # output shape of all nodes are same node_shapes = cell_desc.node_shapes new_node_shape = copy.deepcopy(node_shapes[-1]) node_shapes.insert(len(node_shapes) - 1, new_node_shape) # post op needs rebuilding because number of inputs to it has changed so input/output channels may be different post_op_shape, post_op_desc = model_desc_builder.build_cell_post_op( cell_desc.stem_shapes, node_shapes, cell_desc.conf_cell, ci) cell_desc.reset_nodes(nodes, node_shapes, post_op_desc, post_op_shape)
def finalize_edge(self, edge) -> Tuple[EdgeDesc, Optional[float]]: op_desc, rank = edge._op.finalize() return (EdgeDesc(op_desc, edge.input_ids), rank)
def finalize_node(self, node:nn.ModuleList, node_index:int, node_desc:NodeDesc, max_final_edges:int, cov:np.array, cell: Cell, node_id: int, *args, **kwargs)->NodeDesc: # node is a list of edges assert len(node) >= max_final_edges # covariance matrix shape must be square 2-D assert len(cov.shape) == 2 assert cov.shape[0] == cov.shape[1] # the number of primitive operators has to be greater # than equal to the maximum number of final edges # allowed assert cov.shape[0] >= max_final_edges # get the order and alpha of all ops other than 'none' in_ops = [(edge,op,alpha,i) for i, edge in enumerate(node) \ for op, alpha in edge._op.ops() if not isinstance(op, Zero)] assert len(in_ops) >= max_final_edges # order all the ops by alpha in_ops_sorted = sorted(in_ops, key=lambda in_op:in_op[2], reverse=True) # keep under consideration top half of the ops num_to_keep = max(max_final_edges, len(in_ops_sorted)//2) top_ops = in_ops_sorted[:num_to_keep] # get the covariance submatrix of the top ops only cov_inds = [] for edge, op, alpha, edge_num in top_ops: ind = self._divnas_cells[cell].node_num_to_node_op_to_cov_ind[node_id][op] cov_inds.append(ind) cov_top_ops = cov[np.ix_(cov_inds, cov_inds)] assert len(cov_inds) == len(top_ops) assert len(top_ops) >= max_final_edges assert cov_top_ops.shape[0] == cov_top_ops.shape[1] assert len(cov_top_ops.shape) == 2 # run brute force set selection algorithm # only on the top ops max_subset, max_mi = compute_brute_force_sol(cov_top_ops, max_final_edges) # note that elements of max_subset are indices into top_ops only selected_edges = [] for ind in max_subset: edge, op, alpha, edge_num = top_ops[ind] op_desc, _ = op.finalize() new_edge = EdgeDesc(op_desc, edge.input_ids) logger.info(f'selected edge: {edge_num}, op: {op_desc.name}') selected_edges.append(new_edge) # save diagnostic information to disk expdir = get_expdir() sns.heatmap(cov_top_ops, annot=True, fmt='.1g', cmap='coolwarm') savename = os.path.join( expdir, f'cell_{cell.desc.id}_node_{node_id}_cov.png') plt.savefig(savename) logger.info('') return NodeDesc(selected_edges, node_desc.conv_params)