예제 #1
0
    def __init__(self,
                 motion='GRU',
                 se_layer=False,
                 dilation=True,
                 basic_model='resnext50'):
        super(R3Net, self).__init__()

        self.motion = motion
        self.se_layer = se_layer
        self.dilation = dilation
        if basic_model == 'resnext50':
            resnext = ResNeXt50()
        elif basic_model == 'resnext101':
            resnext = ResNeXt101()
        elif basic_model == 'resnet50':
            resnext = ResNet50()
        else:
            resnext = ResNet101()
        self.layer0 = resnext.layer0
        self.layer1 = resnext.layer1
        self.layer2 = resnext.layer2
        self.layer3 = resnext.layer3
        self.layer4 = resnext.layer4

        self.reduce_low = nn.Sequential(
            nn.Conv2d(64 + 256 + 512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.PReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256),
            nn.PReLU(), nn.Conv2d(256, 256, kernel_size=1),
            nn.BatchNorm2d(256), nn.PReLU())
        self.reduce_high = nn.Sequential(
            nn.Conv2d(1024 + 2048, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.PReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256),
            nn.PReLU(), _ASPP(256))

        # self.motion_predict = nn.Conv2d(256, 1, kernel_size=1)

        if self.se_layer:
            self.reduce_high_se = SELayer(256)
            self.reduce_low_se = SELayer(256)
            # self.motion_se = SELayer(32)

        if dilation:
            resnext.layer3.apply(partial(self._nostride_dilate, dilate=2))
            resnext.layer4.apply(partial(self._nostride_dilate, dilate=4))

        for m in self.modules():
            if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
                m.inplace = True
예제 #2
0
    def __init__(self, motion='GRU', se_layer=False, attention=False):
        super(DSS, self).__init__()

        self.motion = motion
        self.se_layer = se_layer
        self.attention = attention

        resnext = ResNet101()

        self.layer0 = resnext.layer0
        self.layer1 = resnext.layer1
        self.layer2 = resnext.layer2
        self.layer3 = resnext.layer3
        self.layer4 = resnext.layer4

        self.dsn6 = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=7, padding=3), nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=7, padding=3), nn.ReLU(),
            nn.Conv2d(512, 1, kernel_size=1))

        self.dsn5 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(512, 1, kernel_size=1))

        self.dsn4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(256, 1, kernel_size=1))

        self.dsn4_fuse = nn.Conv2d(3, 1, kernel_size=1)

        self.dsn3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(256, 1, kernel_size=1))

        self.dsn3_fuse = nn.Conv2d(3, 1, kernel_size=1)

        self.dsn2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=1))

        self.dsn2_fuse = nn.Conv2d(5, 1, kernel_size=1)

        self.dsn_all_fuse = nn.Conv2d(5, 1, kernel_size=1)
예제 #3
0
    def __init__(self,
                 motion='GRU',
                 se_layer=False,
                 attention=False,
                 pre_attention=True):
        super(DSS, self).__init__()

        self.motion = motion
        self.se_layer = se_layer
        self.attention = attention
        self.pre_attention = pre_attention

        resnext = ResNet101()

        self.layer0 = resnext.layer0
        self.layer1 = resnext.layer1
        self.layer2 = resnext.layer2
        self.layer3 = resnext.layer3
        self.layer4 = resnext.layer4

        self.reduce_high = nn.Sequential(
            nn.Conv2d(1024 + 2048, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.PReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256),
            nn.PReLU(), _ASPP(256))

        if self.motion == 'GRU':
            self.reduce_high_motion = ConvGRU(input_size=(119, 119),
                                              input_dim=256,
                                              hidden_dim=128,
                                              kernel_size=(3, 3),
                                              num_layers=1,
                                              batch_first=True,
                                              bias=True,
                                              return_all_layers=False)
            # self.motion_predict = nn.Conv2d(256, 1, kernel_size=1)

        elif self.motion == 'LSTM':
            self.reduce_high_motion = ConvLSTM(input_size=(119, 119),
                                               input_dim=256,
                                               hidden_dim=32,
                                               kernel_size=(3, 3),
                                               num_layers=1,
                                               padding=1,
                                               dilation=1,
                                               batch_first=True,
                                               bias=True,
                                               return_all_layers=False)
        elif self.motion == 'no':
            self.reduce_high_motion = nn.Sequential(
                nn.Conv2d(256, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128), nn.PReLU(),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128), nn.PReLU(),
                nn.Conv2d(128, 32, kernel_size=1))

        self.predict1_motion = nn.Sequential(
            nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
            nn.PReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.PReLU(), nn.Conv2d(64, 1, kernel_size=1))
        self.predict2_motion = nn.Sequential(
            nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
            nn.PReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.PReLU(), nn.Conv2d(64, 1, kernel_size=1))
        self.predict3_motion = nn.Sequential(
            nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
            nn.PReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.PReLU(), nn.Conv2d(64, 1, kernel_size=1))
        self.predict4_motion = nn.Sequential(
            nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
            nn.PReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.PReLU(), nn.Conv2d(64, 1, kernel_size=1))

        if self.pre_attention:
            self.pre_sals_attention2 = SELayer(2, 1)
            self.pre_sals_attention3 = SELayer(3, 1)
            self.pre_sals_attention4 = SELayer(4, 1)

        self.dsn6 = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=7, padding=3), nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=7, padding=3), nn.ReLU(),
            nn.Conv2d(512, 1, kernel_size=1))

        self.dsn5 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(512, 1, kernel_size=1))

        self.dsn4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(256, 1, kernel_size=1))

        self.dsn4_fuse = nn.Conv2d(3, 1, kernel_size=1)

        self.dsn3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.Conv2d(256, 1, kernel_size=1))

        self.dsn3_fuse = nn.Conv2d(3, 1, kernel_size=1)

        self.dsn2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=1))

        self.dsn2_fuse = nn.Conv2d(5, 1, kernel_size=1)

        self.dsn_all_fuse = nn.Conv2d(5, 1, kernel_size=1)
