Ejemplo n.º 1
0
    def __init__(self, model_desc: ModelDesc, droppath: bool, affine: bool):
        super().__init__()

        # some of these fields are public as finalizer needs access to them
        self.desc = model_desc

        # TODO: support any number of stems
        assert len(model_desc.model_stems
                   ) == 2, "Model compiler currently only supports 2 stems"
        stem0_op = Op.create(model_desc.model_stems[0], affine=affine)
        stem1_op = Op.create(model_desc.model_stems[1], affine=affine)
        self.model_stems = nn.ModuleList((stem0_op, stem1_op))

        self.cells = nn.ModuleList()
        self._aux_towers = nn.ModuleList()

        for i, (cell_desc, aux_tower_desc) in \
                enumerate(zip(model_desc.cell_descs(), model_desc.aux_tower_descs)):
            self._build_cell(cell_desc, aux_tower_desc, droppath, affine)

        # adaptive pooling output size to 1x1
        self.pool_op = Op.create(model_desc.pool_op, affine=affine)
        # since ch_p records last cell's output channels
        # it indicates the input channel number
        self.logits_op = Op.create(model_desc.logits_op, affine=affine)
Ejemplo n.º 2
0
    def __init__(self, op_desc: OpDesc, arch_params: Optional[ArchParams],
                 affine: bool):
        super().__init__()

        # assume last PRIMITIVE is 'none'
        assert DivOp.PRIMITIVES[-1] == 'none'

        conf = get_conf()
        trainer = conf['nas']['search']['divnas']['archtrainer']
        finalizer = conf['nas']['search']['finalizer']

        if trainer == 'noalpha' and finalizer == 'default':
            raise NotImplementedError(
                'noalpha trainer is not implemented for the default finalizer')

        if trainer != 'noalpha':
            self._setup_arch_params(arch_params)
        else:
            self._alphas = None

        self._ops = nn.ModuleList()
        for primitive in DivOp.PRIMITIVES:
            op = Op.create(OpDesc(primitive,
                                  op_desc.params,
                                  in_len=1,
                                  trainables=None),
                           affine=affine,
                           arch_params=None)
            self._ops.append(op)

        # various state variables for diversity
        self._collect_activations = False
        self._forward_counter = 0
        self._batch_activs = None
Ejemplo n.º 3
0
    def __init__(self, desc: CellDesc, affine: bool, droppath: bool,
                 template_cell: Optional['Cell']
                 ):  # template cell, if any, to use for arch params
        super().__init__()

        # some of these members are public as finalizer needs access
        self.desc = desc
        self.s0_op = Op.create(desc.s0_op, affine=affine)
        self.s1_op = Op.create(desc.s1_op, affine=affine)

        self.dag = Cell._create_dag(desc.nodes(),
                                    affine=affine,
                                    droppath=droppath,
                                    template_cell=template_cell)

        self.post_op = Op.create(desc.post_op, affine=affine)
Ejemplo n.º 4
0
    def __init__(self, desc:CellDesc,
                 affine:bool, droppath:bool,
                 trainables_from:Optional['Cell']): # template cell, if any, to use for arch params
        super().__init__()

        # some of these members are public as finalizer needs access
        self.desc = desc

        # TODO: support any number of stems
        assert len(desc.stems)==2, "Cell compiler currently only supports 2 stems"
        self.s0_op = Op.create(desc.stems[0], affine=affine)
        self.s1_op = Op.create(desc.stems[1], affine=affine)

        self.dag =  Cell._create_dag(desc.nodes(),
            affine=affine, droppath=droppath,
            trainables_from=trainables_from)

        self.post_op = Op.create(desc.post_op, affine=affine)
Ejemplo n.º 5
0
    def __init__(self, op_desc: OpDesc, arch_params: Optional[ArchParams],
                 affine: bool):
        super().__init__()

        vertex_op_name = op_desc.params['vertex_op']
        proj_first = op_desc.params[
            'proj_first']  # first input needs projection

        self._vertex_op = Op.create(OpDesc(vertex_op_name,
                                           params=op_desc.params,
                                           in_len=1,
                                           trainables=None),
                                    affine=affine,
                                    arch_params=None)

        self._in_len = op_desc.in_len


        self._proj_op = Op.create(OpDesc('convbnrelu_1x1', params=op_desc.params,
                                                in_len=1, trainables=None),
                                        affine=affine, arch_params=None) \
                        if proj_first else None
Ejemplo n.º 6
0
 def __init__(self, desc: EdgeDesc, affine: bool, droppath: bool,
              template_edge: Optional['DagEdge']) -> None:
     super().__init__()
     # we may need to wrap op is droppath is needed
     self._wrapped = self._op = Op.create(
         desc.op_desc, affine,
         template_edge.op().arch_params()
         if template_edge is not None else None)
     if droppath and self._op.can_drop_path():
         assert self.training
         self._wrapped = nn.Sequential(self._op, DropPath_())
     self.input_ids = desc.input_ids
     self.desc = desc
