Example #1
0
    def __init__(self,
                 prev_layers,
                 channels,
                 stride,
                 drop_path_keep_prob=None,
                 node_id=0,
                 layer_id=0,
                 layers=0,
                 steps=0):
        super(Node, self).__init__()
        self.channels = channels
        self.stride = stride
        self.drop_path_keep_prob = drop_path_keep_prob
        self.node_id = node_id
        self.layer_id = layer_id
        self.layers = layers
        self.steps = steps
        self.x_op = nn.ModuleList()
        self.y_op = nn.ModuleList()

        num_possible_inputs = node_id + 2

        # avg_pool
        self.x_avg_pool = WSAvgPool2d(3, padding=1)
        # max_pool
        self.x_max_pool = WSMaxPool2d(3, padding=1)
        # sep_conv
        self.x_sep_conv_3 = WSSepConv(num_possible_inputs, channels, channels,
                                      3, None, 1)
        self.x_sep_conv_5 = WSSepConv(num_possible_inputs, channels, channels,
                                      5, None, 2)
        if self.stride > 1:
            assert self.stride == 2
            self.x_id_reduce_1 = FactorizedReduce(prev_layers[0][-1], channels)
            self.x_id_reduce_2 = FactorizedReduce(prev_layers[1][-1], channels)

        # avg_pool
        self.y_avg_pool = WSAvgPool2d(3, padding=1)
        # max_pool
        self.y_max_pool = WSMaxPool2d(3, padding=1)
        # sep_conv
        self.y_sep_conv_3 = WSSepConv(num_possible_inputs, channels, channels,
                                      3, None, 1)
        self.y_sep_conv_5 = WSSepConv(num_possible_inputs, channels, channels,
                                      5, None, 2)
        if self.stride > 1:
            assert self.stride == 2
            self.y_id_reduce_1 = FactorizedReduce(prev_layers[0][-1], channels)
            self.y_id_reduce_2 = FactorizedReduce(prev_layers[1][-1], channels)

        self.out_shape = [
            prev_layers[0][0] // stride, prev_layers[0][1] // stride, channels
        ]
Example #2
0
    def __init__(self, genotype, C_prev_prev, C_prev, C, reduction,
                 reduction_prev, height, width):
        """

        :param genotype:
        :param C_prev_prev:
        :param C_prev:
        :param C:
        :param reduction:
        :param reduction_prev:
        """
        super(Cell, self).__init__()

        print(C_prev_prev, C_prev, C)

        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)

        if reduction:
            first_layers, indices, second_layers = zip(*genotype.reduce)
            concat = genotype.reduce_concat
            bottleneck = genotype.reduce_bottleneck
        else:
            first_layers, indices, second_layers = zip(*genotype.normal)
            concat = genotype.normal_concat
            bottleneck = genotype.normal_bottleneck
        self._compile(C, first_layers, second_layers, indices, concat,
                      reduction, bottleneck, height, width)
Example #3
0
    def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction,
                 reduction_prev):
        super(Cell, self).__init__()
        self.reduction = reduction

        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev,
                                          C,
                                          1,
                                          1,
                                          0,
                                          affine=False)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
        self._steps = steps
        self._multiplier = multiplier

        self._ops = nn.ModuleList()
        self._bns = nn.ModuleList()
        for i in range(self._steps):
            for j in range(2 + i):
                stride = 2 if reduction and j < 2 else 1
                op = MixedOp(C, stride)
                self._ops.append(op)
Example #4
0
    def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction,
                 reduction_prev, weights):
        super(InnerCell, self).__init__()
        self.reduction = reduction

        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev,
                                          C,
                                          1,
                                          1,
                                          0,
                                          affine=False)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
        self._steps = steps
        self._multiplier = multiplier

        self._ops = nn.ModuleList()
        self._bns = nn.ModuleList()
        # len(self._ops)=2+3+4+5=14
        offset = 0
        keys = list(OPS.keys())
        for i in range(self._steps):
            for j in range(2 + i):
                stride = 2 if reduction and j < 2 else 1
                weight = weights.data[offset + j]
                choice = keys[weight.argmax()]
                op = OPS[choice](C, stride, False)
                if 'pool' in choice:
                    op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
                self._ops.append(op)
            offset += i + 2
