def forward(self, x1, x2): x = torch.cat((x1, x2), 1) """Forward method.""" # Stage 1 x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) x1p = F.max_pool2d(x12, kernel_size=2, stride=2) # Stage 2 x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) x2p = F.max_pool2d(x22, kernel_size=2, stride=2) # Stage 3 x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) x3p = F.max_pool2d(x33, kernel_size=2, stride=2) # Stage 4 x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) x4p = F.max_pool2d(x43, kernel_size=2, stride=2) # Stage 4d x4d = self.upconv4(x4p) pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) x4d = torch.cat((pad4(x4d), x43), 1) x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) # Stage 3d x3d = self.upconv3(x41d) pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) x3d = torch.cat((pad3(x3d), x33), 1) x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) # Stage 2d x2d = self.upconv2(x31d) pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) x2d = torch.cat((pad2(x2d), x22), 1) x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) # Stage 1d x1d = self.upconv1(x21d) pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) x1d = torch.cat((pad1(x1d), x12), 1) x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) x11d = self.conv11d(x12d) return x11d # return self.sm(x11d)
def forward(self, x1, x2): x = torch.cat((x1, x2), 1) # print(x.shape) # print((self.bn)) x = self.bn(x) # pad5 = ReplicationPad2d((0, x53.size(3) - x5d.size(3), 0, x53.size(2) - x5d.size(2))) s1_1 = x.size() x1 = self.encres1_1(x) x = self.encres1_2(x1) s2_1 = x.size() x2 = self.encres2_1(x) x = self.encres2_2(x2) s3_1 = x.size() x3 = self.encres3_1(x) x = self.encres3_2(x3) s4_1 = x.size() x4 = self.encres4_1(x) x = self.encres4_2(x4) x = self.decres4_1(x) x = self.decres4_2(x) s4_2 = x.size() pad4 = ReplicationPad2d((0, s4_1[3] - s4_2[3], 0, s4_1[2] - s4_2[2])) x = pad4(x) # x = self.decres3_1(x) x = self.decres3_1(torch.cat((x, x4), 1)) x = self.decres3_2(x) s3_2 = x.size() pad3 = ReplicationPad2d((0, s3_1[3] - s3_2[3], 0, s3_1[2] - s3_2[2])) x = pad3(x) x = self.decres2_1(torch.cat((x, x3), 1)) x = self.decres2_2(x) s2_2 = x.size() pad2 = ReplicationPad2d((0, s2_1[3] - s2_2[3], 0, s2_1[2] - s2_2[2])) x = pad2(x) x = self.decres1_1(torch.cat((x, x2), 1)) x = self.decres1_2(x) s1_2 = x.size() pad1 = ReplicationPad2d((0, s1_1[3] - s1_2[3], 0, s1_1[2] - s1_2[2])) x = pad1(x) x = self.coupling(torch.cat((x, x1), 1)) x = self.sm(x) return x
def forward(self, data, n_branches, extract_features=None, **kwargs): layer1 = list() layer2 = list() layer3 = list() res = list() for i in range(n_branches): # Siamese/triplet nets; sharing weights x = data[i] x1 = self.relu1(self.bn1(self.conv1(x))) x = self.maxpool1(x1) x2 = self.relu2(self.bn2(self.conv2(x))) x = self.maxpool2(x2) x3 = self.relu3(self.bn3(self.conv3(x))) x = self.maxpool3(x3) layer1.append(x1) layer2.append(x2) layer3.append(x3) res.append(x) x = torch.abs(res[1] - res[0]) if n_branches == 3: x = torch.cat(x, torch.abs(res[2] - res[1]), 1) x = self.relu4(self.bn4(self.conv4(x))) if extract_features == 'joint': return(x) x = self.convT5(x) pad = ReplicationPad2d((0, layer3[0].shape[3] - x.shape[3], 0, layer3[0].shape[2] - x.shape[2])) x = torch.cat([pad(x), torch.abs(layer3[1]-layer3[0])], dim=1) x = self.relu5(self.bn5(self.conv5(x))) x = self.convT6(x) pad = ReplicationPad2d((0, layer2[0].shape[3] - x.shape[3], 0, layer2[0].shape[2] - x.shape[2])) x = torch.cat([pad(x), torch.abs(layer2[1]-layer2[0])], dim=1) x = self.relu6(self.bn6(self.conv6(x))) x = self.convT7(x) pad = ReplicationPad2d((0, layer1[0].shape[3] - x.shape[3], 0, layer1[0].shape[2] - x.shape[2])) x = torch.cat([pad(x), torch.abs(layer1[1]-layer1[0])], dim=1) x = self.relu7(self.bn7(self.conv7(x))) if extract_features == 'last': return(x) x = torch.flatten(x, 1) x = self.relu8(self.linear1(x)) x = self.relu9(self.linear2(x)) x = self.linear3(x) return x
def forward(self, x1, x2): """Forward method.""" # Stage 1 x11_1 = self.do11(F.relu(self.bn11(self.conv11(x1)))) #print(x11.shape) x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11_1)))) #print(x12_1.shape) x1p_1 = F.max_pool2d(x12_1, kernel_size=2, stride=2) #print(x1p.shape) # Stage 2 x21_1 = self.do21(F.relu(self.bn21(self.conv21(x1p_1)))) x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21_1)))) x2p_1 = F.max_pool2d(x22_1, kernel_size=2, stride=2) # Stage 3 x31_1 = self.do31(F.relu(self.bn31(self.conv31(x2p_1)))) x32_1 = self.do32(F.relu(self.bn32(self.conv32(x31_1)))) x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32_1)))) x3p_1 = F.max_pool2d(x33_1, kernel_size=2, stride=2) # Stage 4 x41_1 = self.do41(F.relu(self.bn41(self.conv41(x3p_1)))) x42_1 = self.do42(F.relu(self.bn42(self.conv42(x41_1)))) x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42_1)))) x4p_1 = F.max_pool2d(x43_1, kernel_size=2, stride=2) #################################################### # Stage 1 x11_2 = self.do11(F.relu(self.bn11(self.conv11(x2)))) x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11_2)))) x1p_2 = F.max_pool2d(x12_2, kernel_size=2, stride=2) # Stage 2 x21_2 = self.do21(F.relu(self.bn21(self.conv21(x1p_2)))) x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21_2)))) x2p_2 = F.max_pool2d(x22_2, kernel_size=2, stride=2) # Stage 3 x31_2 = self.do31(F.relu(self.bn31(self.conv31(x2p_2)))) x32_2 = self.do32(F.relu(self.bn32(self.conv32(x31_2)))) x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32_2)))) x3p_2 = F.max_pool2d(x33_2, kernel_size=2, stride=2) # Stage 4 x41_2 = self.do41(F.relu(self.bn41(self.conv41(x3p_2)))) x42_2 = self.do42(F.relu(self.bn42(self.conv42(x41_2)))) x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42_2)))) x4p_2 = F.max_pool2d(x43_2, kernel_size=2, stride=2) #################################################### # Stage 4d x4d = self.upconv4(x4p_2) pad4 = ReplicationPad2d( (0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) # Stage 3d x3d = self.upconv3(x41d) pad3 = ReplicationPad2d( (0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) # Stage 2d x2d = self.upconv2(x31d) pad2 = ReplicationPad2d( (0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) # Stage 1d x1d = self.upconv1(x21d) pad1 = ReplicationPad2d( (0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) x11d = self.conv11d(x12d) return self.sm(x11d)
def forward(self, data, n_branches, extract_features=None): """Forward method.""" x1 = data[0] x2 = data[1] # ============================================================================= # # Stage 1 # x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) # x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) # x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) # ============================================================================= # Stage 1 x11 = F.relu(self.bn11(self.conv11(x1))) x12_1 = F.relu(self.bn12(self.conv12(x11))) x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) # Stage 2 x21 = F.relu(self.bn21(self.conv21(x1p))) x22_1 = F.relu(self.bn22(self.conv22(x21))) x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) # Stage 3 x31 = F.relu(self.bn31(self.conv31(x2p))) x32 = F.relu(self.bn32(self.conv32(x31))) x33_1 = F.relu(self.bn33(self.conv33(x32))) x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) # Stage 4 x41 = F.relu(self.bn41(self.conv41(x3p))) x42 = F.relu(self.bn42(self.conv42(x41))) x43_1 = F.relu(self.bn43(self.conv43(x42))) x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) #################################################### # Stage 1 x11 = F.relu(self.bn11(self.conv11(x2))) x12_2 = F.relu(self.bn12(self.conv12(x11))) x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) # Stage 2 x21 = F.relu(self.bn21(self.conv21(x1p))) x22_2 = F.relu(self.bn22(self.conv22(x21))) x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) # Stage 3 x31 = F.relu(self.bn31(self.conv31(x2p))) x32 =F.relu(self.bn32(self.conv32(x31))) x33_2 = F.relu(self.bn33(self.conv33(x32))) x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) # Stage 4 x41 = F.relu(self.bn41(self.conv41(x3p))) x42 = F.relu(self.bn42(self.conv42(x41))) x43_2 = F.relu(self.bn43(self.conv43(x42))) x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) # Stage 4d x4d = self.upconv4(x4p) pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) x43d = F.relu(self.bn43d(self.conv43d(x4d))) x42d = F.relu(self.bn42d(self.conv42d(x43d))) x41d = F.relu(self.bn41d(self.conv41d(x42d))) # Stage 3d x3d = self.upconv3(x41d) pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) x33d = F.relu(self.bn33d(self.conv33d(x3d))) x32d = F.relu(self.bn32d(self.conv32d(x33d))) x31d = F.relu(self.bn31d(self.conv31d(x32d))) # Stage 2d x2d = self.upconv2(x31d) pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) x22d = F.relu(self.bn22d(self.conv22d(x2d))) x21d = F.relu(self.bn21d(self.conv21d(x22d))) # Stage 1d x1d = self.upconv1(x21d) pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) x12d = F.relu(self.bn12d(self.conv12d(x1d))) x11d = self.conv11d(x12d) return x11d #self.sm(x11d)
def forward(self, data, n_branches, extract_features=None, **kwargs): res = list() end = list() for i in range(n_branches): # Siamese/triplet nets; sharing weights # encoding x = data[i] x1 = self.relu1a(self.bn1a(self.conv1a(x))) #x1b = self.relu1b(self.bn1b(self.conv1b(x1a))) x = self.maxpool1(x1) x2 = self.relu2a(self.bn2a(self.conv2a(x))) #x2b = self.relu2b(self.bn2b(self.conv2b(x2a))) x = self.maxpool2(x2) x3a = self.relu3a(self.bn3a(self.conv3a(x))) x3 = self.relu3b(self.bn3b(self.conv3b(x3a))) #x3c = self.relu3c(self.bn3c(self.conv3c(x3b))) x = self.maxpool3(x3) x4a = self.relu4a(self.bn4a(self.conv4a(x))) x4 = self.relu4b(self.bn4b(self.conv4b(x4a))) #x4c = self.relu4c(self.bn4c(self.conv4c(x4b))) x = self.maxpool4(x4) xb = self.relu_bottle(self.bn_bottle(self.conv_bottle(x))) res.append(xb) # decoding x4d = self.upconv4(xb) pad4 = ReplicationPad2d( (0, x4.size(3) - x4d.size(3), 0, x4.size(2) - x4d.size(2))) x4d = torch.cat((pad4(x4d), x4), 1) x4d = self.relu4cd(self.bn4cd(self.conv4cd(x4d))) #x4bd = self.relu4bd(self.bn4bd(self.conv4bd(x4cd))) x4d = self.relu4ad(self.bn4ad(self.conv4ad(x4d))) # Stage 3d x3d = self.upconv3(x4d) pad3 = ReplicationPad2d( (0, x3.size(3) - x3d.size(3), 0, x3.size(2) - x3d.size(2))) x3d = torch.cat((pad3(x3d), x3), 1) x3d = self.relu3cd(self.bn3cd(self.conv3cd(x3d))) #x3bd = self.relu3bd(self.bn3bd(self.conv3bd(x3cd))) x3d = self.relu3ad(self.bn3ad(self.conv3ad(x3d))) # Stage 2d x2d = self.upconv2(x3d) pad2 = ReplicationPad2d( (0, x2.size(3) - x2d.size(3), 0, x2.size(2) - x2d.size(2))) x2d = torch.cat((pad2(x2d), x2), 1) x2d = self.relu2bd(self.bn2bd(self.conv2bd(x2d))) x2d = self.relu2ad(self.bn2ad(self.conv2ad(x2d))) # Stage 1d x1d = self.upconv1(x2d) pad1 = ReplicationPad2d( (0, x1.size(3) - x1d.size(3), 0, x1.size(2) - x1d.size(2))) x1d = torch.cat((pad1(x1d), x1), 1) x1d = self.relu1bd(self.bn1bd(self.conv1bd(x1d))) end.append(x1d) # bottleneck classifier diff = torch.abs(res[1] - res[0]) if extract_features == 'joint': return diff if extract_features == 'last': return end bottle = torch.flatten(diff, 1) bottle = self.relu8(self.linear1(bottle)) bottle = self.relu9(self.linear2(bottle)) bottle = self.linear3(bottle) return [bottle, end]
def forward(self, s2_1, s2_2, s1_1, s1_2): """Forward method.""" #################################################### encoder S2 #################################################### # siamese processing of input s2_1 # Stage 1 x11 = self.do11(F.relu(self.bn11(self.conv11(s2_1)))) x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) # Stage 2 x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) # Stage 3 x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) # Stage 4 x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) #################################################### # siamese processing of input s2_2 # Stage 1 x11 = self.do11(F.relu(self.bn11(self.conv11(s2_2)))) x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) # Stage 2 x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) # Stage 3 x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) # Stage 4 x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) #################################################### encoder S1 #################################################### # siamese processing of input s1_1 # Stage 1 x11_b = self.do11_b(F.relu(self.bn11_b(self.conv11_b(s1_1)))) x12_1_b = self.do12_b(F.relu(self.bn12_b(self.conv12_b(x11_b)))) x1p_b = F.max_pool2d(x12_1_b, kernel_size=2, stride=2) # Stage 2 x21_b = self.do21_b(F.relu(self.bn21_b(self.conv21_b(x1p_b)))) x22_1_b = self.do22_b(F.relu(self.bn22_b(self.conv22_b(x21_b)))) x2p_b = F.max_pool2d(x22_1_b, kernel_size=2, stride=2) # Stage 3 x31_b = self.do31_b(F.relu(self.bn31_b(self.conv31_b(x2p_b)))) x32_b = self.do32_b(F.relu(self.bn32_b(self.conv32_b(x31_b)))) x33_1_b = self.do33_b(F.relu(self.bn33_b(self.conv33_b(x32_b)))) x3p_b = F.max_pool2d(x33_1_b, kernel_size=2, stride=2) # Stage 4 x41_b = self.do41_b(F.relu(self.bn41_b(self.conv41_b(x3p_b)))) x42_b = self.do42_b(F.relu(self.bn42_b(self.conv42_b(x41_b)))) x43_1_b = self.do43_b(F.relu(self.bn43_b(self.conv43_b(x42_b)))) x4p_b = F.max_pool2d(x43_1_b, kernel_size=2, stride=2) #################################################### # siamese processing of input s1_2 # Stage 1 x11_b = self.do11_b(F.relu(self.bn11_b(self.conv11_b(s1_2)))) x12_2_b = self.do12_b(F.relu(self.bn12_b(self.conv12_b(x11_b)))) x1p_b = F.max_pool2d(x12_2_b, kernel_size=2, stride=2) # Stage 2 x21_b = self.do21_b(F.relu(self.bn21_b(self.conv21_b(x1p_b)))) x22_2_b = self.do22_b(F.relu(self.bn22_b(self.conv22_b(x21_b)))) x2p_b = F.max_pool2d(x22_2_b, kernel_size=2, stride=2) # Stage 3 x31_b = self.do31_b(F.relu(self.bn31_b(self.conv31_b(x2p_b)))) x32_b = self.do32_b(F.relu(self.bn32_b(self.conv32_b(x31_b)))) x33_2_b = self.do33_b(F.relu(self.bn33_b(self.conv33_b(x32_b)))) x3p_b = F.max_pool2d(x33_2_b, kernel_size=2, stride=2) # Stage 4 x41_b = self.do41_b(F.relu(self.bn41_b(self.conv41_b(x3p_b)))) x42_b = self.do42_b(F.relu(self.bn42_b(self.conv42_b(x41_b)))) x43_2_b = self.do43_b(F.relu(self.bn43_b(self.conv43_b(x42_b)))) x4p_b = F.max_pool2d(x43_2_b, kernel_size=2, stride=2) #################################################### decoder #################################################### # Stage 4d x4d = self.upconv4(x4p) pad4 = ReplicationPad2d( (0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) x4d = torch.cat((pad4(x4d), x43_1, x43_2, x43_1_b, x43_2_b), 1) x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) # Stage 3d x3d = self.upconv3(x41d) pad3 = ReplicationPad2d( (0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) x3d = torch.cat((pad3(x3d), x33_1, x33_2, x33_1_b, x33_2_b), 1) x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) # Stage 2d x2d = self.upconv2(x31d) pad2 = ReplicationPad2d( (0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) x2d = torch.cat((pad2(x2d), x22_1, x22_2, x22_1_b, x22_2_b), 1) x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) # Stage 1d x1d = self.upconv1(x21d) pad1 = ReplicationPad2d( (0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) x1d = torch.cat((pad1(x1d), x12_1, x12_2, x12_1_b, x12_2_b), 1) x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) x11d = self.conv11d(x12d) return self.sm(x11d)
def forward(self, s2_1, s2_2, s1_1=None, s1_2=None): if s1_1 is not None and s1_2 is not None: s2_1 = torch.cat((s2_1, s1_1), 1) s2_2 = torch.cat((s2_2, s1_2), 1) """Forward method.""" # Stage 1 x11 = self.do11(F.relu(self.bn11(self.conv11(s2_1)))) x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) # Stage 2 x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) # Stage 3 x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) # Stage 4 x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) #################################################### # Stage 1 x11 = self.do11(F.relu(self.bn11(self.conv11(s2_2)))) x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) # Stage 2 x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) # Stage 3 x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) # Stage 4 x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) #################################################### # Stage 4d x4d = self.upconv4(x4p) pad4 = ReplicationPad2d( (0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) # Stage 3d x3d = self.upconv3(x41d) pad3 = ReplicationPad2d( (0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) # Stage 2d x2d = self.upconv2(x31d) pad2 = ReplicationPad2d( (0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) # Stage 1d x1d = self.upconv1(x21d) pad1 = ReplicationPad2d( (0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) x11d = self.conv11d(x12d) return self.sm(x11d)