Example #1
0
    def _build_cell(self, cell_desc: CellDesc, rand_ops: RandOps) -> None:
        # sanity check: we have random ops for each node
        assert len(cell_desc.nodes()) == len(rand_ops.ops_and_ins)

        reduction = (cell_desc.cell_type == CellType.Reduction)

        # Add random op for each edge
        for node, (op_names, to_states) in zip(cell_desc.nodes(),
                                               rand_ops.ops_and_ins):
            # as we want cell to be completely random, remove previous edges
            node.edges.clear()

            # add random edges
            for op_name, to_state in zip(op_names, to_states):
                op_desc = OpDesc(op_name,
                                 params={
                                     'conv':
                                     cell_desc.conv_params,
                                     'stride':
                                     2 if reduction and to_state < 2 else 1
                                 },
                                 in_len=1,
                                 trainables=None,
                                 children=None)
                edge = EdgeDesc(op_desc, input_ids=[to_state])
                node.edges.append(edge)
    def _build_cell(self, cell_desc: CellDesc) -> None:
        reduction = (cell_desc.cell_type == CellType.Reduction)

        self._ensure_nonempty_nodes(cell_desc)

        # we operate on last node, inserting another node before it
        new_nodes = [n.clone() for n in cell_desc.nodes()]
        petridish_node = NodeDesc(edges=[])
        new_nodes.insert(len(new_nodes) - 1, petridish_node)

        input_ids = list(range(len(new_nodes)))  # 2 + len-2
        assert len(input_ids) >= 2
        op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op',
                            params={
                                'conv': cell_desc.conv_params,
                                # specify strides for each input, later we will
                                # give this to each primitive
                                '_strides':[2 if reduction and j < 2 else 1 \
                                           for j in input_ids],
                            }, in_len=len(input_ids), trainables=None, children=None)
        edge = EdgeDesc(op_desc, input_ids=input_ids)
        petridish_node.edges.append(edge)

        # note that post op will be recreated which means there is no
        # warm start for post op when number of nodes changes
        cell_desc.reset_nodes(new_nodes, cell_desc.node_ch_out,
                              cell_desc.post_op.name)
Example #3
0
    def build_cell(self, in_shapes:TensorShapesList, conf_cell:Config,
                   cell_index:int) ->CellDesc:

        stem_shapes, stems = self.build_cell_stems(in_shapes, conf_cell, cell_index)
        cell_type = self.get_cell_type(cell_index)

        if self.template is None:
            node_count = self.get_node_count(cell_index)
            in_shape = stem_shapes[0] # input shape to noded is same as cell stem
            out_shape = stem_shapes[0] # we ask nodes to keep the output shape same
            node_shapes, nodes = self.build_nodes(stem_shapes, conf_cell,
                                                  cell_index, cell_type, node_count, in_shape, out_shape)
        else:
            node_shapes, nodes = self.build_nodes_from_template(stem_shapes, conf_cell, cell_index)

        post_op_shape, post_op_desc = self.build_cell_post_op(stem_shapes,
            node_shapes, conf_cell, cell_index)

        cell_desc = CellDesc(
            id=cell_index, cell_type=self.get_cell_type(cell_index),
            conf_cell=conf_cell,
            stems=stems, stem_shapes=stem_shapes,
            nodes=nodes, node_shapes=node_shapes,
            post_op=post_op_desc, out_shape=post_op_shape,
            trainables_from=self.get_trainables_from(cell_index)
        )

        # output same shape twice to indicate s0 and s1 inputs for next cell
        in_shapes.append([post_op_shape])

        return cell_desc
    def finalize_cell(self, cell:Cell, cell_index:int,
                      model_desc:ModelDesc, *args, **kwargs)->CellDesc:
        # first finalize each node, we will need to recreate node desc with final version
        max_final_edges = model_desc.max_final_edges

        node_descs:List[NodeDesc] = []
        dcell = self._divnas_cells[id(cell)]
        assert len(cell.dag) == len(list(dcell.node_covs.values()))
        for i,node in enumerate(cell.dag):
            node_cov = dcell.node_covs[id(node)]
            node_desc = self.finalize_node(node, i, cell.desc.nodes()[i],
                                           max_final_edges, node_cov)
            node_descs.append(node_desc)

        # (optional) clear out all activation collection information
        dcell.clear_collect_activations()

        desc = cell.desc
        finalized = CellDesc(
            id = desc.id, cell_type=desc.cell_type, conf_cell=desc.conf_cell,
            stems=[cell.s0_op.finalize()[0], cell.s1_op.finalize()[0]],
            stem_shapes=desc.stem_shapes,
            nodes = node_descs, node_shapes=desc.node_shapes,
            post_op=cell.post_op.finalize()[0],
            out_shape=desc.out_shape,
            trainables_from = desc.trainables_from
        )
        return finalized
Example #5
0
    def _build_cell(self, cell_desc: CellDesc) -> None:
        for i, node in enumerate(cell_desc.nodes()):
            input_ids = []
            first_proj = False  # if input node is connected then it needs projection
            if self._cell_matrix[0,
                                 i + 1]:  # nadbench internal node starts at 1
                input_ids.append(0)  # connect to s0
                first_proj = True

            for j in range(i):  # look at all internal vertex before us
                if self._cell_matrix[j + 1, i + 1]:  # if there is connection
                    input_ids.append(j + 2)  # offset because of s0, s1

            op_desc = OpDesc(
                'nasbench101_op',
                params={
                    'conv': cell_desc.conv_params,
                    'stride': 1,
                    'vertex_op':
                    self._vertex_ops[i + 1],  # offset because of input node
                    'first_proj': first_proj
                },
                in_len=len(input_ids),
                trainables=None,
                children=None)  # TODO: should we pass children here?
            edge = EdgeDesc(op_desc, input_ids=input_ids)
            node.edges.append(edge)
