示例#1
0
 def get_flop(self, kernel_size, stride, in_channel, out_channel, affine,
              in_h, in_w):
     cal_flop = flops.count_ConvBNReLU_flop(in_h,
                                            in_w,
                                            in_channel,
                                            self.hidden_dim, [1, 1],
                                            False,
                                            is_affine=affine)
     cal_flop += flops.count_ConvBNReLU_flop(in_h,
                                             in_w,
                                             self.hidden_dim,
                                             self.hidden_dim,
                                             kernel_size,
                                             False,
                                             stride=stride,
                                             is_affine=affine)
     in_h, in_w = in_h // stride, in_w // stride
     cal_flop += flops.count_Conv_flop(in_h,
                                       in_w,
                                       self.hidden_dim,
                                       out_channel,
                                       kernel_size=[1, 1],
                                       is_bias=False)
     cal_flop += flops.count_BN_flop(in_h, in_w, out_channel, affine)
     return cal_flop
示例#2
0
 def get_flop(self, kernel_size, stride, in_channel, out_channel, affine,
              in_h, in_w):
     cal_flop = flops.count_Conv_flop(in_h,
                                      in_w,
                                      in_channel,
                                      in_channel,
                                      kernel_size,
                                      False,
                                      stride,
                                      groups=in_channel)
     in_h, in_w = in_h // stride, in_w // stride
     cal_flop += flops.count_ConvBNReLU_flop(in_h,
                                             in_w,
                                             in_channel,
                                             in_channel,
                                             kernel_size=[1, 1],
                                             is_bias=False,
                                             is_affine=affine)
     # stack 2 separate depthwise-conv.
     cal_flop += flops.count_Conv_flop(in_h,
                                       in_w,
                                       in_channel,
                                       in_channel,
                                       kernel_size,
                                       False,
                                       stride=1,
                                       groups=in_channel)
     cal_flop += flops.count_ConvBNReLU_flop(in_h,
                                             in_w,
                                             in_channel,
                                             in_channel,
                                             kernel_size=[1, 1],
                                             is_bias=False,
                                             is_affine=affine)
     # stack 3 separate depthwise-conv.
     cal_flop += flops.count_Conv_flop(in_h,
                                       in_w,
                                       in_channel,
                                       in_channel,
                                       kernel_size,
                                       False,
                                       stride=1,
                                       groups=in_channel)
     cal_flop += flops.count_Conv_flop(in_h,
                                       in_w,
                                       in_channel,
                                       out_channel,
                                       kernel_size=[1, 1],
                                       is_bias=False)
     cal_flop += flops.count_BN_flop(in_h, in_w, out_channel, affine)
     return cal_flop
示例#3
0
    def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
        super().__init__()
        # fmt: off
        self.in_features = cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
        feature_strides = {k: v.stride for k, v in input_shape.items()}  # noqa:F841
        feature_channels = {k: v.channels for k, v in input_shape.items()}
        feature_resolution = {
            k: np.array([v.height, v.width])
            for k, v in input_shape.items()
        }
        self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
        num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        norm = cfg.MODEL.SEM_SEG_HEAD.NORM
        self.loss_weight = cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT
        self.cal_flops = cfg.MODEL.CAL_FLOPS
        self.real_flops = 0.0
        # fmt: on

        self.layer_decoder_list = nn.ModuleList()
        # set affine in BatchNorm
        if 'Sync' in norm:
            affine = True
        else:
            affine = False
        # use simple decoder
        for _feat in self.in_features:
            res_size = feature_resolution[_feat]
            in_channel = feature_channels[_feat]
            if _feat == 'layer_0':
                out_channel = in_channel
            else:
                out_channel = in_channel // 2
            conv_1x1 = Conv2d(in_channel,
                              out_channel,
                              kernel_size=1,
                              stride=1,
                              padding=0,
                              bias=False,
                              norm=get_norm(norm, out_channel),
                              activation=nn.ReLU())
            self.real_flops += cal_op_flops.count_ConvBNReLU_flop(
                res_size[0],
                res_size[1],
                in_channel,
                out_channel, [1, 1],
                is_affine=affine)
            self.layer_decoder_list.append(conv_1x1)
        # using Kaiming init
        for layer in self.layer_decoder_list:
            weight_init.kaiming_init_module(layer, mode='fan_in')
        in_channel = feature_channels['layer_0']
        # the output layer
        self.predictor = Conv2d(in_channels=in_channel,
                                out_channels=num_classes,
                                kernel_size=3,
                                stride=1,
                                padding=1)
        self.real_flops += cal_op_flops.count_Conv_flop(
            feature_resolution['layer_0'][0], feature_resolution['layer_0'][1],
            in_channel, num_classes, [3, 3])
        # using Kaiming init
        weight_init.kaiming_init_module(self.predictor, mode='fan_in')