예제 #4
0
    def __init__(self,
                 motion='GRU',
                 se_layer=False,
                 dilation=True,
                 basic_model='resnext50'):
        super(R3Net, self).__init__()

        self.motion = motion
        self.se_layer = se_layer
        self.dilation = dilation
        if basic_model == 'resnext50':
            resnext = ResNeXt50()
        elif basic_model == 'resnext101':
            resnext = ResNeXt101()
        elif basic_model == 'resnet50':
            resnext = ResNet50()
        else:
            resnext = ResNet101()
        self.layer0 = resnext.layer0
        self.layer1 = resnext.layer1
        self.layer2 = resnext.layer2
        self.layer3 = resnext.layer3
        self.layer4 = resnext.layer4

        self.reduce_low = nn.Sequential(
            nn.Conv2d(64 + 256 + 512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.PReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256),
            nn.PReLU(), nn.Conv2d(256, 256, kernel_size=1),
            nn.BatchNorm2d(256), nn.PReLU())
        self.reduce_high = nn.Sequential(
            nn.Conv2d(1024 + 2048, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.PReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256),
            nn.PReLU(), _ASPP(256))
        if self.motion == 'GRU':
            self.reduce_low_GRU = ConvGRU(input_size=(119, 119),
                                          input_dim=256,
                                          hidden_dim=256,
                                          kernel_size=(3, 3),
                                          num_layers=1,
                                          batch_first=True,
                                          bias=True,
                                          return_all_layers=False)

            self.reduce_high_GRU = ConvGRU(input_size=(119, 119),
                                           input_dim=256,
                                           hidden_dim=256,
                                           kernel_size=(3, 3),
                                           num_layers=1,
                                           batch_first=True,
                                           bias=True,
                                           return_all_layers=False)
            # self.motion_predict = nn.Conv2d(256, 1, kernel_size=1)

        elif self.motion == 'LSTM':
            # self.reduce_low_GRU = ConvLSTM(input_size=(119, 119), input_dim=256,
            #                               hidden_dim=256,
            #                               kernel_size=(3, 3),
            #                               num_layers=1,
            #                               padding=1,
            #                               dilation=1,
            #                               batch_first=True,
            #                               bias=True,
            #                               return_all_layers=False)

            self.reduce_high_GRU = ConvLSTM(input_size=(119, 119),
                                            input_dim=256,
                                            hidden_dim=256,
                                            kernel_size=(3, 3),
                                            num_layers=1,
                                            padding=1,
                                            dilation=1,
                                            batch_first=True,
                                            bias=True,
                                            return_all_layers=False)
            # self.motion_predict = nn.Conv2d(256, 1, kernel_size=1)

        if self.se_layer:
            self.reduce_high_se = SELayer(256)
            self.reduce_low_se = SELayer(256)
            # self.motion_se = SELayer(32)

        if dilation:
            resnext.layer3.apply(partial(self._nostride_dilate, dilate=2))
            resnext.layer4.apply(partial(self._nostride_dilate, dilate=4))

        for m in self.modules():
            if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
                m.inplace = True
