示例#1
0
    def forward(self, x1, x2):

        x = torch.cat((x1, x2), 1)

        """Forward method."""
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(x))))
        x12 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43, kernel_size=2, stride=2)


        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), x43), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), x33), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), x22), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), x12), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)

        return x11d
        # return self.sm(x11d)
示例#2
0
    def forward(self, x1, x2):

        x = torch.cat((x1, x2), 1)

        # print(x.shape)
        # print((self.bn))
        x = self.bn(x)
        #         pad5 = ReplicationPad2d((0, x53.size(3) - x5d.size(3), 0, x53.size(2) - x5d.size(2)))

        s1_1 = x.size()
        x1 = self.encres1_1(x)
        x = self.encres1_2(x1)

        s2_1 = x.size()
        x2 = self.encres2_1(x)
        x = self.encres2_2(x2)

        s3_1 = x.size()
        x3 = self.encres3_1(x)
        x = self.encres3_2(x3)

        s4_1 = x.size()
        x4 = self.encres4_1(x)
        x = self.encres4_2(x4)

        x = self.decres4_1(x)
        x = self.decres4_2(x)
        s4_2 = x.size()
        pad4 = ReplicationPad2d((0, s4_1[3] - s4_2[3], 0, s4_1[2] - s4_2[2]))
        x = pad4(x)

        # x = self.decres3_1(x)
        x = self.decres3_1(torch.cat((x, x4), 1))
        x = self.decres3_2(x)
        s3_2 = x.size()
        pad3 = ReplicationPad2d((0, s3_1[3] - s3_2[3], 0, s3_1[2] - s3_2[2]))
        x = pad3(x)

        x = self.decres2_1(torch.cat((x, x3), 1))
        x = self.decres2_2(x)
        s2_2 = x.size()
        pad2 = ReplicationPad2d((0, s2_1[3] - s2_2[3], 0, s2_1[2] - s2_2[2]))
        x = pad2(x)

        x = self.decres1_1(torch.cat((x, x2), 1))
        x = self.decres1_2(x)
        s1_2 = x.size()
        pad1 = ReplicationPad2d((0, s1_1[3] - s1_2[3], 0, s1_1[2] - s1_2[2]))
        x = pad1(x)

        x = self.coupling(torch.cat((x, x1), 1))
        x = self.sm(x)

        return x