示例#4
0
    def __init__(self,
                 in_channels=3,
                 mid_channels=64,
                 out_channels=64,
                 input_res=None,
                 sept_stem=True,
                 norm="BN",
                 affine=True):
        """
        Build basic STEM for Dynamic Network.
        Args:
            norm (str or callable): a callable that takes the number of
                channels and return a `nn.Module`, or a pre-defined string
                (one of {"FrozenBN", "BN", "GN"}).
        """
        super().__init__()

        self.real_flops = 0.0
        # start with 3 stem layers down-sampling by 4.
        self.stem_1 = Conv2d(in_channels,
                             mid_channels,
                             kernel_size=3,
                             stride=2,
                             bias=False,
                             norm=get_norm(norm, mid_channels),
                             activation=nn.ReLU())
        self.real_flops += cal_op_flops.count_ConvBNReLU_flop(input_res[0],
                                                              input_res[1],
                                                              3,
                                                              mid_channels,
                                                              [3, 3],
                                                              stride=2,
                                                              is_affine=affine)
        # stem 2
        input_res = input_res // 2
        if not sept_stem:
            self.stem_2 = Conv2d(mid_channels,
                                 mid_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 bias=False,
                                 norm=get_norm(norm, mid_channels),
                                 activation=nn.ReLU())
            self.real_flops += cal_op_flops.count_ConvBNReLU_flop(
                input_res[0],
                input_res[1],
                mid_channels,
                mid_channels, [3, 3],
                is_affine=affine)
        else:
            self.stem_2 = nn.Sequential(
                Conv2d(mid_channels,
                       mid_channels,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       groups=mid_channels,
                       bias=False),
                Conv2d(mid_channels,
                       mid_channels,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, mid_channels),
                       activation=nn.ReLU()))
            self.real_flops += (
                cal_op_flops.count_Conv_flop(input_res[0],
                                             input_res[1],
                                             mid_channels,
                                             mid_channels, [3, 3],
                                             groups=mid_channels) +
                cal_op_flops.count_ConvBNReLU_flop(input_res[0],
                                                   input_res[1],
                                                   mid_channels,
                                                   mid_channels, [1, 1],
                                                   is_affine=affine))
        # stem 3
        if not sept_stem:
            self.stem_3 = Conv2d(mid_channels,
                                 out_channels,
                                 kernel_size=3,
                                 stride=2,
                                 padding=1,
                                 bias=False,
                                 norm=get_norm(norm, out_channels),
                                 activation=nn.ReLU())
            self.real_flops += cal_op_flops.count_ConvBNReLU_flop(
                input_res[0],
                input_res[1],
                mid_channels,
                out_channels, [3, 3],
                stride=2,
                is_affine=affine)
        else:
            self.stem_3 = nn.Sequential(
                Conv2d(mid_channels,
                       mid_channels,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       groups=mid_channels,
                       bias=False),
                Conv2d(mid_channels,
                       out_channels,
                       kernel_size=1,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, out_channels),
                       activation=nn.ReLU()))
            self.real_flops += (
                cal_op_flops.count_Conv_flop(input_res[0],
                                             input_res[1],
                                             mid_channels,
                                             mid_channels, [3, 3],
                                             stride=2,
                                             groups=mid_channels) +
                cal_op_flops.count_ConvBNReLU_flop(input_res[0] // 2,
                                                   input_res[1] // 2,
                                                   mid_channels,
                                                   out_channels, [1, 1],
                                                   is_affine=affine))
        self.out_res = input_res // 2
        self.out_cha = out_channels
        # using Kaiming init
        for layer in [self.stem_1, self.stem_2, self.stem_3]:
            weight_init.kaiming_init_module(layer, mode='fan_in')
    def __init__(self,
                 C_in,
                 C_out,
                 norm,
                 allow_up,
                 allow_down,
                 input_size,
                 cell_type,
                 cal_flops=True,
                 using_gate=False,
                 small_gate=False,
                 gate_bias=1.5,
                 affine=True):
        super(Cell, self).__init__()
        self.channel_in = C_in
        self.channel_out = C_out
        self.allow_up = allow_up
        self.allow_down = allow_down
        self.cal_flops = cal_flops
        self.using_gate = using_gate
        self.small_gate = small_gate

        self.cell_ops = Mixed_OP(inplanes=self.channel_in,
                                 outplanes=self.channel_out,
                                 stride=1,
                                 cell_type=cell_type,
                                 norm=norm,
                                 affine=affine,
                                 input_size=input_size)
        self.cell_flops = self.cell_ops.flops
        # resolution keep
        self.res_keep = nn.ReLU()
        self.res_keep_flops = cal_op_flops.count_ReLU_flop(
            input_size[0], input_size[1], self.channel_out)
        # resolution up and dim down
        if self.allow_up:
            self.res_up = nn.Sequential(
                nn.ReLU(),
                Conv2d(self.channel_out,
                       self.channel_out // 2,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, self.channel_out // 2),
                       activation=nn.ReLU()))
            # calculate Flops
            self.res_up_flops = cal_op_flops.count_ReLU_flop(
                input_size[0], input_size[1],
                self.channel_out) + cal_op_flops.count_ConvBNReLU_flop(
                    input_size[0],
                    input_size[1],
                    self.channel_out,
                    self.channel_out // 2, [1, 1],
                    is_affine=affine)
            # using Kaiming init
            weight_init.kaiming_init_module(self.res_up, mode='fan_in')
        # resolution down and dim up
        if self.allow_down:
            self.res_down = nn.Sequential(
                nn.ReLU(),
                Conv2d(self.channel_out,
                       2 * self.channel_out,
                       kernel_size=1,
                       stride=2,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, 2 * self.channel_out),
                       activation=nn.ReLU()))
            # calculate Flops
            self.res_down_flops = cal_op_flops.count_ReLU_flop(
                input_size[0], input_size[1],
                self.channel_out) + cal_op_flops.count_ConvBNReLU_flop(
                    input_size[0],
                    input_size[1],
                    self.channel_out,
                    2 * self.channel_out, [1, 1],
                    stride=2,
                    is_affine=affine)
            # using Kaiming init
            weight_init.kaiming_init_module(self.res_down, mode='fan_in')
        if self.allow_up and self.allow_down:
            self.gate_num = 3
        elif self.allow_up or self.allow_down:
            self.gate_num = 2
        else:
            self.gate_num = 1
        if self.using_gate:
            self.gate_conv_beta = nn.Sequential(
                Conv2d(self.channel_in,
                       self.channel_in // 2,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, self.channel_in // 2),
                       activation=nn.ReLU()), nn.AdaptiveAvgPool2d((1, 1)),
                Conv2d(self.channel_in // 2,
                       self.gate_num,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       bias=True))
            if self.small_gate:
                input_size = input_size // 4
            self.gate_flops = cal_op_flops.count_ConvBNReLU_flop(
                input_size[0],
                input_size[1],
                self.channel_in,
                self.channel_in // 2, [1, 1],
                is_affine=affine) + cal_op_flops.count_Pool2d_flop(
                    input_size[0], input_size[1], self.channel_in // 2,
                    [1, 1], 1) + cal_op_flops.count_Conv_flop(
                        1, 1, self.channel_in // 2, self.gate_num, [1, 1])
            # using Kaiming init and predefined bias for gate
            weight_init.kaiming_init_module(self.gate_conv_beta,
                                            mode='fan_in',
                                            bias=gate_bias)
        else:
            self.register_buffer('gate_weights_beta',
                                 torch.ones(1, self.gate_num, 1, 1).cuda())
            self.gate_flops = 0.0