Example #5
0
    def __init__(self, num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space):
        """
        Args:
            num_nodes: Number of intermediate cell nodes
            c_prev_prev: channels_out[k-2]
            c_prev : Channels_out[k-1]
            c_cur   : Channels_in[k] (current)
            reduction_prev: flag for whether the previous cell is reduction cell or not
            reduction_cur: flag for whether the current cell is reduction cell or not
        """

        super(Cell, self).__init__()
        self.reduction_cur = reduction_cur
        self.num_nodes = num_nodes

        # If previous cell is reduction cell, current input size does not match with
        # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing
        if reduction_prev:
            self.preprocess0 = FactorizedReduce(c_prev_prev, c_cur)
        else:
            self.preprocess0 = StdConv(c_prev_prev, c_cur, kernel_size=1, stride=1, padding=0)
        self.preprocess1 = StdConv(c_prev, c_cur, kernel_size=1, stride=1, padding=0)

        # Generate dag from mixed operations
        self.dag_ops = nn.ModuleList()

        for i in range(self.num_nodes):
            self.dag_ops.append(nn.ModuleList())
            # Include 2 input nodes
            for j in range(2+i):
                # Reduction with stride = 2 must be only for the input node
                stride = 2 if reduction_cur and j < 2 else 1
                op = MixedOp(c_cur, stride, search_space)
                self.dag_ops[i].append(op)
