예제 #1
0
    def __init__(  # noqa:C901
        self, C_in, C_out, norm, allow_up, allow_down, input_size,
        cell_type, cal_flops=True, using_gate=False,
        small_gate=False, gate_bias=1.5, affine=True
    ):
        super(Cell, self).__init__()
        self.channel_in = C_in
        self.channel_out = C_out
        self.allow_up = allow_up
        self.allow_down = allow_down
        self.cal_flops = cal_flops
        self.using_gate = using_gate
        self.small_gate = small_gate

        self.cell_ops = Mixed_OP(
            inplanes=self.channel_in, outplanes=self.channel_out,
            stride=1, cell_type=cell_type, norm=norm,
            affine=affine, input_size=input_size
        )
        self.cell_flops = self.cell_ops.flops
        # resolution keep
        self.res_keep = nn.ReLU()
        self.res_keep_flops = cal_op_flops.count_ReLU_flop(
            input_size[0], input_size[1], self.channel_out
        )
        # resolution up and dim down
        if self.allow_up:
            self.res_up = nn.Sequential(
                nn.ReLU(),
                Conv2d(
                    self.channel_out, self.channel_out // 2, kernel_size=1,
                    stride=1, padding=0, bias=False,
                    norm=get_norm(norm, self.channel_out // 2),
                    activation=nn.ReLU()
                )
            )
            # calculate Flops
            self.res_up_flops = cal_op_flops.count_ReLU_flop(
                input_size[0], input_size[1], self.channel_out
            ) + cal_op_flops.count_ConvBNReLU_flop(
                input_size[0], input_size[1], self.channel_out,
                self.channel_out // 2, [1, 1], is_affine=affine
            )
            # using Kaiming init
            for m in self.res_up.modules():
                if isinstance(m, nn.Conv2d):
                    weight_init.kaiming_init(m, mode='fan_in')
                elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
                    if m.weight is not None:
                        nn.init.constant_(m.weight, 1)
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
        # resolution down and dim up
        if self.allow_down:
            self.res_down = nn.Sequential(
                nn.ReLU(),
                Conv2d(
                    self.channel_out, 2 * self.channel_out,
                    kernel_size=1, stride=2, padding=0, bias=False,
                    norm=get_norm(norm, 2 * self.channel_out),
                    activation=nn.ReLU()
                )
            )
            # calculate Flops
            self.res_down_flops = cal_op_flops.count_ReLU_flop(
                input_size[0], input_size[1], self.channel_out
            ) + cal_op_flops.count_ConvBNReLU_flop(
                input_size[0], input_size[1], self.channel_out,
                2 * self.channel_out, [1, 1], stride=2, is_affine=affine
            )
            # using Kaiming init
            for m in self.res_down.modules():
                if isinstance(m, nn.Conv2d):
                    weight_init.kaiming_init(m, mode='fan_in')
                elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
                    if m.weight is not None:
                        nn.init.constant_(m.weight, 1)
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
        if self.allow_up and self.allow_down:
            self.gate_num = 3
        elif self.allow_up or self.allow_down:
            self.gate_num = 2
        else:
            self.gate_num = 1
        if self.using_gate:
            self.gate_conv_beta = nn.Sequential(
                Conv2d(
                    self.channel_in, self.channel_in // 2, kernel_size=1,
                    stride=1, padding=0, bias=False,
                    norm=get_norm(norm, self.channel_in // 2),
                    activation=nn.ReLU()
                ),
                nn.AdaptiveAvgPool2d((1, 1)),
                Conv2d(
                    self.channel_in // 2, self.gate_num, kernel_size=1,
                    stride=1, padding=0, bias=True
                )
            )
            if self.small_gate:
                input_size = input_size // 4
            self.gate_flops = cal_op_flops.count_ConvBNReLU_flop(
                input_size[0], input_size[1], self.channel_in,
                self.channel_in // 2, [1, 1], is_affine=affine
            ) + cal_op_flops.count_Pool2d_flop(
                input_size[0], input_size[1], self.channel_in // 2, [1, 1], 1
            ) + cal_op_flops.count_Conv_flop(
                1, 1, self.channel_in // 2, self.gate_num, [1, 1]
            )
            # using Kaiming init and predefined bias for gate
            for m in self.gate_conv_beta.modules():
                if isinstance(m, nn.Conv2d):
                    weight_init.kaiming_init(m, mode='fan_in', bias=gate_bias)
                elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
                    if m.weight is not None:
                        nn.init.constant_(m.weight, 1)
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
        else:
            self.register_buffer(
                'gate_weights_beta', torch.ones(1, self.gate_num, 1, 1).cuda()
            )
            self.gate_flops = 0.0
예제 #2
0
 def get_flop(self, kernel_size, stride, out_channel, in_h, in_w):
     cal_flop = flops.count_Pool2d_flop(in_h, in_w, out_channel,
                                        kernel_size, stride)
     return cal_flop