def forward(self,ed,es,source): # print(ed.shape) # print(source.shape) x1 = torch.cat([ed, source], dim=1) x1 = self.unet1(x1) flow_field_x1 = self.flow(x1) depth = source.shape[2] # print(ed.shape) # print(flow_field_x1.shape) ed_source = layer.SpatialTransformer((depth, 128, 128))(ed, flow_field_x1) # ed_source_1 = layer.SpatialTransformer((depth, 128, 128))(ed, flow_field_x1[:,0:1,:,:,:]) # ed_source_2 = layer.SpatialTransformer((depth, 128, 128))(ed, flow_field_x1[:, 1:2, :, :, :]) # ed_source_3 = layer.SpatialTransformer((depth, 128, 128))(ed, flow_field_x1[:, 2:3, :, :, :]) # ed_source = F.softmax(torch.cat((ed_source_1, ed_source_2, ed_source_3), 1), 1) # print(ed.shape) # print(flow_field_x1[:,0:1,:,:,:].shape) x2 = torch.cat([es, source], dim=1) x2 = self.unet2(x2) flow_field_x2 = self.flow(x2) es_source = layer.SpatialTransformer((depth, 128, 128))(es, flow_field_x2) # es_source_1 = layer.SpatialTransformer((depth, 128, 128))(es, flow_field_x2[:, 0:1, :, :, :]) # es_source_2 = layer.SpatialTransformer((depth, 128, 128))(es, flow_field_x2[:, 1:2, :, :, :]) # es_source_3 = layer.SpatialTransformer((depth, 128, 128))(es, flow_field_x2[:, 2:3, :, :, :]) # es_source = F.softmax(torch.cat((es_source_1, es_source_2, es_source_3), 1), 1) return flow_field_x1,ed_source,flow_field_x2,es_source
def forward(self, source, target): depth = source.shape[2] # y_source = layer.SpatialTransformer((depth, 192, 192))(source, source) x = torch.cat([source, target], dim=1) x = self.unet(x) flow_field = self.flow(x) y_source = layer.SpatialTransformer((depth, 192, 192))(source, flow_field) return y_source, flow_field
criterion_ncc(img_pre[:, :, 1:-1, mi3:mx3, mi4:mx4], step3_flow[:, :, 1:-1, mi3:mx3, mi4:mx4])+ \ criterion_ncc(img_ed[:, :, 1:-1, mi3:mx3, mi4:mx4], step4_flow[:, :, 1:-1, mi3:mx3, mi4:mx4]))/4)/2 ############## speed_combin = 0.25 * (speed_field[:, 0:3, :, :, :] + speed_field[:, 3:6, :, :, :] + \ speed_field[:, 6:9, :, :, :] + speed_field[:, 9:12, :, :, :]) fm = flow_field - flow_field.mean() f = fm / fm.std() sm = speed_combin - speed_combin.mean() s = sm / sm.std() loss_com = criterion_MSE(f, s) ##################### w0 = torch.unsqueeze(torch.unsqueeze(labeles[0, 0, :, :, :], 0), 0) w1 = torch.unsqueeze(torch.unsqueeze(labeles[0, 1, :, :, :], 0), 0) w2 = torch.unsqueeze(torch.unsqueeze(labeles[0, 2, :, :, :], 0), 0) w3 = torch.unsqueeze(torch.unsqueeze(labeles[0, 3, :, :, :], 0), 0) w0 = layer.SpatialTransformer((depth, 128, 128))(w0, flow_field) w1 = layer.SpatialTransformer((depth, 128, 128))(w1, flow_field) w2 = layer.SpatialTransformer((depth, 128, 128))(w2, flow_field) w3 = layer.SpatialTransformer((depth, 128, 128))(w3, flow_field) # print(flfinseg0.shape) ws = F.softmax(torch.cat([w0, w1, w2, w3], dim=1), dim=1) loss_fse = criterion_dice(ws[:, :, 1:-1, :, :], labeled[:, :, 1:-1, :, :]) ############### loss_reg = loss_smooth + loss_com + loss_reg_ncc + 0.01 * loss_fse loss_continous_motion = loss_reg_ncc flow_field_reg = flow_field ################################ loss_seg.backward() loss_reg.backward(retain_graph=True)
def forward(self, start, final, pre_img, mid_img, aft_img, labeled, labeles): depth = start.shape[2] fsim, fsflow = self.reg(start, final) preimflow = layer.SpatialTransformer((depth, 192, 192))(start, 0.25 * fsflow) midimflow = layer.SpatialTransformer((depth, 192, 192))(start, 0.5 * fsflow) afimflow = layer.SpatialTransformer((depth, 192, 192))(start, 0.75 * fsflow) fake_es = layer.SpatialTransformer((depth, 192, 192))(final, -1 * fsflow) startseg1 = self.unet(start) flfinseg0 = torch.unsqueeze( torch.unsqueeze(startseg1[0, 0, :, :, :], 0), 0) flfinseg1 = torch.unsqueeze( torch.unsqueeze(startseg1[0, 1, :, :, :], 0), 0) flfinseg2 = torch.unsqueeze( torch.unsqueeze(startseg1[0, 2, :, :, :], 0), 0) flfinseg3 = torch.unsqueeze( torch.unsqueeze(startseg1[0, 3, :, :, :], 0), 0) flfinseg0 = layer.SpatialTransformer((depth, 192, 192))(flfinseg0, fsflow) flfinseg1 = layer.SpatialTransformer((depth, 192, 192))(flfinseg1, fsflow) flfinseg2 = layer.SpatialTransformer((depth, 192, 192))(flfinseg2, fsflow) flfinseg3 = layer.SpatialTransformer((depth, 192, 192))(flfinseg3, fsflow) flfinseg = torch.cat([flfinseg0, flfinseg1, flfinseg2, flfinseg3], dim=1) finalseg1 = self.unet(final) flstaseg0 = torch.unsqueeze( torch.unsqueeze(finalseg1[0, 0, :, :, :], 0), 0) flstaseg1 = torch.unsqueeze( torch.unsqueeze(finalseg1[0, 1, :, :, :], 0), 0) flstaseg2 = torch.unsqueeze( torch.unsqueeze(finalseg1[0, 2, :, :, :], 0), 0) flstaseg3 = torch.unsqueeze( torch.unsqueeze(finalseg1[0, 3, :, :, :], 0), 0) flstaseg0 = layer.SpatialTransformer((depth, 192, 192))(flstaseg0, -1 * fsflow) flstaseg1 = layer.SpatialTransformer((depth, 192, 192))(flstaseg1, -1 * fsflow) flstaseg2 = layer.SpatialTransformer((depth, 192, 192))(flstaseg2, -1 * fsflow) flstaseg3 = layer.SpatialTransformer((depth, 192, 192))(flstaseg3, -1 * fsflow) flstaseg = torch.cat([flstaseg0, flstaseg1, flstaseg2, flstaseg3], dim=1) startseg = F.softmax(self.v2( self.v1(torch.cat([startseg1, flstaseg], dim=1))), dim=1) finalseg = F.softmax(self.v2( self.v1(torch.cat([finalseg1, flfinseg], dim=1))), dim=1) #####reg_label flowfn0 = torch.unsqueeze(torch.unsqueeze(labeles[0, 0, :, :, :], 0), 0) flowfn1 = torch.unsqueeze(torch.unsqueeze(labeles[0, 1, :, :, :], 0), 0) flowfn2 = torch.unsqueeze(torch.unsqueeze(labeles[0, 2, :, :, :], 0), 0) flowfn3 = torch.unsqueeze(torch.unsqueeze(labeles[0, 3, :, :, :], 0), 0) flowfn0 = layer.SpatialTransformer((depth, 192, 192))(flowfn0, fsflow) flowfn1 = layer.SpatialTransformer((depth, 192, 192))(flowfn1, fsflow) flowfn2 = layer.SpatialTransformer((depth, 192, 192))(flowfn2, fsflow) flowfn3 = layer.SpatialTransformer((depth, 192, 192))(flowfn3, fsflow) flowfn = torch.cat([flowfn0, flowfn1, flowfn2, flowfn3], dim=1) flowst0 = torch.unsqueeze(torch.unsqueeze(labeled[0, 0, :, :, :], 0), 0) flowst1 = torch.unsqueeze(torch.unsqueeze(labeled[0, 1, :, :, :], 0), 0) flowst2 = torch.unsqueeze(torch.unsqueeze(labeled[0, 2, :, :, :], 0), 0) flowst3 = torch.unsqueeze(torch.unsqueeze(labeled[0, 3, :, :, :], 0), 0) flowst0 = layer.SpatialTransformer((depth, 192, 192))(flowst0, -1 * fsflow) flowst1 = layer.SpatialTransformer((depth, 192, 192))(flowst1, -1 * fsflow) flowst2 = layer.SpatialTransformer((depth, 192, 192))(flowst2, -1 * fsflow) flowst3 = layer.SpatialTransformer((depth, 192, 192))(flowst3, -1 * fsflow) flowst = torch.cat([flowst0, flowst1, flowst2, flowst3], dim=1) return fsim, fake_es, preimflow, midimflow, afimflow, fsflow, startseg, finalseg, flowfn, flowst
img_ed = img_ed.to(device).float() img_es = img_es.to(device).float() labeled = labeled.to(device).float() labeles = labeles.to(device).float() source = source.to(device).float() # mi3, mx3, mi4, mx4 = criterion.location(labeles, labeled, 5, 4) depth = img_ed.shape[2] ################################################################################ ed_seg = Segnet(img_ed) es_seg = Segnet(img_es) seg = Segnet(source) with torch.no_grad(): flow_field_x1, ed_source, flow_field_x2, es_source = Flownet( img_es, img_ed, source) ed_source_1 = layer.SpatialTransformer( (depth, 128, 128))(ed_seg[:, 0:1, :, :, :], flow_field_x1) ed_source_2 = layer.SpatialTransformer( (depth, 128, 128))(ed_seg[:, 1:2, :, :, :], flow_field_x1) ed_source_3 = layer.SpatialTransformer( (depth, 128, 128))(ed_seg[:, 2:3, :, :, :], flow_field_x1) ed_source_4 = layer.SpatialTransformer( (depth, 128, 128))(ed_seg[:, 3:4, :, :, :], flow_field_x1) ed_source = torch.cat( (ed_source_1, ed_source_2, ed_source_3, ed_source_4), 1) #################################################################################################### es_source_1 = layer.SpatialTransformer( (depth, 128, 128))(es_seg[:, 0:1, :, :, :], flow_field_x2) es_source_2 = layer.SpatialTransformer( (depth, 128, 128))(es_seg[:, 1:2, :, :, :], flow_field_x2) es_source_3 = layer.SpatialTransformer( (depth, 128, 128))(es_seg[:, 2:3, :, :, :], flow_field_x2)