Example #6
0
    def __init__(self, steps, multiplier, c_prev_prev, c_prev, c, reduction,
                 reduction_prev, switches, p):
        super(Cell, self).__init__()
        self.reduction = reduction
        self.p = p
        if reduction_prev:
            self.preprocess0 = FactorizedReduce(c_prev_prev, c, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(c_prev_prev,
                                          c,
                                          1,
                                          1,
                                          0,
                                          affine=False)
        self.preprocess1 = ReLUConvBN(c_prev, c, 1, 1, 0, affine=False)
        self._steps = steps
        self._multiplier = multiplier

        self.cell_ops = nn.ModuleList()
        switch_count = 0
        for i in range(self._steps):
            for j in range(2 + i):
                stride = 2 if reduction and j < 2 else 1
                op = MixedOp(c,
                             stride,
                             switches=switches,
                             index=switch_count,
                             p=self.p)
                self.cell_ops.append(op)
                switch_count = switch_count + 1
Example #7
0
    def __init__(self,
                 search_space,
                 prev_layers,
                 nodes,
                 channels,
                 reduction,
                 layer_id,
                 layers,
                 steps,
                 drop_path_keep_prob=None):
        super(Cell, self).__init__()
        self.search_space = search_space
        assert len(prev_layers) == 2
        print(prev_layers)
        self.reduction = reduction
        self.layer_id = layer_id
        self.layers = layers
        self.steps = steps
        self.drop_path_keep_prob = drop_path_keep_prob
        self.ops = nn.ModuleList()
        self.nodes = nodes

        # maybe calibrate size
        prev_layers = [list(prev_layers[0]), list(prev_layers[1])]
        self.maybe_calibrate_size = MaybeCalibrateSize(prev_layers, channels)
        prev_layers = self.maybe_calibrate_size.out_shape

        stride = 2 if self.reduction else 1
        for i in range(self.nodes):
            node = Node(search_space, prev_layers, channels, stride,
                        drop_path_keep_prob, i, layer_id, layers, steps)
            self.ops.append(node)
            prev_layers.append(node.out_shape)
        out_hw = min([shape[0] for i, shape in enumerate(prev_layers)])

        if reduction:
            self.fac_1 = FactorizedReduce(prev_layers[0][-1], channels,
                                          prev_layers[0])
            self.fac_2 = FactorizedReduce(prev_layers[1][-1], channels,
                                          prev_layers[1])
        self.final_combine_conv = WSReLUConvBN(self.nodes + 2, channels,
                                               channels, 1)

        self.out_shape = [out_hw, out_hw, channels]
Example #8
0
    def __init__(self, C_op0_prev, C_op1_prev, C, reduction, op0_reduction,
                 op1_reduction, op1_name, op2_name, op0_prev, op1_prev):
        super(Cell, self).__init__()
        self.multiplier = 2
        if reduction:
            stride = 2
        else:
            stride = 1
        self.op0_re = op0_reduction
        self.op1_re = op1_reduction

        if op0_prev == 1 and op1_prev == 2:
            if op0_reduction and not op1_reduction:
                self.preprocess0 = ReLUConvBN(C_op0_prev, C, 1, 1, 0)
                self.preprocess1 = FactorizedReduce(C_op1_prev, C)
            else:
                self.preprocess0 = ReLUConvBN(C_op0_prev, C, 1, 1, 0)
                self.preprocess1 = ReLUConvBN(C_op1_prev, C, 1, 1, 0)
        elif op0_prev == 2 and op1_prev == 1:
            if not op0_reduction and op1_reduction:
                self.preprocess0 = FactorizedReduce(C_op0_prev, C)
                self.preprocess1 = ReLUConvBN(C_op1_prev, C, 1, 1, 0)
            else:
                self.preprocess0 = ReLUConvBN(C_op0_prev, C, 1, 1, 0)
                self.preprocess1 = ReLUConvBN(C_op1_prev, C, 1, 1, 0)
        else:
            self.preprocess0 = ReLUConvBN(C_op0_prev, C, 1, 1, 0)
            self.preprocess1 = ReLUConvBN(C_op1_prev, C, 1, 1, 0)

        # if op0_reduction and op1_reduction:
        #     self.preprocess0 = ReLUConvBN(op0_prev, C, 1, 1, 0)
        #     self.preprocess1 = ReLUConvBN(op1_prev, C, 1, 1, 0)
        # if op0_reduction and not op1_reduction:
        #     self.preprocess0 = ReLUConvBN(op0_prev, C, 1, 1, 0)
        #     self.preprocess1 = FactorizedReduce(op1_prev, C)
        # elif not op0_reduction and op1_reduction:
        #     self.preprocess0 = FactorizedReduce(op0_prev, C)
        #     self.preprocess1 = ReLUConvBN(op1_prev, C, 1, 1, 0)
        # else:
        #     self.preprocess0 = ReLUConvBN(op0_prev, C, 1, 1, 0)
        #     self.preprocess1 = ReLUConvBN(op1_prev, C, 1, 1, 0)
        self.op1 = OPS[op1_name](C, stride, True)
        self.op2 = OPS[op2_name](C, stride, True)
Example #9
0
    def __init__(self, steps, multiplier, cpp, cp, c, reduction,
                 reduction_prev):
        """

        :param steps: 4, number of layers inside a cell
        :param multiplier: 4
        :param cpp: 48
        :param cp: 48
        :param c: 16
        :param reduction: indicates whether to reduce the output maps width
        :param reduction_prev: when previous cell reduced width, s1_d = s0_d//2
        in order to keep same shape between s1 and s0, we adopt prep0 layer to
        reduce the s0 width by half.
        """
        super().__init__()

        # indicating current cell is reduction or not
        self.reduction = reduction
        self.reduction_prev = reduction_prev

        # preprocess0 deal with output from prev_prev cell
        if reduction_prev:
            # if prev cell has reduced channel/double width,
            # it will reduce width by half
            self.preprocess0 = FactorizedReduce(cpp, c, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(cpp,
                                          c,
                                          kernel_size=1,
                                          stride=1,
                                          padding=0,
                                          affine=False)
        # preprocess1 deal with output from prev cell
        self.preprocess1 = ReLUConvBN(cp,
                                      c,
                                      kernel_size=1,
                                      stride=1,
                                      padding=0,
                                      affine=False)

        # steps inside a cell
        self.steps = steps  # 4
        self.multiplier = multiplier  # 4
        self.layers = nn.ModuleList()

        for i in range(self.steps):
            for j in range(2 + i):
                # for reduction cell, it will reduce the heading 2 inputs only
                stride = 2 if reduction and j < 2 else 1  # 只对和 s0,s1 相连的边做reduction
                layer = Layer(c, stride)
                self.layers.append(layer)
Example #10
0
    def __init__(self, steps, multiplier, cpp, cp, c, reduction,
                 reduction_prev):
        """
        Each cell k takes input from last two cells k-2, k-1. The cell consists of `steps` so that on each step i,
        we take output of all previous i steps + 2 cell inputs, apply op on each of these outputs and produce their
        sum as output of i-th step.
        Each op output has c channels. The output of the cell is produced by forward() is concatenation of last
        `multiplier` number of layers. Cell could be a reduction cell or it could be a normal cell. The only
        diference between two is that reduction cell uses stride=2 for the ops that connects to cell inputs.

        :param steps: 4, number of layers inside a cell
        :param multiplier: 4, number of last nodes to concatenate as output, this will multiply number of channels in node
        :param cpp: 48, channels from cell k-2
        :param cp: 48, channels from cell k-1
        :param c: 16, output channels for each node
        :param reduction: indicates whether to reduce the output maps width
        :param reduction_prev: when previous cell reduced width, s1_d = s0_d//2
        in order to keep same shape between s1 and s0, we adopt prep0 layer to
        reduce the s0 width by half.
        """
        super(Cell, self).__init__()

        # indicating current cell is reduction or not
        self.reduction = reduction
        self.reduction_prev = reduction_prev

        # preprocess0 deal with output from prev_prev cell
        if reduction_prev:
            # if prev cell has reduced channel/double width,
            # it will reduce width by half
            self.preprocess0 = FactorizedReduce(cpp, c, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(cpp, c, 1, 1, 0, affine=False)
        # preprocess1 deal with output from prev cell
        self.preprocess1 = ReLUConvBN(cp, c, 1, 1, 0, affine=False)

        # steps inside a cell
        self.steps = steps  # 4
        self.multiplier = multiplier  # 4

        self.layers = nn.ModuleList()

        for i in range(self.steps):
            # for each i inside cell, it connects with all previous output
            # plus previous two cells' output
            for j in range(2 + i):
                # for reduction cell, it will reduce the heading 2 inputs only
                stride = 2 if reduction and j < 2 else 1
                layer = MixedLayer(c, stride)
                self.layers.append(layer)
Example #11
0
    def _init_nodes(self, op_cls):
        """
		Initialize nodes to create DAG with 2 input nodes come from 2 previous cell C[k-2] and C[k-1]
		"""
        self.node_ops = nn.ModuleList()
        if self.reduction_prev:
            self.node0 = FactorizedReduce(self.C_pp, self.C, affine=False)
        else:
            self.node0 = ReLUConvBN(self.C_pp, self.C, 1, 1, 0, affine=False)
        self.node1 = ReLUConvBN(self.C_p, self.C, 1, 1, 0, affine=False)

        for i in range(self.num_nodes):
            # Creating edges connect node `i` to other nodes `j`. `j < i`
            for j in range(2 + i):
                stride = 2 if self.reduction and j < 2 else 1
                op = op_cls(self.C, stride)
                self.node_ops.append(op)
Example #12
0
    def __init__(self, genotype, C_prev_prev, C_prev, C, reduction,
                 reduction_prev):
        super(Cell, self).__init__()

        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)

        if reduction:
            op_names, indices = zip(*genotype.reduce)
            concat = genotype.reduce_concat
        else:
            op_names, indices = zip(*genotype.normal)
            concat = genotype.normal_concat
        self._compile(C, op_names, indices, concat, reduction)
 def MNAS_sparsification(dag,
                         is_reduction,
                         fraction_edges_to_skip_connect=0.5):
     num_to_skip = int(len(dag.keys()) * fraction_edges_to_skip_connect)
     already_skip = []
     for key in dag.keys():
         if isinstance(dag[key], Identity) or isinstance(
                 dag[key], FactorizedReduce):
             already_skip.append(key)
     for key in random.sample(
         [key for key in dag.keys() if key not in already_skip],
             num_to_skip):
         if is_reduction and "0" in key:
             dag[key] = FactorizedReduce(dag[key].channels_out,
                                         dag[key].channels_out)
         else:
             dag[key] = Identity(dag[key].channels_out)
     return dag
Example #14
0
    def __init__(self, genotype, C_pp, C_p, C, reduction, reduction_prev,
                 dropout_rate):
        super(DerivedCell, self).__init__()
        self.reduction = reduction
        if reduction_prev:
            self.node0 = FactorizedReduce(C_pp, C)
        else:
            self.node0 = ReLUConvBN(C_pp, C, 1, 1, 0)
        self.node1 = ReLUConvBN(C_p, C, 1, 1, 0)
        self.dropout = nn.Dropout(dropout_rate)

        if reduction:
            dag = genotype.reduce
            concat = genotype.reduce_concat
        else:
            dag = genotype.normal
            concat = genotype.normal_concat
        self.num_nodes = len(dag)
        self.concat = concat
        self.ops, self.nodes = self._compile_dag(C, dag)
    def __init__(self, genotype, C_prev_prev, C_prev, C, reduction,
                 reduction_prev):
        super(Cell, self).__init__()
        # print(C_prev_prev, C_prev, C)

        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)

        if reduction:
            cells = len(genotype.normal) // 2
        else:
            cells = len(genotype.reduce) // 2
        concat = range(2, cells + 2)
        if reduction:
            op_names, indices = zip(*genotype.reduce)
        else:
            op_names, indices = zip(*genotype.normal)
        self._compile(C, op_names, indices, concat, reduction)
Example #16
0
    def __init__(self,
                 genotype_sequence,
                 concat_sequence,
                 C_prev_prev,
                 C_prev,
                 C,
                 reduction,
                 reduction_prev,
                 op_dict=None,
                 separate_reduce_cell=True,
                 C_mid=None):
        """Create a final cell with a single architecture.

    The Cell class in model_search.py is the equivalent for searching multiple architectures.

    # Arguments

      op_dict: The dictionary of possible operation creation functions.
        All primitive name strings defined in the genotype must be in the op_dict.
    """
        super(Cell, self).__init__()
        print(C_prev_prev, C_prev, C)
        self.reduction = reduction
        if op_dict is None:
            op_dict = operations.OPS
        # _op_dict are op_dict available for use,
        # _ops is the actual sequence of op_dict being utilized in this case
        self._op_dict = op_dict

        if reduction_prev is None:
            self.preprocess0 = operations.Identity()
        elif reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C, stride=2)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)

        op_names, indices = zip(*genotype_sequence)
        self._compile(C, op_names, indices, concat_sequence, reduction, C_mid)