Example #6
0
    def finalize_cell(self, cell: Cell, *args, **kwargs) -> CellDesc:
        # first finalize each node, we will need to recreate node desc with final version
        logger.info(f'cell id {cell.desc.id}')
        node_descs: List[NodeDesc] = []
        dcell = self._divnas_cells[cell]
        assert len(cell.dag) == len(list(dcell.node_covs.values()))
        for i, node in enumerate(cell.dag):
            node_cov = dcell.node_covs[id(node)]
            logger.info(f'node {i}')
            node_desc = self.finalize_node(
                node, cell.desc.max_final_edges, node_cov, cell, i)
            node_descs.append(node_desc)

        # (optional) clear out all activation collection information
        dcell.clear_collect_activations()

        finalized = CellDesc(
            cell_type=cell.desc.cell_type,
            id=cell.desc.id,
            nodes=node_descs,
            s0_op=cell.s0_op.finalize()[0],
            s1_op=cell.s1_op.finalize()[0],
            template_cell=cell.desc.template_cell,
            max_final_edges=cell.desc.max_final_edges,
            node_ch_out=cell.desc.node_ch_out,
            post_op=cell.post_op.finalize()[0]
        )
        return finalized
Example #7
0
    def finalize_cell(self, cell: Cell, cell_index: int, model_desc: ModelDesc,
                      *args, **kwargs) -> CellDesc:
        # first finalize each node, we will need to recreate node desc with final version
        max_final_edges = model_desc.max_final_edges

        node_descs: List[NodeDesc] = []
        for i, node in enumerate(cell.dag):
            node_desc = self.finalize_node(node, i,
                                           cell.desc.nodes()[i],
                                           max_final_edges)
            node_descs.append(node_desc)

        desc = cell.desc
        finalized = CellDesc(
            id=desc.id,
            cell_type=desc.cell_type,
            conf_cell=desc.conf_cell,
            stems=[cell.s0_op.finalize()[0],
                   cell.s1_op.finalize()[0]],
            stem_shapes=desc.stem_shapes,
            nodes=node_descs,
            node_shapes=desc.node_shapes,
            post_op=cell.post_op.finalize()[0],
            out_shape=desc.out_shape,
            trainables_from=desc.trainables_from)
        return finalized
Example #8
0
    def _build_cell(self, cell_desc:CellDesc)->None:
        reduction = (cell_desc.cell_type==CellType.Reduction)

        # add xnas op for each edge
        for i, node in enumerate(cell_desc.nodes()):
            for j in range(i+2):
                op_desc = OpDesc('xnas_op',
                                    params={
                                        'conv': cell_desc.conv_params,
                                        'stride': 2 if reduction and j < 2 else 1
                                    }, in_len=1, trainables=None, children=None)
                edge = EdgeDesc(op_desc, input_ids=[j])
                node.edges.append(edge)
Example #9
0
    def __init__(self, desc: CellDesc, affine: bool, droppath: bool,
                 template_cell: Optional['Cell']
                 ):  # template cell, if any, to use for arch params
        super().__init__()

        # some of these members are public as finalizer needs access
        self.desc = desc
        self.s0_op = Op.create(desc.s0_op, affine=affine)
        self.s1_op = Op.create(desc.s1_op, affine=affine)

        self.dag = Cell._create_dag(desc.nodes(),
                                    affine=affine,
                                    droppath=droppath,
                                    template_cell=template_cell)

        self.post_op = Op.create(desc.post_op, affine=affine)
Example #10
0
    def __init__(self, desc:CellDesc,
                 affine:bool, droppath:bool,
                 trainables_from:Optional['Cell']): # template cell, if any, to use for arch params
        super().__init__()

        # some of these members are public as finalizer needs access
        self.desc = desc

        # TODO: support any number of stems
        assert len(desc.stems)==2, "Cell compiler currently only supports 2 stems"
        self.s0_op = Op.create(desc.stems[0], affine=affine)
        self.s1_op = Op.create(desc.stems[1], affine=affine)

        self.dag =  Cell._create_dag(desc.nodes(),
            affine=affine, droppath=droppath,
            trainables_from=trainables_from)

        self.post_op = Op.create(desc.post_op, affine=affine)
Example #11
0
    def _build_cell(self, cell_desc: CellDesc) -> None:
        reduction = (cell_desc.cell_type == CellType.Reduction)

        # add mixed op for each edge in each node
        # how does the stride works? For all ops connected to s0 and s1, we apply
        # reduction in WxH. All ops connected elsewhere automatically gets
        # reduced WxH (because all subsequent states are derived from s0 and s1).
        # Note that channel is increased via conv_params for the cell
        for i, node in enumerate(cell_desc.nodes()):
            for j in range(i + 2):
                op_desc = OpDesc('mixed_op',
                                 params={
                                     'conv': cell_desc.conv_params,
                                     'stride': 2 if reduction and j < 2 else 1
                                 },
                                 in_len=1,
                                 trainables=None,
                                 children=None)
                edge = EdgeDesc(op_desc, input_ids=[j])
                node.edges.append(edge)
 def _ensure_nonempty_nodes(self, cell_desc: CellDesc):
     assert len(cell_desc.nodes()) > 0
     for node in cell_desc.nodes():
         assert len(node.edges) > 0