Example #1
0
    def __init__(self):
        super(DSSNet, self).__init__()
        self.resnet = torchvision.models.resnet34(True)

        self.conv1 = nn.Sequential(
            self.resnet.conv1,
            self.resnet.bn1,
            self.resnet.relu)

        self.encode2 = nn.Sequential(self.resnet.layer1,
                                     SCSE(64))
        self.encode3 = nn.Sequential(self.resnet.layer2,
                                     SCSE(128))
        self.encode4 = nn.Sequential(self.resnet.layer3,
                                     SCSE(256))
        self.encode5 = nn.Sequential(self.resnet.layer4,
                                     SCSE(512))

        self.center = nn.Sequential(SeRes(512),
                                    SeRes(512),
                                    BasicConv2d(512, 256, 3, 1),
                                    nn.MaxPool2d(2, 2))

        self.decode5 = Decoderv2(256, 512, 64)
        self.decode4 = Decoderv2(64, 256, 64)
        self.decode3 = Decoderv2(64, 128, 64)
        self.decode2 = Decoderv2(64, 64, 64)
        self.decode1 = Decoder(64, 64)

        self.logit = nn.Sequential(nn.Conv2d(4, 1, kernel_size=1, bias=False))

        self.logit1 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1, bias=False))
        self.logit2 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1, bias=False))
        self.logit3 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1, bias=False))
        self.logit4 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1, bias=False))
Example #2
0
    def __init__(self):
        super(AUNet, self).__init__()
        self.xception = xception()

        self.encode1 = nn.Sequential(self.xception.conv1, self.xception.bn1,
                                     self.xception.relu1, self.xception.conv2,
                                     self.xception.bn2, self.xception.relu2)
        self.encode2 = self.xception.block1
        self.encode3 = self.xception.block2
        self.encode4 = self.xception.block3
        self.encode5 = nn.Sequential(
            self.xception.block4, self.xception.block5, self.xception.block6,
            self.xception.block7, self.xception.block8, self.xception.block9,
            self.xception.block10, self.xception.block11)

        self.center = nn.Sequential(SeRes(728), SeRes(728),
                                    BasicConv2d(728, 256, 3, 1),
                                    nn.MaxPool2d(2, 2))

        self.decode5 = Decoderv2(256, 728, 64)
        self.decode4 = Decoderv2(64, 256, 64)
        self.decode3 = Decoderv2(64, 128, 64)
        self.decode2 = Decoderv2(64, 64, 64)
        self.decode1 = Decoder(64, 64)

        self.logit = nn.Sequential(nn.Conv2d(4, 1, kernel_size=1, bias=False))

        self.logit1 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1,
                                              bias=False))
        self.logit2 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1,
                                              bias=False))
        self.logit3 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1,
                                              bias=False))
        self.logit4 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1,
                                              bias=False))
Example #3
0
 def __init__(self, x1_c, x2_c, out_c, pad = None):
     super(Decoderv2, self).__init__()
     self.ct_out = x1_c // 2
     self.x2_c = x2_c
     self.ct_conv = nn.ConvTranspose2d(x1_c, self.ct_out, 2, stride=2)
     self.conx2 = BasicConv2d(x2_c, self.ct_out, 3, 1)
     self.res1 = SeRes(x1_c)
     self.res2 = SeRes(x1_c)
     self.con = BasicConv2d(x1_c, out_c, 3, 1)
Example #4
0
    def __init__(self):
        super(AUNet, self).__init__()
        self.vgg = torchvision.models.vgg19(True)
        #128
        self.encode2 = nn.Sequential(self.vgg.features[0],
                                     self.vgg.features[1],
                                     self.vgg.features[2],
                                     self.vgg.features[3],
                                     self.vgg.features[4])
        #64
        self.encode3 = nn.Sequential(self.vgg.features[5],
                                     self.vgg.features[6],
                                     self.vgg.features[7],
                                     self.vgg.features[8],
                                     self.vgg.features[9])
        #32
        self.encode4 = nn.Sequential(
            self.vgg.features[10], self.vgg.features[11],
            self.vgg.features[12], self.vgg.features[13],
            self.vgg.features[14], self.vgg.features[15],
            self.vgg.features[16], self.vgg.features[17],
            self.vgg.features[18])
        #16
        self.encode5 = nn.Sequential(
            self.vgg.features[19], self.vgg.features[20],
            self.vgg.features[21], self.vgg.features[22],
            self.vgg.features[23], self.vgg.features[24],
            self.vgg.features[25], self.vgg.features[26],
            self.vgg.features[27])

        self.center = nn.Sequential(SeRes(512), SeRes(512),
                                    BasicConv2d(512, 256, 3, 1),
                                    nn.MaxPool2d(2, 2))
        #
        self.decode5 = Decoderv2(256, 512, 64)
        self.decode4 = Decoderv2(64, 256, 64)
        self.decode3 = Decoderv2(64, 128, 64)
        self.decode2 = Decoderv2(64, 64, 64)
        self.decode1 = Decoder(64, 64)

        self.logit = nn.Sequential(nn.Conv2d(4, 1, kernel_size=1, bias=False))

        self.logit1 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1,
                                              bias=False))
        self.logit2 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1,
                                              bias=False))
        self.logit3 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1,
                                              bias=False))
        self.logit4 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1,
                                              bias=False))
Example #5
0
 def __init__(self, x1_c, out_c):
     super(Decoder, self).__init__()
     self.ct_conv = nn.ConvTranspose2d(x1_c, out_c, 2, stride=2)
     self.res1 = SeRes(out_c, 8)
     self.res2 = SeRes(out_c, 8)