def __init__(self, genotype, C_prev_prev, C_prev, C, k=9, d=1): super(Cell, self).__init__() self.preprocess0 = BasicConv([C_prev_prev, C], 'relu', 'batch', bias=False) self.preprocess1 = BasicConv([C_prev, C], 'relu', 'batch', bias=False) self.dilated_knn_graph = DilatedKnn2d(k=k, dilation=d) op_names, indices = zip(*genotype.normal) concat = genotype.normal_concat self._compile(C, op_names, indices, concat)
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, k=9, d=1): super(Cell, self).__init__() self.preprocess0 = BasicConv([C_prev_prev, C], 'relu', 'batch', bias=False) self.preprocess1 = BasicConv([C_prev, C], 'relu', 'batch', bias=False) self._steps = steps self._multiplier = multiplier self.dilated_knn_graph = DilatedKnn2d(k=k, dilation=d) self._ops = nn.ModuleList() self._bns = nn.ModuleList() for i in range(self._steps): for j in range(2 + i): stride = 1 op = MixedOp(C, stride) self._ops.append(op)