Ejemplo n.º 7
0
    def __init__(self, model_desc:ModelDesc, droppath:bool, affine:bool):
        super().__init__()

        # some of these fields are public as finalizer needs access to them
        self.desc = model_desc
        self.stem0_op = Op.create(model_desc.stem0_op, affine=affine)
        self.stem1_op = Op.create(model_desc.stem1_op, affine=affine)

        self.cells = nn.ModuleList()
        self._aux_towers = nn.ModuleList()

        for i, (cell_desc, aux_tower_desc) in \
                enumerate(zip(model_desc.cell_descs(), model_desc.aux_tower_descs)):
            self._build_cell(cell_desc, aux_tower_desc, droppath, affine)

        # adaptive pooling output size to 1x1
        self.pool_op = Op.create(model_desc.pool_op, affine=affine)
        # since ch_p records last cell's output channels
        # it indicates the input channel number
        self.logits_op = Op.create(model_desc.logits_op, affine=affine)

        # for i,cell in enumerate(self.cells):
        #     print(i, ml_utils.param_size(cell))
        logger.info({'model_summary': self.summary()})
Ejemplo n.º 8
0
    def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
                 affine:bool):
        super().__init__()

        # assume last PRIMITIVE is 'none'
        assert GsOp.PRIMITIVES[-1] == 'none'

        self._gs_num_sample = op_desc.params['gs_num_sample']

        self._ops = nn.ModuleList()
        for primitive in GsOp.PRIMITIVES:
            op = Op.create(
                OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
                affine=affine, arch_params=None)
            self._ops.append(op)
        # we do this at the end so that we can capture all arch params registered by
        # any previous child modules
        self._setup_arch_params(arch_params)
Ejemplo n.º 9
0
    def __init__(self, op_desc: OpDesc, arch_params: Optional[ArchParams],
                 reduction: bool, affine: bool):
        super().__init__()

        # assume last PRIMITIVE is 'none' (this is used for finalize)
        assert PetridishOp.PRIMITIVES[-1] == 'none'

        # create edges for the op, each edge connects input state,
        # within each edge we will have all N primitives
        self._edges = nn.ModuleList()

        for i in range(op_desc.in_len):
            # edge contains all primitives with alphas
            edge = nn.ModuleList()
            self._edges.append(edge)

            # for each input stride could be different,
            # so we will make copy of our params and then set stride for this input
            params = deepcopy(op_desc.params)
            params['stride'] = op_desc.params['_strides'][i]

            # create primitives for the edge
            for primitive in PetridishOp.PRIMITIVES:
                primitive_op = Op.create(OpDesc(primitive,
                                                params=params,
                                                in_len=1,
                                                trainables=None),
                                         affine=affine,
                                         arch_params=None)
                # wrap primitive with sg
                op = nn.Sequential(StopGradient(), primitive_op)
                edge.append(op)

        # TODO: check with Dey: Do we really need StopForwardReductionOp
        #   or StopGradientReductionOp because these two will only make sense
        #   for cell stems.
        # NOTE: Consider the case where prev_prev is normal, prev is reduction
        # then s_0 is twice as big in each dimension as s_1 and the number of channels
        # won't match. So you have to use StopGradientReductionOp on s_1 to make it match.
        self._sf = StopForward()

        # we do this at the end so that we can capture all arch params registered by
        # any previous child modules
        self._setup_arch_params(arch_params, op_desc.in_len)
Ejemplo n.º 10
0
    def __init__(self, op_desc: OpDesc, arch_params: Optional[ArchParams],
                 affine: bool):
        super().__init__()

        # assume last PRIMITIVE is 'none'
        assert XnasOp.PRIMITIVES[-1] == 'none'

        self._ops = nn.ModuleList()
        for primitive in XnasOp.PRIMITIVES:
            op = Op.create(OpDesc(primitive,
                                  op_desc.params,
                                  in_len=1,
                                  trainables=None),
                           affine=affine,
                           arch_params=None)
            self._ops.append(op)

        # for getting gradients to non-leaf node
        self._is_first_call = True
        self._avg_grad_meter = AverageMeter()

        # we do this at the end so that we can capture all arch params registered by
        # any previous child modules
        self._setup_arch_params(arch_params)
Ejemplo n.º 11
0
 def _stem_reductions(stems:List[OpDesc])->List[int]:
     # create stem ops to find out reduction factors
     ops = [Op.create(stem, affine=False) for stem in stems]
     assert all(isinstance(op, StemBase) for op in ops)
     return list(op.reduction for op in ops)