コード例 #1
0
    def __init__(self,
                 feature_scale=4,
                 n_classes=21,
                 is_deconv=True,
                 in_channels=3,
                 is_batchnorm=True):
        super(unetoriginal, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        # filters = [128, 256, 512, 1024, 2048]
        # filters = [256, 512, 1024, 2048,4096]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)

        # upsampling
        self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
        self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
        self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
        self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)

        # final conv (without any concat)
        self.final = nn.Conv2d(filters[0], n_classes, 1)
        self.drop = nn.Dropout(p=0.25)
コード例 #2
0
    def __init__(self,
                 feature_scale=4,
                 n_classes=21,
                 is_deconv=True,
                 in_channels=3,
                 is_batchnorm=True):
        super(unet, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        # filters = [128,256,512,1024,2048]
        filters = [int(x / self.feature_scale) for x in filters]
        print(filters)
        #print("unet_initial_mark1")
        # downsampling

        #TODO: DOUBLE THE BASE FILTER NUMBER
        #TODO: RESIDUAL CONNECTION
        #TODO: MOVE DROPOUT LAYERS

        #TODO: TUNE DOWN THE LEARNING RATE WHEN TRAINING LOSS FLUCTUATES
        # print("unet_initial_mark1.2")
        self.conv1 = unetfirstConv2(self.in_channels, filters[0],
                                    self.is_batchnorm)  # channel 3-16
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.half1 = nn.Conv2d(filters[0], filters[0], 2, 2, 0)
        # print("unet_initial_mark1.3")
        self.conv2 = unetConv2(filters[0], filters[1],
                               self.is_batchnorm)  #channel 16-32
        # print("unet_initial_mark1.4")
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.half2 = nn.Conv2d(filters[1], filters[1], 2, 2, 0)

        self.conv3 = unetConv2(filters[1], filters[2],
                               self.is_batchnorm)  #channel 32-64
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)
        self.half3 = nn.Conv2d(filters[2], filters[2], 2, 2, 0)

        self.conv4 = unetConv2(filters[2], filters[3],
                               self.is_batchnorm)  #channel 64-128
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)
        self.half4 = nn.Conv2d(filters[3], filters[3], 2, 2, 0)

        self.center = unetConv2(filters[3], filters[4],
                                self.is_batchnorm)  #channel 128-256

        # upsampling
        self.up_concat4 = unetUp(filters[4], filters[3],
                                 self.is_deconv)  #channel 256-128,problem
        self.up_concat3 = unetUp(filters[3], filters[2],
                                 self.is_deconv)  #channel 128-64
        self.up_concat2 = unetUp(filters[2], filters[1],
                                 self.is_deconv)  #channel 64-32
        self.up_concat1 = unetUp(filters[1], filters[0],
                                 self.is_deconv)  #channel 32-16

        # final conv (without any concat)
        self.final = nn.Conv2d(filters[0], n_classes,
                               1)  #channel 16-21,kernel 1x1
        #print("unet_initial_mark1.5")
        self.drop = nn.Dropout(p=0.3)
        self.shortcut1 = nn.Sequential(
            nn.Conv2d(self.in_channels, filters[0], 1, 1, 0),
            nn.BatchNorm2d(filters[0]))
        self.shortcut2 = nn.Sequential(
            nn.Conv2d(filters[0], filters[1], 1, 2, 0),
            nn.BatchNorm2d(filters[1]))
        self.shortcut3 = nn.Sequential(
            nn.Conv2d(filters[1], filters[2], 1, 2, 0),
            nn.BatchNorm2d(filters[2]))
        self.shortcut4 = nn.Sequential(
            nn.Conv2d(filters[2], filters[3], 1, 2, 0),
            nn.BatchNorm2d(filters[3]))
        self.shortcut_center = nn.Sequential(
            nn.Conv2d(filters[3], filters[4], 1, 2, 0),
            nn.BatchNorm2d(filters[4]))