示例#3
0
    def forward(self, data, n_branches, extract_features=None, **kwargs):      
        layer1 = list()
        layer2 = list()
        layer3 = list()
        res = list()
        for i in range(n_branches): # Siamese/triplet nets; sharing weights
            x = data[i]
            x1 = self.relu1(self.bn1(self.conv1(x)))
            x = self.maxpool1(x1)
            x2 = self.relu2(self.bn2(self.conv2(x)))
            x = self.maxpool2(x2)
            x3 = self.relu3(self.bn3(self.conv3(x)))
            x = self.maxpool3(x3)
            
            layer1.append(x1)
            layer2.append(x2)
            layer3.append(x3)
            res.append(x)
        
        x = torch.abs(res[1] - res[0])
        if n_branches == 3:
            x = torch.cat(x, torch.abs(res[2] - res[1]), 1)

        x = self.relu4(self.bn4(self.conv4(x)))
        if extract_features == 'joint':
            return(x)
        
        x = self.convT5(x)
        pad = ReplicationPad2d((0, layer3[0].shape[3] - x.shape[3], 0, layer3[0].shape[2] - x.shape[2]))
        x = torch.cat([pad(x), torch.abs(layer3[1]-layer3[0])], dim=1)
        x = self.relu5(self.bn5(self.conv5(x)))
        
        x = self.convT6(x)
        pad = ReplicationPad2d((0, layer2[0].shape[3] - x.shape[3], 0, layer2[0].shape[2] - x.shape[2]))
        x = torch.cat([pad(x), torch.abs(layer2[1]-layer2[0])], dim=1)
        x = self.relu6(self.bn6(self.conv6(x)))
        
        x = self.convT7(x)
        pad = ReplicationPad2d((0, layer1[0].shape[3] - x.shape[3], 0, layer1[0].shape[2] - x.shape[2]))
        x = torch.cat([pad(x), torch.abs(layer1[1]-layer1[0])], dim=1)
        x = self.relu7(self.bn7(self.conv7(x)))

        if extract_features == 'last':
            return(x)
        
        x = torch.flatten(x, 1)
        x = self.relu8(self.linear1(x))
        x = self.relu9(self.linear2(x))
        x = self.linear3(x)
        
        return x
    def forward(self, x1, x2):
        """Forward method."""
        # Stage 1
        x11_1 = self.do11(F.relu(self.bn11(self.conv11(x1))))
        #print(x11.shape)
        x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11_1))))
        #print(x12_1.shape)
        x1p_1 = F.max_pool2d(x12_1, kernel_size=2, stride=2)
        #print(x1p.shape)

        # Stage 2
        x21_1 = self.do21(F.relu(self.bn21(self.conv21(x1p_1))))
        x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21_1))))
        x2p_1 = F.max_pool2d(x22_1, kernel_size=2, stride=2)

        # Stage 3
        x31_1 = self.do31(F.relu(self.bn31(self.conv31(x2p_1))))
        x32_1 = self.do32(F.relu(self.bn32(self.conv32(x31_1))))
        x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32_1))))
        x3p_1 = F.max_pool2d(x33_1, kernel_size=2, stride=2)

        # Stage 4
        x41_1 = self.do41(F.relu(self.bn41(self.conv41(x3p_1))))
        x42_1 = self.do42(F.relu(self.bn42(self.conv42(x41_1))))
        x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42_1))))
        x4p_1 = F.max_pool2d(x43_1, kernel_size=2, stride=2)

        ####################################################
        # Stage 1
        x11_2 = self.do11(F.relu(self.bn11(self.conv11(x2))))
        x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11_2))))
        x1p_2 = F.max_pool2d(x12_2, kernel_size=2, stride=2)

        # Stage 2
        x21_2 = self.do21(F.relu(self.bn21(self.conv21(x1p_2))))
        x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21_2))))
        x2p_2 = F.max_pool2d(x22_2, kernel_size=2, stride=2)

        # Stage 3
        x31_2 = self.do31(F.relu(self.bn31(self.conv31(x2p_2))))
        x32_2 = self.do32(F.relu(self.bn32(self.conv32(x31_2))))
        x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32_2))))
        x3p_2 = F.max_pool2d(x33_2, kernel_size=2, stride=2)

        # Stage 4
        x41_2 = self.do41(F.relu(self.bn41(self.conv41(x3p_2))))
        x42_2 = self.do42(F.relu(self.bn42(self.conv42(x41_2))))
        x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42_2))))
        x4p_2 = F.max_pool2d(x43_2, kernel_size=2, stride=2)

        ####################################################
        # Stage 4d
        x4d = self.upconv4(x4p_2)
        pad4 = ReplicationPad2d(
            (0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d(
            (0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d(
            (0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d(
            (0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)

        return self.sm(x11d)
示例#5
0
    def forward(self, data, n_branches, extract_features=None):
        """Forward method."""
        
        x1 = data[0]
        x2 = data[1]
        
# =============================================================================
#         # Stage 1
#         x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
#         x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
#         x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
# =============================================================================

        # Stage 1
        x11 = F.relu(self.bn11(self.conv11(x1)))
        x12_1 = F.relu(self.bn12(self.conv12(x11)))
        x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)

        # Stage 2
        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22_1 = F.relu(self.bn22(self.conv22(x21)))
        x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)

        # Stage 3
        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 = F.relu(self.bn32(self.conv32(x31)))
        x33_1 = F.relu(self.bn33(self.conv33(x32)))
        x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)

        # Stage 4
        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43_1 = F.relu(self.bn43(self.conv43(x42)))
        x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)

        ####################################################
        # Stage 1
        x11 = F.relu(self.bn11(self.conv11(x2)))
        x12_2 = F.relu(self.bn12(self.conv12(x11)))
        x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)


        # Stage 2
        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22_2 = F.relu(self.bn22(self.conv22(x21)))
        x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)

        # Stage 3
        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 =F.relu(self.bn32(self.conv32(x31)))
        x33_2 = F.relu(self.bn33(self.conv33(x32)))
        x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)

        # Stage 4
        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43_2 = F.relu(self.bn43(self.conv43(x42)))
        x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)



        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1)
        x43d = F.relu(self.bn43d(self.conv43d(x4d)))
        x42d = F.relu(self.bn42d(self.conv42d(x43d)))
        x41d = F.relu(self.bn41d(self.conv41d(x42d)))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1)
        x33d = F.relu(self.bn33d(self.conv33d(x3d)))
        x32d = F.relu(self.bn32d(self.conv32d(x33d)))
        x31d = F.relu(self.bn31d(self.conv31d(x32d)))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1)
        x22d = F.relu(self.bn22d(self.conv22d(x2d)))
        x21d = F.relu(self.bn21d(self.conv21d(x22d)))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1)
        x12d = F.relu(self.bn12d(self.conv12d(x1d)))
        x11d = self.conv11d(x12d)

        return x11d #self.sm(x11d)