예제 #5
0
    def __init__(self,
                 motion='GRU',
                 se_layer=False,
                 attention=False,
                 pre_attention=False,
                 isTriplet=False,
                 basic_model='resnext50',
                 sta=False,
                 naive_fuse=False):
        super(R3Net_prior, self).__init__()

        self.motion = motion
        self.se_layer = se_layer
        self.attention = attention
        self.pre_attention = pre_attention
        self.isTriplet = isTriplet
        self.sta = sta
        self.naive_fuse = naive_fuse

        if basic_model == 'resnext50':
            resnext = ResNeXt50()
        elif basic_model == 'resnext101':
            resnext = ResNeXt101()
        elif basic_model == 'resnet50':
            resnext = ResNet50()
        else:
            resnext = ResNet101()
        self.layer0 = resnext.layer0
        self.layer1 = resnext.layer1
        self.layer2 = resnext.layer2
        self.layer3 = resnext.layer3
        self.layer4 = resnext.layer4

        self.reduce_low = nn.Sequential(
            nn.Conv2d(64 + 256 + 512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.PReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256),
            nn.PReLU(), nn.Conv2d(256, 256, kernel_size=1),
            nn.BatchNorm2d(256), nn.PReLU())
        self.reduce_high = nn.Sequential(
            nn.Conv2d(1024 + 2048, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.PReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256),
            nn.PReLU(), _ASPP(256))
        if self.motion == 'GRU':
            self.reduce_high_motion = ConvGRU(input_size=(119, 119),
                                              input_dim=256,
                                              hidden_dim=64,
                                              kernel_size=(3, 3),
                                              num_layers=1,
                                              batch_first=True,
                                              bias=True,
                                              return_all_layers=False)
            # self.motion_predict = nn.Conv2d(256, 1, kernel_size=1)

        elif self.motion == 'LSTM':
            self.reduce_high_motion = ConvLSTM(input_size=(119, 119),
                                               input_dim=256,
                                               hidden_dim=64,
                                               kernel_size=(3, 3),
                                               num_layers=1,
                                               padding=1,
                                               dilation=1,
                                               batch_first=True,
                                               bias=True,
                                               return_all_layers=False)
        elif self.motion == 'no':
            self.reduce_high_motion = nn.Sequential(
                nn.Conv2d(256, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128), nn.PReLU(),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128), nn.PReLU(),
                nn.Conv2d(128, 32, kernel_size=1))
            # self.motion_predict = nn.Conv2d(256, 1, kernel_size=1)

        if self.se_layer:
            self.reduce_high_se = SELayer(256)
            # self.reduce_low_se = SELayer(256)
            self.motion_se = SELayer(32)

        if self.attention:
            self.reduce_atte = BaseOC_Context_Module(256,
                                                     256,
                                                     128,
                                                     128,
                                                     0.05,
                                                     sizes=([2]))

        if self.pre_attention:
            self.pre_sals_attention2 = SELayer(2, 1)
            self.pre_sals_attention3 = SELayer(3, 1)
            self.pre_sals_attention4 = SELayer(4, 1)

        if self.sta:
            self.sta_module = STA_Module(64)
            self.sp_down = nn.Sequential(nn.Conv2d(256, 64, kernel_size=1),
                                         nn.PReLU())

        if self.naive_fuse:
            self.sp_down = nn.Sequential(nn.Conv2d(256, 64, kernel_size=1),
                                         nn.PReLU())
            # self.sp_down2 = nn.Sequential(
            #     nn.Conv2d(128, 64, kernel_size=1), nn.PReLU()
            # )

        self.predict0 = nn.Conv2d(256, 1, kernel_size=1)
        self.predict1 = nn.Sequential(
            nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128),
            nn.PReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.PReLU(), nn.Conv2d(128, 1, kernel_size=1))
        self.predict2 = nn.Sequential(
            nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128),
            nn.PReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.PReLU(), nn.Conv2d(128, 1, kernel_size=1))
        self.predict3 = nn.Sequential(
            nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128),
            nn.PReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.PReLU(), nn.Conv2d(128, 1, kernel_size=1))
        self.predict4 = nn.Sequential(
            nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128),
            nn.PReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.PReLU(), nn.Conv2d(128, 1, kernel_size=1))
        self.predict5 = nn.Sequential(
            nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128),
            nn.PReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.PReLU(), nn.Conv2d(128, 1, kernel_size=1))
        self.predict6 = nn.Sequential(
            nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128),
            nn.PReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.PReLU(), nn.Conv2d(128, 1, kernel_size=1))

        self.predict1_motion = nn.Sequential(
            nn.Conv2d(65, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32),
            nn.PReLU(), nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32), nn.PReLU(), nn.Conv2d(32, 1, kernel_size=1))
        self.predict2_motion = nn.Sequential(
            nn.Conv2d(65, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32),
            nn.PReLU(), nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32), nn.PReLU(), nn.Conv2d(32, 1, kernel_size=1))
        self.predict3_motion = nn.Sequential(
            nn.Conv2d(65, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32),
            nn.PReLU(), nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32), nn.PReLU(), nn.Conv2d(32, 1, kernel_size=1))
        self.predict4_motion = nn.Sequential(
            nn.Conv2d(65, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32),
            nn.PReLU(), nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32), nn.PReLU(), nn.Conv2d(32, 1, kernel_size=1))

        for m in self.modules():
            if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
                m.inplace = True