Example #17
0
    def __init__(self, steps: int, multiplier: int, cpp: int, cp: int, c: int,
                 reduction: bool, reduction_prev: bool, height: int,
                 width: int, setting: AttLocation):
        """
        :param steps: 4, number of layers inside a cell
        :param multiplier: 4
        :param cpp: 48
        :param cp: 48
        :param c: 16
        :param reduction: indicates whether to reduce the output maps width
        :param reduction_prev: when previous cell reduced width, s1_d = s0_d//2
        in order to keep same shape between s1 and s0, we adopt prep0 layer to
        reduce the s0 width by half.
        """
        super(Cell, self).__init__()

        # indicating current cell is reduction or not
        self.reduction = reduction
        self.reduction_prev = reduction_prev

        self.setting = setting

        # preprocess0 deal with output from prev_prev cell
        if reduction_prev:
            # if prev cell has reduced channel/double width,
            # it will reduce width by half
            self.preprocess0 = FactorizedReduce(cpp, c, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(cpp, c, 1, 1, 0, affine=False)
        # preprocess1 deal with output from prev cell
        self.preprocess1 = ReLUConvBN(cp, c, 1, 1, 0, affine=False)

        # steps inside a cell
        self.steps = steps  # 4
        self.multiplier = multiplier  # 4

        self.layers = nn.ModuleList()

        for i in range(self.steps):
            # for each i inside cell, it connects with all previous output
            # plus previous two cells' output
            for j in range(2 + i):
                # for reduction cell, it will reduce the heading 2 inputs only
                stride = 2 if reduction and j < 2 else 1
                layer = MixedLayer(c, stride, height, width, setting)
                self.layers.append(layer)

        self.bottleneck_attns = nn.ModuleList()
        if setting in [AttLocation.END, AttLocation.AFTER_EVERY_AND_END]:
            for attn_primitive in ATTN_PRIMIVIVES:
                attn = ATTNS[attn_primitive](c * steps, height, width)
                self.bottleneck_attns.append(attn)

        elif setting in [
                AttLocation.AFTER_EVERY, AttLocation.NO_ATTENTION,
                AttLocation.MIXED_WITH_OPERATION, AttLocation.DOUBLE_MIXED
        ]:
            pass

        else:
            raise Exception('no match setting')
Example #18
0
def op(node_id, op_id, x_shape, channels, strides):  # x means feature maps
    br_op = None
    x_id_fact_reduce = None

    x_stride = strides if node_id in [
        0, 1
    ] else 1  ## ??? why set strides=1 when x_id not in [0, 1]

    if op_id == 0:
        br_op = SepConv(C_in=channels,
                        C_out=channels,
                        kernel_size=3,
                        strides=x_stride,
                        padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 1:
        br_op = SepConv(C_in=channels,
                        C_out=channels,
                        kernel_size=5,
                        strides=x_stride,
                        padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 2:
        br_op = layers.AveragePooling2D(pool_size=3,
                                        strides=x_stride,
                                        padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, x_shape[-1]]
    elif op_id == 3:
        br_op = layers.MaxPool2D(pool_size=3, strides=x_stride, padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, x_shape[-1]]
    elif op_id == 4:
        br_op = Identity()
        if x_stride > 1:
            assert x_stride == 2
            x_id_fact_reduce = FactorizedReduce(C_in=x_shape[-1],
                                                C_out=channels)
            x_shape = [
                x_shape[0] // x_stride, x_shape[1] // x_stride, channels
            ]
    elif op_id == 5:
        br_op = Identity()
        if x_stride > 1:
            assert x_stride == 2
            x_id_fact_reduce = FactorizedReduce(C_in=x_shape[-1],
                                                C_out=channels)
            x_shape = [
                x_shape[0] // x_stride, x_shape[1] // x_stride, channels
            ]
    elif op_id == 6:
        br_op = Conv(C_in=channels,
                     C_out=channels,
                     kernel_size=1,
                     strides=x_stride,
                     padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 7:
        br_op = Conv(C_in=channels,
                     C_out=channels,
                     kernel_size=3,
                     strides=x_stride,
                     padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 8:
        br_op = Conv(C_in=channels,
                     C_out=channels,
                     kernel_size=(1, 3),
                     strides=x_stride,
                     padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 9:
        br_op = Conv(C_in=channels,
                     C_out=channels,
                     kernel_size=(1, 7),
                     strides=x_stride,
                     padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 10:
        br_op = layers.MaxPool2D(pool_size=2, strides=x_stride, padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 11:
        br_op = layers.MaxPool2D(pool_size=3, strides=x_stride, padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 12:
        br_op = layers.MaxPool2D(pool_size=5, strides=x_stride, padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 13:
        br_op = layers.AveragePooling2D(pool_size=2,
                                        strides=x_stride,
                                        padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 14:
        br_op = layers.AveragePooling2D(pool_size=3,
                                        strides=x_stride,
                                        padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
    elif op_id == 15:
        br_op = layers.AveragePooling2D(pool_size=5,
                                        strides=x_stride,
                                        padding='same')
        x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]

    return br_op, x_shape, x_id_fact_reduce
Example #19
0
 def __init__(self, x_id, x_op, y_id, y_op, x_shape, y_shape, channels, stride=1, drop_path_keep_prob=None,
              layer_id=0, layers=0, steps=0):
     super(Node, self).__init__()
     self.channels = channels
     self.stride = stride
     self.drop_path_keep_prob = drop_path_keep_prob
     self.layer_id = layer_id
     self.layers = layers
     self.steps = steps
     self.x_id = x_id
     self.x_op_id = x_op
     self.x_id_fact_reduce = None
     self.y_id = y_id
     self.y_op_id = y_op
     self.y_id_fact_reduce = None
     x_shape = list(x_shape)
     y_shape = list(y_shape)
     
     x_stride = stride if x_id in [0, 1] else 1
     if x_op == 0:
         self.x_op = OPERATIONS[x_op](channels, channels, 3, x_stride, 1)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 1:
         self.x_op = OPERATIONS[x_op](channels, channels, 5, x_stride, 2)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 2:
         self.x_op = OPERATIONS[x_op](3, stride=x_stride, padding=1, count_include_pad=False)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, x_shape[-1]]
     elif x_op == 3:
         self.x_op = OPERATIONS[x_op](3, stride=x_stride, padding=1)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, x_shape[-1]]
     elif x_op == 4:
         self.x_op = OPERATIONS[x_op]()
         if x_stride > 1:
             assert x_stride == 2
             self.x_id_fact_reduce = FactorizedReduce(x_shape[-1], channels)
             x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 5:
         self.x_op = OPERATIONS_large[x_op]()
         if x_stride > 1:
             assert x_stride == 2
             self.x_id_fact_reduce = FactorizedReduce(x_shape[-1], channels)
             x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 6:
         self.x_op = OPERATIONS_large[x_op](channels, channels, 1, x_stride, 0)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 7:
         self.x_op = OPERATIONS_large[x_op](channels, channels, 3, x_stride, 1)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 8:
         self.x_op = OPERATIONS_large[x_op](channels, channels, (1,3), x_stride, ((0,1),(1,0)))
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 9:
         self.x_op = OPERATIONS_large[x_op](channels, channels, (1,7), x_stride, ((0,3),(3,0)))
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 10:
         self.x_op = OPERATIONS_large[x_op](2, stride=x_stride, padding=0)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 11:
         self.x_op = OPERATIONS_large[x_op](3, stride=x_stride, padding=1)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 12:
         self.x_op = OPERATIONS_large[x_op](5, stride=x_stride, padding=2)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 13:
         self.x_op = OPERATIONS_large[x_op](2, stride=x_stride, padding=0)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 14:
         self.x_op = OPERATIONS_large[x_op](3, stride=x_stride, padding=1)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     elif x_op == 15:
         self.x_op = OPERATIONS_large[x_op](5, stride=x_stride, padding=2)
         x_shape = [x_shape[0] // x_stride, x_shape[1] // x_stride, channels]
     
     y_stride = stride if y_id in [0, 1] else 1
     if y_op == 0:
         self.y_op = OPERATIONS[y_op](channels, channels, 3, y_stride, 1)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 1:
         self.y_op = OPERATIONS[y_op](channels, channels, 5, y_stride, 2)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 2:
         self.y_op = OPERATIONS[y_op](3, stride=y_stride, padding=1, count_include_pad=False)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, y_shape[-1]]
     elif y_op == 3:
         self.y_op = OPERATIONS[y_op](3, stride=y_stride, padding=1)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, y_shape[-1]]
     elif y_op == 4:
         self.y_op = OPERATIONS[y_op]()
         if y_stride > 1:
             assert y_stride == 2
             self.y_id_fact_reduce = FactorizedReduce(y_shape[-1], channels)
             y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 5:
         self.y_op = OPERATIONS_large[y_op]()
         if y_stride > 1:
             assert y_stride == 2
             self.y_id_fact_reduce = FactorizedReduce(y_shape[-1], channels)
             y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 6:
         self.y_op = OPERATIONS_large[y_op](channels, channels, 1, y_stride, 0)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 7:
         self.y_op = OPERATIONS_large[y_op](channels, channels, 3, y_stride, 1)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 8:
         self.y_op = OPERATIONS_large[y_op](channels, channels, (1,3), y_stride, ((0,1),(1,0)))
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 9:
         self.y_op = OPERATIONS_large[y_op](channels, channels, (1,7), y_stride, ((0,3),(3,0)))
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 10:
         self.y_op = OPERATIONS_large[y_op](2, stride=y_stride, padding=0)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 11:
         self.y_op = OPERATIONS_large[y_op](3, stride=y_stride, padding=1)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 12:
         self.y_op = OPERATIONS_large[y_op](5, stride=y_stride, padding=2)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 13:
         self.y_op = OPERATIONS_large[y_op](2, stride=y_stride, padding=0)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 14:
         self.y_op = OPERATIONS_large[y_op](3, stride=y_stride, padding=1)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     elif y_op == 15:
         self.y_op = OPERATIONS_large[y_op](5, stride=y_stride, padding=2)
         y_shape = [y_shape[0] // y_stride, y_shape[1] // y_stride, channels]
     
     assert x_shape[0] == y_shape[0] and x_shape[1] == y_shape[1]
     self.out_shape = list(x_shape)
    def create_dag(level: int,
                   alpha: Alpha,
                   alpha_dags: list,
                   primitives: dict,
                   channels_in_x1: int,
                   channels_in_x2=None,
                   channels=None,
                   is_reduction=False,
                   prev_reduction=False,
                   learnt_op=False,
                   input_stride=1):
        '''
    - Recursive funnction to create the computational dag from a given point.
    - Done in this manner to try and ensure that number of channels_in is correct for each operation.
    - Called with top-level dag parameters in the model.__init__ and recursively generates entire model
    - When using for learnt model extraction ensure that alpha_dags has only one alpha_dag in it
    - When using for weight sharing model training put all alpha_dags that you want shared in this
    '''

        # Initialize variables
        num_nodes = alpha.num_nodes_at_level[level]
        dag = {
        }  # from stringified tuple of edge -> nn.Module (to construct nn.ModuleDict from)

        for node_a in range(0, num_nodes - 1):
            '''
      Determine stride
      '''
            if (level == alpha.num_levels - 1 and is_reduction and node_a < 2):
                stride = 2
            elif (node_a == 0):
                stride = input_stride
            else:
                stride = 1
            '''
      Determine Pre-Processing If Necessary
      '''
            if alpha.num_levels - 1 == level:
                if prev_reduction:
                    dag[PREPROC_X] = FactorizedReduce(channels_in_x1,
                                                      channels,
                                                      affine=learnt_op)
                else:
                    dag[PREPROC_X] = ReLUConvBN(channels_in_x1,
                                                channels,
                                                1,
                                                1,
                                                0,
                                                affine=learnt_op)
                dag[PREPROC_X2] = ReLUConvBN(channels_in_x2,
                                             channels,
                                             1,
                                             1,
                                             0,
                                             affine=learnt_op)
            '''
      Determine Channels In
      '''
            if channels is None:
                channels = channels_in_x1
            '''
      Determine base set of operations 
      '''

            ###################
            # Select Operations
            ###################
            if learnt_op:
                chosen_ops = {}
                # Loop through all node_b >= node_a + offset to create mixed operation on every outgoing edge from node_a
                for node_b in range(node_a + 1, num_nodes):

                    # If input node at top level, then do not connect to output node
                    # If input node at top level, do not connect to other input node
                    if (level == alpha.num_levels -
                            1) and ((node_a < 2 and node_b == 1) or
                                    (node_b == num_nodes - 1)):
                        continue

                    # Determine Operation to Choose
                    edge = (node_a, node_b)
                    # If primitive level, then last op is zero - do not include
                    if level == 0:
                        alpha_candidates = alpha_dags[0][edge].cpu().detach(
                        )[:-1]
                    else:
                        alpha_candidates = alpha_dags[0][edge].cpu().detach()
                    chosen_ops[edge] = int(argmax(alpha_candidates))

                ops_to_create = sorted(set(chosen_ops.values()))

            else:
                ops_to_create = range(0, alpha.num_ops_at_level[level])

            base_operations = {}

            if level == 0:
                # Base case, do not need to recursively create operations at levels below
                primitives.update(
                    MANDATORY_OPS
                )  # Append mandatory ops: identity, zero to primitives
                for i, key in enumerate(primitives.keys()):
                    base_operations[i] = primitives[key](C=channels,
                                                         stride=stride,
                                                         affine=learnt_op)
            else:
                # Recursive case, use create_dag to create the list of operations
                if not learnt_op and level == alpha.num_levels - 1:
                    base_operations[0] = HierarchicalOperation.create_dag(
                        level=level - 1,
                        alpha=alpha,
                        alpha_dags=alpha.parameters[level - 1],
                        primitives=primitives,
                        channels_in_x1=channels,
                        input_stride=stride,
                        learnt_op=learnt_op)
                else:
                    for op_num in ops_to_create:
                        # Skip creation if zero op
                        base_operations[
                            op_num] = HierarchicalOperation.create_dag(
                                level=level - 1,
                                alpha=alpha,
                                alpha_dags=[
                                    alpha.parameters[level - 1][op_num]
                                ],
                                primitives=primitives,
                                channels_in_x1=channels,
                                input_stride=stride,
                                learnt_op=learnt_op)
            '''
      Create mixed operations / Place selected operations on outgoing edges for node_a
      '''
            # Loop through all node_b >= node_a + offset to create mixed operation on every outgoing edge from node_a
            for node_b in range(node_a + 1, num_nodes):

                # If input node at top level, then do not connect to output node
                # If input node at top level, do not connect to other input node
                if (level == alpha.num_levels - 1) and (
                    (node_a < 2 and node_b == 1) or (node_b == num_nodes - 1)):
                    continue

                # Create mixed operation / Select Learnt Operation on outgiong edge
                edge = (node_a, node_b)
                if not learnt_op:
                    dag[str(edge)] = MixedOperation(
                        base_operations,
                        [alpha_dag[edge] for alpha_dag in alpha_dags])
                else:
                    dag[str(edge)] = deepcopy(
                        base_operations[chosen_ops[edge]])
        '''        
    Return HierarchicalOperation created from dag
    '''
        if learnt_op:
            if alpha.num_levels == 1:  # DARTS SIM - TRAINING PHASE
                dag = HierarchicalOperation.darts_sparsification(
                    dag, alpha_dags[0], num_nodes)

        return HierarchicalOperation(alpha.num_nodes_at_level[level],
                                     dag,
                                     channels,
                                     level == alpha.num_levels - 1,
                                     learnt_op=learnt_op)
    def forward(self, x, x2=None, op_num=0, temp=None):
        '''
    Iteratively compute using each edge of the dag
    '''
        output = {}

        # Apply preprocessing if applicable
        if PREPROC_X in self.ops:
            x = self.ops[PREPROC_X].forward(x)
        if PREPROC_X2 in self.ops:
            x2 = self.ops[PREPROC_X2].forward(x2)

        for node_a in range(0, self.num_nodes - 1):
            # For a given edge, determine the input to the starting node
            if (node_a == 0):
                # for node_a = 0, it is trivial, input of entire module / first input
                input = x
            elif (node_a == 1 and type(x2) != type(None)):
                # if top level, then x2 provided then use for second node
                input = x2
            else:
                # otherwise it is the concatentation of the output of every edge (node, node_a)
                input = []
                for prev_node in range(0, node_a):
                    edge = str((prev_node, node_a))
                    if edge in output:
                        input.append(output[edge])
                input = sum(input)

            for node_b in range(node_a + 1, self.num_nodes):

                edge = str((node_a, node_b))

                # If edge shouldn't exist, skip it
                if (type(x2) != type(None)) and (
                        node_a < 2 and
                    (node_b == 1 or node_b == self.num_nodes - 1)):
                    continue
                elif (type(x2) != type(None)) and (
                        node_b == self.num_nodes -
                        1):  # if output collation edge, pass input as is
                    output[edge] = input
                elif edge not in self.ops:  # if edge removed in sparsification, skip
                    continue
                elif isinstance(self.ops[edge], MixedOperation):
                    output[edge] = self.ops[edge].forward(input,
                                                          op_num=op_num,
                                                          temp=temp)
                else:
                    # If not at top level maybe drop path, else don't
                    if self.learnt_op and (self.darts_sim or type(x2)
                                           == type(None)) and not isinstance(
                                               self.ops[edge], Identity):
                        output[edge] = drop_path(self.ops[edge].forward(input),
                                                 DROP_PROB)
                    else:
                        output[edge] = self.ops[edge].forward(input)

        # By extension, final output will be the concatenation of all inputs to the final node
        if type(x2) != type(None):  # if top level skip input nodes
            start_node = 2
        else:
            start_node = 0

        # Concatenate Output only if top level op
        if self.concatenate_output:
            return cat(tuple([
                output[str((prev_node, self.num_nodes - 1))]
                for prev_node in range(start_node, self.num_nodes - 1)
            ]),
                       dim=1)
        else:
            if output[str((0, self.num_nodes - 1))].shape[3] != x.shape[3]:
                x = FactorizedReduce(x.shape[1], x.shape[1]).cuda()(x)
            return sum([
                output[str((prev_node, self.num_nodes - 1))]
                for prev_node in range(start_node, self.num_nodes - 1)
            ]) + x