示例#6
0
    def forward(self, data, n_branches, extract_features=None, **kwargs):

        res = list()
        end = list()
        for i in range(n_branches):  # Siamese/triplet nets; sharing weights
            # encoding
            x = data[i]
            x1 = self.relu1a(self.bn1a(self.conv1a(x)))
            #x1b = self.relu1b(self.bn1b(self.conv1b(x1a)))
            x = self.maxpool1(x1)
            x2 = self.relu2a(self.bn2a(self.conv2a(x)))
            #x2b = self.relu2b(self.bn2b(self.conv2b(x2a)))
            x = self.maxpool2(x2)
            x3a = self.relu3a(self.bn3a(self.conv3a(x)))
            x3 = self.relu3b(self.bn3b(self.conv3b(x3a)))
            #x3c = self.relu3c(self.bn3c(self.conv3c(x3b)))
            x = self.maxpool3(x3)
            x4a = self.relu4a(self.bn4a(self.conv4a(x)))
            x4 = self.relu4b(self.bn4b(self.conv4b(x4a)))
            #x4c = self.relu4c(self.bn4c(self.conv4c(x4b)))
            x = self.maxpool4(x4)
            xb = self.relu_bottle(self.bn_bottle(self.conv_bottle(x)))

            res.append(xb)

            # decoding
            x4d = self.upconv4(xb)
            pad4 = ReplicationPad2d(
                (0, x4.size(3) - x4d.size(3), 0, x4.size(2) - x4d.size(2)))
            x4d = torch.cat((pad4(x4d), x4), 1)
            x4d = self.relu4cd(self.bn4cd(self.conv4cd(x4d)))
            #x4bd = self.relu4bd(self.bn4bd(self.conv4bd(x4cd)))
            x4d = self.relu4ad(self.bn4ad(self.conv4ad(x4d)))

            # Stage 3d
            x3d = self.upconv3(x4d)
            pad3 = ReplicationPad2d(
                (0, x3.size(3) - x3d.size(3), 0, x3.size(2) - x3d.size(2)))
            x3d = torch.cat((pad3(x3d), x3), 1)
            x3d = self.relu3cd(self.bn3cd(self.conv3cd(x3d)))
            #x3bd = self.relu3bd(self.bn3bd(self.conv3bd(x3cd)))
            x3d = self.relu3ad(self.bn3ad(self.conv3ad(x3d)))

            # Stage 2d
            x2d = self.upconv2(x3d)
            pad2 = ReplicationPad2d(
                (0, x2.size(3) - x2d.size(3), 0, x2.size(2) - x2d.size(2)))
            x2d = torch.cat((pad2(x2d), x2), 1)
            x2d = self.relu2bd(self.bn2bd(self.conv2bd(x2d)))
            x2d = self.relu2ad(self.bn2ad(self.conv2ad(x2d)))

            # Stage 1d
            x1d = self.upconv1(x2d)
            pad1 = ReplicationPad2d(
                (0, x1.size(3) - x1d.size(3), 0, x1.size(2) - x1d.size(2)))
            x1d = torch.cat((pad1(x1d), x1), 1)
            x1d = self.relu1bd(self.bn1bd(self.conv1bd(x1d)))

            end.append(x1d)

        # bottleneck classifier
        diff = torch.abs(res[1] - res[0])
        if extract_features == 'joint':
            return diff

        if extract_features == 'last':
            return end

        bottle = torch.flatten(diff, 1)
        bottle = self.relu8(self.linear1(bottle))
        bottle = self.relu9(self.linear2(bottle))
        bottle = self.linear3(bottle)

        return [bottle, end]
    def forward(self, s2_1, s2_2, s1_1, s1_2):
        """Forward method."""

        #################################################### encoder S2 ####################################################

        # siamese processing of input s2_1
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(s2_1))))
        x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)

        ####################################################

        # siamese processing of input s2_2
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(s2_2))))
        x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)

        #################################################### encoder S1 ####################################################

        # siamese processing of input s1_1
        # Stage 1
        x11_b = self.do11_b(F.relu(self.bn11_b(self.conv11_b(s1_1))))
        x12_1_b = self.do12_b(F.relu(self.bn12_b(self.conv12_b(x11_b))))
        x1p_b = F.max_pool2d(x12_1_b, kernel_size=2, stride=2)

        # Stage 2
        x21_b = self.do21_b(F.relu(self.bn21_b(self.conv21_b(x1p_b))))
        x22_1_b = self.do22_b(F.relu(self.bn22_b(self.conv22_b(x21_b))))
        x2p_b = F.max_pool2d(x22_1_b, kernel_size=2, stride=2)

        # Stage 3
        x31_b = self.do31_b(F.relu(self.bn31_b(self.conv31_b(x2p_b))))
        x32_b = self.do32_b(F.relu(self.bn32_b(self.conv32_b(x31_b))))
        x33_1_b = self.do33_b(F.relu(self.bn33_b(self.conv33_b(x32_b))))
        x3p_b = F.max_pool2d(x33_1_b, kernel_size=2, stride=2)

        # Stage 4
        x41_b = self.do41_b(F.relu(self.bn41_b(self.conv41_b(x3p_b))))
        x42_b = self.do42_b(F.relu(self.bn42_b(self.conv42_b(x41_b))))
        x43_1_b = self.do43_b(F.relu(self.bn43_b(self.conv43_b(x42_b))))
        x4p_b = F.max_pool2d(x43_1_b, kernel_size=2, stride=2)

        ####################################################

        # siamese processing of input s1_2
        # Stage 1
        x11_b = self.do11_b(F.relu(self.bn11_b(self.conv11_b(s1_2))))
        x12_2_b = self.do12_b(F.relu(self.bn12_b(self.conv12_b(x11_b))))
        x1p_b = F.max_pool2d(x12_2_b, kernel_size=2, stride=2)

        # Stage 2
        x21_b = self.do21_b(F.relu(self.bn21_b(self.conv21_b(x1p_b))))
        x22_2_b = self.do22_b(F.relu(self.bn22_b(self.conv22_b(x21_b))))
        x2p_b = F.max_pool2d(x22_2_b, kernel_size=2, stride=2)

        # Stage 3
        x31_b = self.do31_b(F.relu(self.bn31_b(self.conv31_b(x2p_b))))
        x32_b = self.do32_b(F.relu(self.bn32_b(self.conv32_b(x31_b))))
        x33_2_b = self.do33_b(F.relu(self.bn33_b(self.conv33_b(x32_b))))
        x3p_b = F.max_pool2d(x33_2_b, kernel_size=2, stride=2)

        # Stage 4
        x41_b = self.do41_b(F.relu(self.bn41_b(self.conv41_b(x3p_b))))
        x42_b = self.do42_b(F.relu(self.bn42_b(self.conv42_b(x41_b))))
        x43_2_b = self.do43_b(F.relu(self.bn43_b(self.conv43_b(x42_b))))
        x4p_b = F.max_pool2d(x43_2_b, kernel_size=2, stride=2)

        #################################################### decoder ####################################################
        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = ReplicationPad2d(
            (0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), x43_1, x43_2, x43_1_b, x43_2_b), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d(
            (0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), x33_1, x33_2, x33_1_b, x33_2_b), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d(
            (0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), x22_1, x22_2, x22_1_b, x22_2_b), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d(
            (0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), x12_1, x12_2, x12_1_b, x12_2_b), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)

        return self.sm(x11d)
    def forward(self, s2_1, s2_2, s1_1=None, s1_2=None):

        if s1_1 is not None and s1_2 is not None:
            s2_1 = torch.cat((s2_1, s1_1), 1)
            s2_2 = torch.cat((s2_2, s1_2), 1)
        """Forward method."""
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(s2_1))))
        x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)

        ####################################################
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(s2_2))))
        x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)

        ####################################################
        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = ReplicationPad2d(
            (0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d(
            (0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d(
            (0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d(
            (0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)

        return self.sm(x11d)