コード例 #1
0
ファイル: resnet.py プロジェクト: younetcq/BSConv
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride):
        super().__init__()
        self.use_projection = (in_channels != out_channels) or (stride != 1)

        self.conv1 = conv3x3_block(in_channels=in_channels, out_channels=out_channels, stride=stride)
        self.conv2 = conv3x3_block(in_channels=out_channels, out_channels=out_channels, stride=1, activation=None)
        if self.use_projection:
            self.pojection = conv1x1_block(in_channels=in_channels, out_channels=out_channels, stride=stride, activation=None)
        self.relu = torch.nn.ReLU(inplace=True)
コード例 #2
0
ファイル: mobilenet.py プロジェクト: younetcq/BSConv
    def __init__(self,
                 num_classes,
                 init_conv_channels,
                 init_conv_stride,
                 channels,
                 mid_channels,
                 final_conv_channels,
                 strides,
                 in_channels=3,
                 in_size=(224, 224),
                 use_data_batchnorm=True):
        super().__init__()
        self.use_data_batchnorm = use_data_batchnorm
        self.in_size = in_size

        self.backbone = torch.nn.Sequential()

        # data batchnorm
        if self.use_data_batchnorm:
            self.backbone.add_module(
                "data_bn", torch.nn.BatchNorm2d(num_features=in_channels))

        # init conv
        self.backbone.add_module(
            "init_conv",
            conv3x3_block(in_channels=in_channels,
                          out_channels=init_conv_channels,
                          stride=init_conv_stride,
                          activation="relu6"))

        # stages
        in_channels = init_conv_channels
        for stage_id, stage_channels in enumerate(channels):
            stage = torch.nn.Sequential()
            for unit_id, unit_channels in enumerate(stage_channels):
                stride = strides[stage_id] if unit_id == 0 else 1
                mid_channel = mid_channels[stage_id][unit_id]
                stage.add_module(
                    "unit{}".format(unit_id + 1),
                    LinearBottleneck(in_channels=in_channels,
                                     mid_channels=mid_channel,
                                     out_channels=unit_channels,
                                     stride=stride))
                in_channels = unit_channels
            self.backbone.add_module("stage{}".format(stage_id + 1), stage)

        self.backbone.add_module(
            "final_conv",
            conv1x1_block(in_channels=in_channels,
                          out_channels=final_conv_channels,
                          activation="relu6"))
        self.backbone.add_module("global_pool",
                                 torch.nn.AdaptiveAvgPool2d(output_size=1))

        # classifier
        self.classifier = Classifier(in_channels=final_conv_channels,
                                     num_classes=num_classes)

        self.init_params()
コード例 #3
0
ファイル: resnet.py プロジェクト: GG-yuki/bugs
    def __init__(self, in_channels, out_channels, preact=False):
        super().__init__()

        self.conv = conv3x3_block(in_channels=in_channels,
                                  out_channels=out_channels,
                                  stride=1,
                                  use_bn=not preact,
                                  activation=None if preact else "relu")
コード例 #4
0
ファイル: resnet.py プロジェクト: GG-yuki/bugs
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.use_projection = (in_channels != out_channels) or (stride != 1)

        self.bn = torch.nn.BatchNorm2d(num_features=in_channels)
        self.relu = torch.nn.ReLU(inplace=True)
        self.conv1 = conv3x3_block(in_channels=in_channels,
                                   out_channels=out_channels,
                                   stride=stride)
        self.conv2 = conv3x3_block(in_channels=out_channels,
                                   out_channels=out_channels,
                                   use_bn=False,
                                   activation=None)
        if self.use_projection:
            self.projection = conv1x1_block(in_channels=in_channels,
                                            out_channels=out_channels,
                                            stride=stride,
                                            use_bn=False,
                                            activation=None)
コード例 #5
0
ファイル: mobilenet.py プロジェクト: younetcq/BSConv
    def __init__(self,
                 num_classes,
                 init_conv_channels,
                 init_conv_stride,
                 final_conv_channels,
                 final_conv_se,
                 channels,
                 mid_channels,
                 strides,
                 se_units,
                 kernel_sizes,
                 activations,
                 dropout_rate=0.0,
                 in_channels=3,
                 in_size=(224, 224),
                 use_data_batchnorm=True):
        super().__init__()
        self.use_data_batchnorm = use_data_batchnorm
        self.in_size = in_size
        self.dropout_rate = dropout_rate

        self.backbone = torch.nn.Sequential()

        # data batchnorm
        if self.use_data_batchnorm:
            self.backbone.add_module(
                "data_bn", torch.nn.BatchNorm2d(num_features=in_channels))

        # init conv
        self.backbone.add_module(
            "init_conv",
            conv3x3_block(in_channels=in_channels,
                          out_channels=init_conv_channels,
                          stride=init_conv_stride,
                          activation="hswish"))

        # stages
        in_channels = init_conv_channels
        for stage_id, stage_channels in enumerate(channels):
            stage = torch.nn.Sequential()
            for unit_id, unit_channels in enumerate(stage_channels):
                stride = strides[stage_id] if unit_id == 0 else 1
                mid_channel = mid_channels[stage_id][unit_id]
                use_se = se_units[stage_id][unit_id] == 1
                kernel_size = kernel_sizes[stage_id]
                activation = activations[stage_id]
                stage.add_module(
                    "unit{}".format(unit_id + 1),
                    LinearBottleneck(in_channels=in_channels,
                                     mid_channels=mid_channel,
                                     out_channels=unit_channels,
                                     stride=stride,
                                     activation=activation,
                                     use_se=use_se,
                                     kernel_size=kernel_size))
                in_channels = unit_channels
            self.backbone.add_module("stage{}".format(stage_id + 1), stage)

        self.backbone.add_module(
            "final_conv1",
            conv1x1_block(in_channels=in_channels,
                          out_channels=final_conv_channels[0],
                          activation="hswish"))
        in_channels = final_conv_channels[0]
        if final_conv_se:
            self.backbone.add_module(
                "final_se",
                SEUnit(channels=in_channels,
                       squeeze_factor=4,
                       squeeze_activation="relu",
                       excite_activation="hsigmoid"))
        self.backbone.add_module("final_pool",
                                 torch.nn.AdaptiveAvgPool2d(output_size=1))
        if len(final_conv_channels) > 1:
            self.backbone.add_module(
                "final_conv2",
                conv1x1_block(in_channels=in_channels,
                              out_channels=final_conv_channels[1],
                              activation="hswish",
                              use_bn=False))
            in_channels = final_conv_channels[1]
        if self.dropout_rate != 0.0:
            self.backbone.add_module("final_dropout",
                                     torch.nn.Dropout(dropout_rate))

        # classifier
        self.classifier = Classifier(in_channels=in_channels,
                                     num_classes=num_classes)

        self.init_params()