예제 #6
0
    def __init__(self,
                 motion='GRU',
                 se_layer=False,
                 attention=False,
                 dilation=True,
                 basic_model='resnext50'):
        super(R3Net, self).__init__()

        self.motion = motion
        self.se_layer = se_layer
        self.attention = attention
        self.dilation = dilation
        if basic_model == 'resnext50':
            resnext = ResNeXt50()
        elif basic_model == 'resnext101':
            resnext = ResNeXt101()
        elif basic_model == 'resnet50':
            resnext = ResNet50()
        else:
            resnext = ResNet101()

        if dilation:
            resnext.layer3.apply(partial(self._nostride_dilate, dilate=2))
            resnext.layer4.apply(partial(self._nostride_dilate, dilate=4))

        self.layer0 = resnext.layer0
        self.layer1 = resnext.layer1
        self.layer2 = resnext.layer2
        self.layer3 = resnext.layer3
        self.layer4 = resnext.layer4

        self.reduce_low = nn.Sequential(
            nn.Conv2d(64 + 256 + 512, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.PReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128),
            nn.PReLU(), nn.Conv2d(128, 128, kernel_size=1),
            nn.BatchNorm2d(128), nn.PReLU())

        self.reduce_high = nn.Sequential(
            nn.Conv2d(1024 + 2048, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.PReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.PReLU(),
            # _ASPP(128)
        )

        inter_channels = 512 // 4
        self.conv5a = nn.Sequential(
            nn.Conv2d(512, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels), nn.ReLU())

        self.conv5c = nn.Sequential(
            nn.Conv2d(512, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels), nn.ReLU())

        self.sa = PAM_Module(inter_channels)
        self.sc = CAM_Module(inter_channels)

        self.conv51 = nn.Sequential(
            nn.Conv2d(inter_channels, inter_channels, 3, padding=1,
                      bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
        self.conv52 = nn.Sequential(
            nn.Conv2d(inter_channels, inter_channels, 3, padding=1,
                      bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())

        if self.motion == 'GRU':
            # self.reduce_low_GRU = ConvGRU(input_size=(119, 119), input_dim=256,
            #                          hidden_dim=256,
            #                          kernel_size=(3, 3),
            #                          num_layers=1,
            #                          batch_first=True,
            #                          bias=True,
            #                          return_all_layers=False)

            self.reduce_high_motion = ConvGRU(input_size=(119, 119),
                                              input_dim=128,
                                              hidden_dim=128,
                                              kernel_size=(3, 3),
                                              num_layers=1,
                                              batch_first=True,
                                              bias=True,
                                              return_all_layers=False)
            # self.motion_predict = nn.Conv2d(256, 1, kernel_size=1)
        elif self.motion == 'GGNN':
            self.graph_module = GGNN(5, 1, 1, 3, 1)

        # self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, 1, 1))
        # self.conv7 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, 1, 1))
        # self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, 1, 1))

        self.predict0 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict1 = nn.Sequential(
            nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
            nn.PReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.PReLU(), nn.Conv2d(64, 1, kernel_size=1))
        self.predict2 = nn.Sequential(
            nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
            nn.PReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.PReLU(), nn.Conv2d(64, 1, kernel_size=1))
        # self.predict3 = nn.Sequential(
        #     nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
        #     nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
        #     nn.Conv2d(64, 1, kernel_size=1)
        # )
        # self.predict3 = nn.Sequential(
        #     nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
        #     nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
        #     nn.Conv2d(128, 1, kernel_size=1)
        # )
        # self.predict4 = nn.Sequential(
        #     nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
        #     nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
        #     nn.Conv2d(128, 1, kernel_size=1)
        # )
        # self.predict5 = nn.Sequential(
        #     nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
        #     nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
        #     nn.Conv2d(128, 1, kernel_size=1)
        # )
        # self.predict6 = nn.Sequential(
        #     nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
        #     nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
        #     nn.Conv2d(128, 1, kernel_size=1)
        # )

        for m in self.modules():
            if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
                m.inplace = True