def __init__(self, n_nodes, C_pp, C_p, C, reduction_p, reduction): """ Args: n_nodes: # of intermediate n_nodes C_pp: C_out[k-2] C_p : C_out[k-1] C : C_in[k] (current) reduction_p: flag for whether the previous cell is reduction cell or not reduction: flag for whether the current cell is reduction cell or not """ super().__init__() self.reduction = reduction self.n_nodes = n_nodes # If previous cell is reduction cell, current input size does not match with # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. if reduction_p: self.preproc0 = ops.FactorizedReduce(C_pp, C, affine=False) else: self.preproc0 = ops.Conv2dBlock(C_pp, C, 1, 1, 0, affine=False) self.preproc1 = ops.Conv2dBlock(C_p, C, 1, 1, 0, affine=False) # generate dag self.dag = nn.ModuleList() for i in range(self.n_nodes): self.dag.append(nn.ModuleList()) for j in range(2 + i): # include 2 input nodes # reduction should be used only for input node stride = 2 if reduction and j < 2 else 1 op = ops.MixedOp(C, stride) self.dag[i].append(op)
def __init__(self, n_big_nodes, cells, start_p, end_p, C): super().__init__() self.n_big_nodes = n_big_nodes self.preproc0 = ops.StdConv(C, C, 1, 1, 0, affine=False) self.preproc1 = ops.StdConv(C, C, 1, 1, 0, affine=False) self.DAG = nn.ModuleList() for i in range(self.n_big_nodes): self.DAG.append(nn.ModuleList()) if i < 1: for _ in range(2 + i): stride = 1 op = ops.MixedOp(C, stride) self.DAG[i].append(op) else: for _ in range(3): stride = 1 op = ops.MixedOp(C, stride) self.DAG[i].append(op) for k in range(start_p, end_p): self.DAG.append(cells[k])
def __init__(self, in_features, n_nodes): """ Args: n_nodes: # of intermediate n_nodes """ super().__init__() self.n_nodes = n_nodes self.in_features = in_features self.out_features = None # generate dag self.dag = nn.ModuleList() for i in range(self.n_nodes): self.dag.append(nn.ModuleList()) op = ops.MixedOp(in_features) in_features = op.out_features self.dag[-1].append(op) self.out_features = in_features print(f"new cell(in={self.in_features}, out={self.out_features})")