def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
        super(CellStem0, self).__init__()
        self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type)

        self.comb_iter_0_left = BranchSeparables(
            in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type)
        self.comb_iter_0_right = nn.Sequential(OrderedDict([
            ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)),
            ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)),
            ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)),
        ]))

        self.comb_iter_1_left = BranchSeparables(
            out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type)
        self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type)

        self.comb_iter_2_left = BranchSeparables(
            out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type)
        self.comb_iter_2_right = BranchSeparables(
            out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type)

        self.comb_iter_3_left = BranchSeparables(
            out_chs_right, out_chs_right, kernel_size=3, padding=pad_type)
        self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type)

        self.comb_iter_4_left = BranchSeparables(
            in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type)
        self.comb_iter_4_right = ActConvBn(
            out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type)
    def __init__(self,
                 in_chs_left,
                 out_chs_left,
                 in_chs_right,
                 out_chs_right,
                 pad_type=''):
        super(FirstCell, self).__init__()
        self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1)

        self.act = nn.ReLU()
        self.path_1 = nn.Sequential()
        self.path_1.add_module(
            'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
        self.path_1.add_module(
            'conv',
            nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))

        self.path_2 = nn.Sequential()
        self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
        self.path_2.add_module(
            'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
        self.path_2.add_module(
            'conv',
            nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))

        self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2,
                                            eps=0.001,
                                            momentum=0.1)

        self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right,
                                                 5, 1, pad_type)
        self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right,
                                                  3, 1, pad_type)

        self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right,
                                                 5, 1, pad_type)
        self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right,
                                                  3, 1, pad_type)

        self.comb_iter_2_left = create_pool2d('avg',
                                              3,
                                              1,
                                              count_include_pad=False,
                                              padding=pad_type)

        self.comb_iter_3_left = create_pool2d('avg',
                                              3,
                                              1,
                                              count_include_pad=False,
                                              padding=pad_type)
        self.comb_iter_3_right = create_pool2d('avg',
                                               3,
                                               1,
                                               count_include_pad=False,
                                               padding=pad_type)

        self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right,
                                                 3, 1, pad_type)
    def __init__(self, stem_size, num_channels=42, pad_type=''):
        super(CellStem0, self).__init__()
        self.num_channels = num_channels
        self.stem_size = stem_size
        self.conv_1x1 = ActConvBn(self.stem_size,
                                  self.num_channels,
                                  1,
                                  stride=1)

        self.comb_iter_0_left = BranchSeparables(self.num_channels,
                                                 self.num_channels, 5, 2,
                                                 pad_type)
        self.comb_iter_0_right = BranchSeparables(self.stem_size,
                                                  self.num_channels,
                                                  7,
                                                  2,
                                                  pad_type,
                                                  stem_cell=True)

        self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
        self.comb_iter_1_right = BranchSeparables(self.stem_size,
                                                  self.num_channels,
                                                  7,
                                                  2,
                                                  pad_type,
                                                  stem_cell=True)

        self.comb_iter_2_left = create_pool2d('avg',
                                              3,
                                              2,
                                              count_include_pad=False,
                                              padding=pad_type)
        self.comb_iter_2_right = BranchSeparables(self.stem_size,
                                                  self.num_channels,
                                                  5,
                                                  2,
                                                  pad_type,
                                                  stem_cell=True)

        self.comb_iter_3_right = create_pool2d('avg',
                                               3,
                                               1,
                                               count_include_pad=False,
                                               padding=pad_type)

        self.comb_iter_4_left = BranchSeparables(self.num_channels,
                                                 self.num_channels, 3, 1,
                                                 pad_type)
        self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
    def __init__(self,
                 in_chs_left,
                 out_chs_left,
                 in_chs_right,
                 out_chs_right,
                 pad_type=''):
        super(NormalCell, self).__init__()
        self.conv_prev_1x1 = ActConvBn(in_chs_left,
                                       out_chs_left,
                                       1,
                                       stride=1,
                                       padding=pad_type)
        self.conv_1x1 = ActConvBn(in_chs_right,
                                  out_chs_right,
                                  1,
                                  stride=1,
                                  padding=pad_type)

        self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right,
                                                 5, 1, pad_type)
        self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left,
                                                  3, 1, pad_type)

        self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5,
                                                 1, pad_type)
        self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left,
                                                  3, 1, pad_type)

        self.comb_iter_2_left = create_pool2d('avg',
                                              3,
                                              1,
                                              count_include_pad=False,
                                              padding=pad_type)

        self.comb_iter_3_left = create_pool2d('avg',
                                              3,
                                              1,
                                              count_include_pad=False,
                                              padding=pad_type)
        self.comb_iter_3_right = create_pool2d('avg',
                                               3,
                                               1,
                                               count_include_pad=False,
                                               padding=pad_type)

        self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right,
                                                 3, 1, pad_type)
    def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='',
                 is_reduction=False, match_prev_layer_dims=False):
        super(Cell, self).__init__()

        # If `is_reduction` is set to `True` stride 2 is used for
        # convolution and pooling layers to reduce the spatial size of
        # the output of a cell approximately by a factor of 2.
        stride = 2 if is_reduction else 1

        # If `match_prev_layer_dimensions` is set to `True`
        # `FactorizedReduction` is used to reduce the spatial size
        # of the left input of a cell approximately by a factor of 2.
        self.match_prev_layer_dimensions = match_prev_layer_dims
        if match_prev_layer_dims:
            self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type)
        else:
            self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)
        self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type)

        self.comb_iter_0_left = BranchSeparables(
            out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type)
        self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type)

        self.comb_iter_1_left = BranchSeparables(
            out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type)
        self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type)

        self.comb_iter_2_left = BranchSeparables(
            out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type)
        self.comb_iter_2_right = BranchSeparables(
            out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type)

        self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3)
        self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type)

        self.comb_iter_4_left = BranchSeparables(
            out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type)
        if is_reduction:
            self.comb_iter_4_right = ActConvBn(
                out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type)
        else:
            self.comb_iter_4_right = None
    def __init__(self, stem_size, num_channels, pad_type=''):
        super(CellStem1, self).__init__()
        self.num_channels = num_channels
        self.stem_size = stem_size
        self.conv_1x1 = ActConvBn(2 * self.num_channels,
                                  self.num_channels,
                                  1,
                                  stride=1)

        self.act = nn.ReLU()
        self.path_1 = nn.Sequential()
        self.path_1.add_module(
            'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
        self.path_1.add_module(
            'conv',
            nn.Conv2d(self.stem_size,
                      self.num_channels // 2,
                      1,
                      stride=1,
                      bias=False))

        self.path_2 = nn.Sequential()
        self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
        self.path_2.add_module(
            'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
        self.path_2.add_module(
            'conv',
            nn.Conv2d(self.stem_size,
                      self.num_channels // 2,
                      1,
                      stride=1,
                      bias=False))

        self.final_path_bn = nn.BatchNorm2d(self.num_channels,
                                            eps=0.001,
                                            momentum=0.1)

        self.comb_iter_0_left = BranchSeparables(self.num_channels,
                                                 self.num_channels, 5, 2,
                                                 pad_type)
        self.comb_iter_0_right = BranchSeparables(self.num_channels,
                                                  self.num_channels, 7, 2,
                                                  pad_type)

        self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
        self.comb_iter_1_right = BranchSeparables(self.num_channels,
                                                  self.num_channels, 7, 2,
                                                  pad_type)

        self.comb_iter_2_left = create_pool2d('avg',
                                              3,
                                              2,
                                              count_include_pad=False,
                                              padding=pad_type)
        self.comb_iter_2_right = BranchSeparables(self.num_channels,
                                                  self.num_channels, 5, 2,
                                                  pad_type)

        self.comb_iter_3_right = create_pool2d('avg',
                                               3,
                                               1,
                                               count_include_pad=False,
                                               padding=pad_type)

        self.comb_iter_4_left = BranchSeparables(self.num_channels,
                                                 self.num_channels, 3, 1,
                                                 pad_type)
        self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)