def build_ops(self, C, op_names, indices_out, indices_inp, concat, reduction): """Compile the cell. :param C: channels of this cell :type C: int :param op_names: list of all the operations in description :type op_names: list of str :param indices_out: list of all output nodes :type indices_out: list of int :param indices_inp: list of all input nodes link to output node :type indices_inp: list of int :param concat: cell concat list of output node :type concat: list of int :param reduction: whether to reduce :type reduction: bool """ self._concat = concat self._multiplier = len(concat) self.out_inp_list = [] temp_list = [] idx_cmp = 2 _op_list = [] for i in range(len(op_names)): if indices_out[i] == idx_cmp: temp_list.append(indices_inp[i]) elif indices_out[i] > idx_cmp: self.out_inp_list.append(temp_list.copy()) temp_list = [] idx_cmp += 1 temp_list.append(indices_inp[i]) else: raise Exception("input index should not less than idx_cmp") stride = 2 if reduction and indices_inp[i] < 2 else 1 op = MixedOp(C=C, stride=stride, ops_cands=op_names[i]) _op_list.append(op) self.op_list = Seq(*tuple(_op_list)) self.oplist = list(self.op_list.children()) self.out_inp_list.append(temp_list.copy()) if len(self.out_inp_list) != self.steps: raise Exception("out_inp_list length should equal to steps")
class Cell(ops.Module): """Cell structure according to desc.""" concat_size = 0 def __init__(self, genotype, steps, concat, reduction, reduction_prev=None, C_prev_prev=None, C_prev=None, C=None): """Init Cell.""" super(Cell, self).__init__() self.genotype = genotype self.steps = steps self.concat = concat self.reduction = reduction self.reduction_prev = reduction_prev self.C_prev_prev = C_prev_prev self.C_prev = C_prev self.C = C self.concat_size = 0 affine = True if isinstance(self.genotype[0][0], list): affine = False if self.reduction_prev: self.preprocess0 = FactorizedReduce(self.C_prev_prev, self.C, affine) else: self.preprocess0 = ReLUConvBN(self.C_prev_prev, self.C, 1, 1, 0, affine) self.preprocess1 = ReLUConvBN(self.C_prev, self.C, 1, 1, 0, affine) op_names, indices_out, indices_inp = zip(*self.genotype) self.build_ops(self.C, op_names, indices_out, indices_inp, self.concat, self.reduction) self.concat_size = len(self.concat) def build_ops(self, C, op_names, indices_out, indices_inp, concat, reduction): """Compile the cell. :param C: channels of this cell :type C: int :param op_names: list of all the operations in description :type op_names: list of str :param indices_out: list of all output nodes :type indices_out: list of int :param indices_inp: list of all input nodes link to output node :type indices_inp: list of int :param concat: cell concat list of output node :type concat: list of int :param reduction: whether to reduce :type reduction: bool """ self._concat = concat self._multiplier = len(concat) self.out_inp_list = [] temp_list = [] idx_cmp = 2 _op_list = [] for i in range(len(op_names)): if indices_out[i] == idx_cmp: temp_list.append(indices_inp[i]) elif indices_out[i] > idx_cmp: self.out_inp_list.append(temp_list.copy()) temp_list = [] idx_cmp += 1 temp_list.append(indices_inp[i]) else: raise Exception("input index should not less than idx_cmp") stride = 2 if reduction and indices_inp[i] < 2 else 1 op = MixedOp(C=C, stride=stride, ops_cands=op_names[i]) _op_list.append(op) self.op_list = Seq(*tuple(_op_list)) self.oplist = list(self.op_list.children()) self.out_inp_list.append(temp_list.copy()) if len(self.out_inp_list) != self.steps: raise Exception("out_inp_list length should equal to steps") def call(self, s0, s1, weights=None, drop_path_prob=0, selected_idxs=None): """Forward function of Cell. :param s0: feature map of previous of previous cell :type s0: torch tensor :param s1: feature map of previous cell :type s1: torch tensor :param weights: weights of operations in cell :type weights: torch tensor, 2 dimension :return: cell output :rtype: torch tensor """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) states = [s0, s1] idx = 0 for i in range(self.steps): hlist = [] for j, inp in enumerate(self.out_inp_list[i]): op = self.oplist[idx + j] if selected_idxs is None: if weights is None: h = op(states[inp]) else: h = op(states[inp], weights[idx + j]) if drop_path_prob > 0. and not isinstance( list(op.children())[0], ops.Identity): h = ops.drop_path(h, drop_path_prob) hlist.append(h) elif selected_idxs[idx + j] == -1: # undecided mix edges h = op(states[inp], weights[idx + j]) hlist.append(h) elif selected_idxs[idx + j] == 0: # zero operation continue else: h = self.oplist[idx + j](states[inp], None, selected_idxs[idx + j]) hlist.append(h) # s = sum(hlist) s = hlist[0] for ii in range(1, len(hlist)): s += hlist[ii] states.append(s) idx += len(self.out_inp_list[i]) states_list = () for i in self._concat: states_list += (states[i], ) # states_list = tuple([states[i] for i in self._concat]) return ops.concat(states_list)