def test(): model = ModelFlow_stride(2, 3, opt.start_channel).cuda() transform = SpatialTransform().cuda() model.load_state_dict(torch.load(opt.modelpath)) model.eval() transform.eval() grid = generate_grid(imgshape) grid = Variable(torch.from_numpy(np.reshape(grid, (1, ) + grid.shape))).cuda().float() start = timeit.default_timer() A = Variable(torch.from_numpy(load_5D(opt.fixed))).cuda().float() B = Variable(torch.from_numpy(load_5D(opt.moving))).cuda().float() start2 = timeit.default_timer() print('Time for loading data: ', start2 - start) pred = model(A, B) F_AB = pred.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] F_AB = F_AB.astype(np.float32) * range_flow warped_A = transform(A, pred.permute(0, 2, 3, 4, 1) * range_flow, grid).data.cpu().numpy()[0, 0, :, :, :] start3 = timeit.default_timer() print('Time for registration: ', start3 - start2) warped_F_BA = transform(-pred, pred.permute(0, 2, 3, 4, 1) * range_flow, grid).permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] warped_F_BA = warped_F_BA.astype(np.float32) * range_flow start4 = timeit.default_timer() print('Time for generating inverse flow: ', start4 - start3) save_flow(F_AB, savepath + '/flow_A_B.nii.gz') save_flow(warped_F_BA, savepath + '/inverse_flow_B_A.nii.gz') save_img(warped_A, savepath + '/warped_A.nii.gz') start5 = timeit.default_timer() print('Time for saving results: ', start5 - start4) del pred pred = model(B, A) F_BA = pred.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] F_BA = F_BA.astype(np.float32) * range_flow warped_B = transform(B, pred.permute(0, 2, 3, 4, 1) * range_flow, grid).data.cpu().numpy()[0, 0, :, :, :] warped_F_AB = transform(-pred, pred.permute(0, 2, 3, 4, 1) * range_flow, grid).permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] warped_F_AB = warped_F_AB.astype(np.float32) * range_flow save_flow(F_BA, savepath + '/flow_B_A.nii.gz') save_flow(warped_F_AB, savepath + '/inverse_flow_A_B.nii.gz') save_img(warped_B, savepath + '/warped_B.nii.gz')
def train(): model =ModelFlow_stride(2,3,start_channel).cuda() loss_similarity =mse_loss loss_inverse = mse_loss loss_antifold = antifoldloss loss_smooth = smoothloss transform = SpatialTransform().cuda() for param in transform.parameters(): param.requires_grad = False param.volatile=True names = glob.glob(datapath + '/*.gz') grid = generate_grid(imgshape) grid = Variable(torch.from_numpy(np.reshape(grid, (1,) + grid.shape))).cuda().float() print(grid.type()) optimizer = torch.optim.Adam(model.parameters(),lr=lr) model_dir = '../Model' if not os.path.isdir(model_dir): os.mkdir(model_dir) lossall = np.zeros((5,iteration)) training_generator = Data.DataLoader(Dataset(names,iteration,True), batch_size=1, shuffle=False, num_workers=2) step=0 for X,Y in training_generator: X = X.cuda().float() Y = Y.cuda().float() F_xy = model(X,Y) F_yx = model(Y,X) X_Y = transform(X,F_xy.permute(0,2,3,4,1)*range_flow,grid) Y_X = transform(Y,F_yx.permute(0,2,3,4,1)*range_flow,grid) # Note that, the generation of inverse flow depends on the definition of transform. # The generation strategies are sligtly different for the backward warpping and forward warpping F_xy_ = transform(-F_yx,F_xy.permute(0,2,3,4,1)*range_flow,grid) F_yx_ = transform(-F_xy,F_yx.permute(0,2,3,4,1)*range_flow,grid) loss1 = loss_similarity(Y,X_Y) + loss_similarity(X,Y_X) loss2 = loss_inverse(F_xy*range_flow,F_xy_*range_flow) + loss_inverse(F_yx*range_flow,F_yx_*range_flow) loss3 = loss_antifold(F_xy*range_flow) + loss_antifold(F_yx*range_flow) loss4 = loss_smooth(F_xy*range_flow) + loss_smooth(F_yx*range_flow) loss = loss1+inverse*loss2 + antifold*loss3 + smooth*loss4 optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients lossall[:,step] = np.array([loss.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item()]) sys.stdout.write("\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:.4f}" - inv "{3:.4f}" \ - ant "{4:.4f}" -smo "{5:.4f}" '.format(step, loss.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item())) sys.stdout.flush() if(step % n_checkpoint == 0): modelname = model_dir + '/' + str(step) + '.pth' torch.save(model.state_dict(), modelname) step+=1 np.save(model_dir+'/loss.npy',lossall)
def test(): model = SYMNet(2, 3, opt.start_channel).cuda() transform = SpatialTransform().cuda() diff_transform = DiffeomorphicTransform(time_step=7).cuda() com_transform = CompositionTransform().cuda() model.load_state_dict(torch.load(opt.modelpath)) model.eval() transform.eval() diff_transform.eval() com_transform.eval() grid = generate_grid(imgshape) grid = torch.from_numpy(np.reshape(grid, (1, ) + grid.shape)).cuda().float() use_cuda = True device = torch.device("cuda" if use_cuda else "cpu") fixed_img = load_4D(fixed_path) moved_img = load_4D(moving_path) fixed_img = torch.from_numpy(fixed_img).float().to(device).unsqueeze(dim=0) moved_img = torch.from_numpy(moved_img).float().to(device).unsqueeze(dim=0) with torch.no_grad(): F_xy, F_yx = model(fixed_img, moved_img) F_X_Y_half = diff_transform(F_xy, grid, range_flow) F_Y_X_half = diff_transform(F_yx, grid, range_flow) F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow) F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow) F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow) F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow) F_BA = F_Y_X.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] F_BA = F_BA.astype(np.float32) * range_flow F_AB = F_X_Y.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] F_AB = F_AB.astype(np.float32) * range_flow warped_B = transform(moved_img, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid).data.cpu().numpy()[0, 0, :, :, :] warped_A = transform(fixed_img, F_X_Y.permute(0, 2, 3, 4, 1) * range_flow, grid).data.cpu().numpy()[0, 0, :, :, :] save_flow(F_BA, savepath + '/wrapped_flow_B_to_A.nii.gz') save_flow(F_AB, savepath + '/wrapped_flow_A_to_B.nii.gz') save_img(warped_B, savepath + '/wrapped_norm_B_to_A.nii.gz') save_img(warped_A, savepath + '/wrapped_norm_A_to_B.nii.gz') print("Finished.")
def train(): torch.cuda.empty_cache() device = torch.device("cuda:6") model = ModelFlow_stride(2, 3, start_channel).cuda(device) # model =ModelFlow_stride(2,3,start_channel).cpu() loss_similarity = mse_loss loss_inverse = mse_loss loss_antifold = antifoldloss loss_smooth = smoothloss transform = SpatialTransform().cuda(device) # transform = SpatialTransform().cpu() for param in transform.parameters(): param.requires_grad = False param.volatile = True names = glob.glob(datapath + '/*.gz') grid = generate_grid(imgshape) grid = Variable(torch.from_numpy(np.reshape( grid, (1, ) + grid.shape))).cuda(device).float() # grid = Variable(torch.from_numpy(np.reshape(grid, (1,) + grid.shape))).cpu().float() print(grid.type()) optimizer = torch.optim.Adam(model.parameters(), lr=lr) model_dir = '../Model' if not os.path.isdir(model_dir): os.mkdir(model_dir) lossall = np.zeros((5, iteration)) training_generator = Data.DataLoader(Dataset(names, iteration, True), batch_size=1, shuffle=False, num_workers=0) step = 0 for X, Y in training_generator: X = X.cuda(device).float() # X = X.cpu().float() Y = Y.cuda(device).float() # Y = Y.cpu().float() # # X = sitk.GetArrayFromImage(sitk.ReadImage(X, sitk.sitkFloat32)) # # Y = sitk.GetArrayFromImage(sitk.ReadImage(Y, sitk.sitkFloat32)) #X = Variable(X).cuda(device).float() #Y = Variable(Y).cuda(device).float() X, Y = train_padding(X, Y) #added today X = X.cuda(device).float( ) #added as the values returned after padding was not cuda Y = Y.cuda(device).float() F_xy = model(X, Y) F_yx = model(Y, X) X_Y = transform(X, F_xy.permute(0, 2, 3, 4, 1) * range_flow, grid) Y_X = transform(Y, F_yx.permute(0, 2, 3, 4, 1) * range_flow, grid) F_xy_ = transform(-F_xy, F_xy.permute(0, 2, 3, 4, 1) * range_flow, grid) F_yx_ = transform(-F_yx, F_yx.permute(0, 2, 3, 4, 1) * range_flow, grid) loss1 = loss_similarity(Y, X_Y) + loss_similarity(X, Y_X) loss2 = loss_inverse(F_xy * range_flow, F_xy_ * range_flow) + loss_inverse( F_yx * range_flow, F_yx_ * range_flow) loss3 = loss_antifold(F_xy * range_flow) + loss_antifold( F_yx * range_flow) loss4 = loss_smooth(F_xy * range_flow) + loss_smooth(F_yx * range_flow) loss = loss1 + inverse * loss2 + antifold * loss3 + smooth * loss4 optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients #lossall[:,step] = np.array([loss.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0]]) lossall[:, step] = np.array( [loss.data, loss1.data, loss2.data, loss3.data, loss4.data]) #sys.stdout.write("\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:.4f}" - inv "{3:.4f}" - ant "{4:.4f}" -smo "{5:.4f}" '.format(step, loss.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0])) sys.stdout.write( "\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:.4f}" - inv "{3:.4f}" - ant "{4:.4f}" -smo "{5:.4f}" ' .format(step, loss.data, loss1.data, loss2.data, loss3.data, loss4.data)) sys.stdout.flush() if (step % n_checkpoint == 0): modelname = model_dir + '/' + str(step) + '.pth' torch.save(model.state_dict(), modelname) step += 1 torch.cuda.empty_cache() del X del Y np.save(model_dir + '/loss.npy', lossall)
imgshape=imgshape_2, range_flow=range_flow, model_lvl1=model_lvl1).cuda() model = Miccai2020_LDR_laplacian_unit_disp_add_lvl3(in_channel, n_classes, start_channel, is_train=isTrainLvl3, imgshape=imgshape, range_flow=range_flow, model_lvl2=model_lvl2).cuda() transform = SpatialTransform_unit().cuda() model.load_state_dict(torch.load(opt.modelpath)) print("pretrained model is loaded: " + opt.modelpath) model.eval() transform.eval() use_cuda = True device = torch.device("cuda" if use_cuda else "cpu") print(" processor use: ", device) grid = generate_grid(imgshape) grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).cuda().float() #grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).to(device).float() def testImages(testingImagesLst): totalDice = 0.0 testing_generator = Data.DataLoader(Dataset_epoch(testingLst, norm=doNormalisation, isSeg=isSeg), batch_size=1, shuffle=True, num_workers=0) for imgPair in testing_generator: moving_path = imgPair[1][1][0] fixed_path = imgPair[1][0][0] moving_tensor = imgPair[0][0] fixed_tensor = imgPair[0][1] # print("moving_tensor.shape : ",moving_tensor.shape) # print("fixed_path : ",fixed_path) # print("moving_path : ", moving_path)
def train_lvl1(): print("Training lvl1...") device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = Miccai2020_LDR_laplacian_unit_disp_add_lvl1( 2, 3, start_channel, is_train=True, imgshape=imgshape_4, range_flow=range_flow).to(device) loss_similarity = NCC(win=3) loss_Jdet = neg_Jdet_loss loss_smooth = smoothloss transform = SpatialTransform_unit().to(device) for param in transform.parameters(): param.requires_grad = False param.volatile = True # OASIS names = sorted(glob.glob(datapath + '/*.nii')) grid_4 = generate_grid(imgshape_4) grid_4 = torch.from_numpy(np.reshape(grid_4, (1, ) + grid_4.shape)).to(device).float() optimizer = torch.optim.Adam(model.parameters(), lr=lr) # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) model_dir = '../Model/Stage' if not os.path.isdir(model_dir): os.mkdir(model_dir) lossall = np.zeros((4, iteration_lvl1 + 1)) training_generator = Data.DataLoader(Dataset_epoch(names, norm=False), batch_size=1, shuffle=True, num_workers=2) step = 0 load_model = False if load_model is True: model_path = "../Model/LDR_LPBA_NCC_lap_share_preact_1_05_3000.pth" print("Loading weight: ", model_path) step = 3000 model.load_state_dict(torch.load(model_path)) temp_lossall = np.load( "../Model/loss_LDR_LPBA_NCC_lap_share_preact_1_05_3000.npy") lossall[:, 0:3000] = temp_lossall[:, 0:3000] while step <= iteration_lvl1: for X, Y in training_generator: X = X.to(device).float() Y = Y.to(device).float() # output_disp_e0, warpped_inputx_lvl1_out, down_y, output_disp_e0_v, e0 F_X_Y, X_Y, Y_4x, F_xy, _ = model(X, Y) # 3 level deep supervision NCC loss_multiNCC = loss_similarity(X_Y, Y_4x) F_X_Y_norm = transform_unit_flow_to_flow_cuda( F_X_Y.permute(0, 2, 3, 4, 1).clone()) loss_Jacobian = loss_Jdet(F_X_Y_norm, grid_4) # reg2 - use velocity _, _, x, y, z = F_X_Y.shape F_X_Y[:, 0, :, :, :] = F_X_Y[:, 0, :, :, :] * z F_X_Y[:, 1, :, :, :] = F_X_Y[:, 1, :, :, :] * y F_X_Y[:, 2, :, :, :] = F_X_Y[:, 2, :, :, :] * x loss_regulation = loss_smooth(F_X_Y) loss = loss_multiNCC + antifold * loss_Jacobian + smooth * loss_regulation optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients lossall[:, step] = np.array([ loss.item(), loss_multiNCC.item(), loss_Jacobian.item(), loss_regulation.item() ]) sys.stdout.write( "\r" + 'step "{0}" -> training loss "{1:.4f}" - sim_NCC "{2:4f}" - Jdet "{3:.10f}" -smo "{4:.4f}"' .format(step, loss.item(), loss_multiNCC.item(), loss_Jacobian.item(), loss_regulation.item())) sys.stdout.flush() # with lr 1e-3 + with bias if (step % n_checkpoint == 0): modelname = model_dir + '/' + model_name + "stagelvl1_" + str( step) + '.pth' torch.save(model.state_dict(), modelname) np.save( model_dir + '/loss' + model_name + "stagelvl1_" + str(step) + '.npy', lossall) step += 1 if step > iteration_lvl1: break print("one epoch pass") np.save(model_dir + '/loss' + model_name + 'stagelvl1.npy', lossall)
def test(): device = torch.device("cuda:6") model = ModelFlow_stride(2, 3, opt.start_channel).cuda(device) # model =ModelFlow_stride(2,3,opt.start_channel).cpu() # transform = SpatialTransform().cpu() transform = SpatialTransform().cuda(device) # model.load_state_dict(torch.load(opt.modelpath)) # model.load_state_dict(torch.load(opt.modelpath, map_location=torch.device('cpu'))) model.load_state_dict(torch.load(opt.modelpath, map_location=device)) model.eval() transform.eval() grid = generate_grid(imgshape) grid = Variable(torch.from_numpy(np.reshape( grid, (1, ) + grid.shape))).cuda(device).float() # grid = Variable(torch.from_numpy(np.reshape(grid, (1,) + grid.shape))).cpu().float() start = timeit.default_timer() A = sitk.GetArrayFromImage(sitk.ReadImage(opt.fixed, sitk.sitkFloat32)) B = sitk.GetArrayFromImage(sitk.ReadImage(opt.moving, sitk.sitkFloat32)) #A = Variable(torch.from_numpy(A)).cuda(device).float() #B = Variable(torch.from_numpy(B)).cuda(device).float() # A, B = padding(A,B) #A = load_5D(A).cuda(device).float() #B = load_5D(B).cuda(device).float() A = load_5D(A) B = load_5D(B) A = Variable(torch.from_numpy(A)).cuda(device).float() B = Variable(torch.from_numpy(B)).cuda(device).float() #A = Variable(torch.from_numpy( load_5D(opt.fixed))).cuda(device).float() #B = Variable(torch.from_numpy( load_5D(opt.moving))).cuda(device).float() start2 = timeit.default_timer() print('Time for loading data: ', start2 - start) pred = model(A, B) F_AB = pred.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] #F_AB = pred.permute(0,2,3,4,1).data.cuda(device).numpy()[0, :, :, :, :] F_AB = F_AB.astype(np.float32) * range_flow warped_A = transform(A, pred.permute(0, 2, 3, 4, 1) * range_flow, grid).data.cpu().numpy()[0, 0, :, :, :] #warped_A = transform(A,pred.permute(0,2,3,4,1)*range_flow,grid).data.cuda(device).numpy()[0, 0, :, :, :] start3 = timeit.default_timer() print('Time for registration: ', start3 - start2) warped_F_BA = transform(-pred, pred.permute(0, 2, 3, 4, 1) * range_flow, grid).permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] #warped_F_BA = transform(-pred,pred.permute(0,2,3,4,1)*range_flow,grid).permute(0,2,3,4,1).data.cuda(device).numpy()[0, :, :, :, :] warped_F_BA = warped_F_BA.astype(np.float32) * range_flow start4 = timeit.default_timer() print('Time for generating inverse flow: ', start4 - start3) save_flow(F_AB, savepath + '/flow_A_B.nii.gz') save_flow(warped_F_BA, savepath + '/inverse_flow_B_A.nii.gz') save_img(warped_A, savepath + '/warped_A.nii.gz') start5 = timeit.default_timer() print('Time for saving results: ', start5 - start4) del pred pred = model(B, A) F_BA = pred.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] #F_BA = pred.permute(0,2,3,4,1).data.cuda(device).numpy()[0, :, :, :, :] F_BA = F_BA.astype(np.float32) * range_flow warped_B = transform(B, pred.permute(0, 2, 3, 4, 1) * range_flow, grid).data.cpu().numpy()[0, 0, :, :, :] #warped_B = transform(B,pred.permute(0,2,3,4,1)*range_flow,grid).data.cuda(device).numpy()[0, 0, :, :, :] warped_F_AB = transform(-pred, pred.permute(0, 2, 3, 4, 1) * range_flow, grid).permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :] #warped_F_AB = transform(-pred,pred.permute(0,2,3,4,1)*range_flow,grid).permute(0,2,3,4,1).data.cuda(device).numpy()[0, :, :, :, :] warped_F_AB = warped_F_AB.astype(np.float32) * range_flow save_flow(F_BA, savepath + '/flow_B_A.nii.gz') save_flow(warped_F_AB, savepath + '/inverse_flow_A_B.nii.gz') save_img(warped_B, savepath + '/warped_B.nii.gz')
def train(): model = ModelFlow_stride(2, 3, start_channel).cuda() loss_similarity = NCC().loss loss_cycle = l1_loss loss_smooth = smoothloss transform = SpatialTransform().cuda() for param in transform.parameters(): param.requires_grad = False param.volatile = True names = glob.glob(datapath + '/*.gz') grid = generate_grid(imgshape) grid = Variable(torch.from_numpy(np.reshape(grid, (1, ) + grid.shape))).cuda().float() print(grid.type()) optimizer = torch.optim.Adam(model.parameters(), lr=lr) model_dir = '../Model' if not os.path.isdir(model_dir): os.mkdir(model_dir) lossall = np.zeros((5, iteration)) training_generator = Data.DataLoader(Dataset(names, iteration, False), batch_size=1, shuffle=False, num_workers=2) step = 0 for X, Y in training_generator: X = X.cuda().float() Y = Y.cuda().float() F_xy = model(X, Y) F_yx = model(Y, X) X_Y = transform(X, F_xy.permute(0, 2, 3, 4, 1) * range_flow, grid) Y_X = transform(Y, F_yx.permute(0, 2, 3, 4, 1) * range_flow, grid) F_xy_ = model(Y_X, X_Y) F_yx_ = model(X_Y, Y_X) Y_X_Y = transform(Y_X, F_xy_.permute(0, 2, 3, 4, 1) * range_flow, grid) X_Y_X = transform(X_Y, F_yx_.permute(0, 2, 3, 4, 1) * range_flow, grid) F_xx = model(X, X) F_yy = model(Y, Y) X_X = transform(X, F_xx.permute(0, 2, 3, 4, 1) * range_flow, grid) Y_Y = transform(Y, F_yy.permute(0, 2, 3, 4, 1) * range_flow, grid) L_smooth = loss_smooth(F_xy * range_flow) + loss_smooth( F_yx * range_flow) L_regist = loss_similarity(Y, X_Y) + loss_similarity(X, Y_X) + \ lambda_ * L_smooth L_cycle = loss_cycle(X, X_Y_X) + loss_cycle(Y, Y_X_Y) L_identity = loss_similarity(X, X_X) + loss_similarity(Y, Y_Y) loss = L_regist + alpha * L_cycle + beta * L_identity optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients lossall[:, step] = np.array([ loss.data[0], L_regist.data[0], L_cycle.data[0], L_identity.data[0], L_smooth.data[0] ]) sys.stdout.write( "\r" + 'step "{0}" -> training loss "{1:.4f}" - reg "{2:.4f}" - cyc "{3:.4f}" - ind "{4:.4f}" -smo "{5:.4f}" ' .format(step, loss.data[0], L_regist.data[0], L_cycle.data[0], L_identity.data[0], L_smooth.data[0])) sys.stdout.flush() if (step % n_checkpoint == 0): modelname = model_dir + '/' + str(step) + '.pth' torch.save(model.state_dict(), modelname) step += 1 np.save(model_dir + '/loss.npy', lossall)
def train(lvlID, opt=[], model_lvl1_path="", model_lvl2_path=""): print("Training " + str(lvlID) + "===========================================================") model_dir = '../Model/Stage' if not os.path.isdir(model_dir): os.mkdir(model_dir) result_folder_path = "../Results" logLvlPath = model_dir + "/logLvl_" + str(lvlID) + ".txt" #logLvlChrtPath = model_dir + "/logLvl_"+str(lvlID)+".png" numWorkers = 0 # 2 number of threads for the data generators??? #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') freeze_step = opt.freeze_step # TODO: ??? lossName = "_NCC_" if opt.simLossType == 0 else ( "_MSE_" if opt.simLossType == 1 else "_DICE_") model_name = "LDR_OASIS" + lossName + "_disp_" + str( opt.start_channel) + "_" + str(opt.iteration_lvl1) + "_f" + str( opt.iteration_lvl2) + "_f" + str(opt.iteration_lvl3) + "_" device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str( lvlID) + "_0.pth" loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str( lvlID) + "_0.npy" n_checkpoint = opt.checkpoint if lvlID == 1: model = Miccai2020_LDR_laplacian_unit_disp_add_lvl1( in_channel, n_classes, start_channel, is_train=isTrainLvl1, imgshape=imgshape_4, range_flow=range_flow).to(device) grid = generate_grid(imgshape_4) start_iteration = opt.sIteration_lvl1 num_iteration = opt.iteration_lvl1 elif lvlID == 2: model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1( in_channel, n_classes, start_channel, is_train=isTrainLvl1, imgshape=imgshape_4, range_flow=range_flow).to(device) model_lvl1.load_state_dict(torch.load(model_lvl1_path)) # Freeze model_lvl1 weight for param in model_lvl1.parameters(): param.requires_grad = False model = Miccai2020_LDR_laplacian_unit_disp_add_lvl2( in_channel, n_classes, start_channel, is_train=isTrainLvl2, imgshape=imgshape_2, range_flow=range_flow, model_lvl1=model_lvl1).to(device) grid = generate_grid(imgshape_2) start_iteration = opt.sIteration_lvl2 num_iteration = opt.iteration_lvl2 elif lvlID == 3: model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1( in_channel, n_classes, start_channel, is_train=isTrainLvl1, imgshape=imgshape_4, range_flow=range_flow).to(device) model_lvl2 = Miccai2020_LDR_laplacian_unit_disp_add_lvl2( in_channel, n_classes, start_channel, is_train=isTrainLvl2, imgshape=imgshape_2, range_flow=range_flow, model_lvl1=model_lvl1).to(device) model_lvl2.load_state_dict(torch.load(model_lvl2_path)) # Freeze model_lvl1 weight for param in model_lvl2.parameters(): param.requires_grad = False model = Miccai2020_LDR_laplacian_unit_disp_add_lvl3( in_channel, n_classes, start_channel, is_train=isTrainLvl3, imgshape=imgshape, range_flow=range_flow, model_lvl2=model_lvl2).to(device) grid = generate_grid(imgshape) start_iteration = opt.sIteration_lvl3 num_iteration = opt.iteration_lvl3 load_model_lvl = True if start_iteration > 0 else False loss_Jdet = neg_Jdet_loss loss_smooth = smoothloss transform = SpatialTransform_unit().to(device) for param in transform.parameters(): param.requires_grad = False # TODO: ??? param.volatile = True # TODO: ??? grid = torch.from_numpy(np.reshape(grid, (1, ) + grid.shape)).to(device).float() optimizer = torch.optim.Adam(model.parameters(), lr=lr) # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) lossall = np.zeros((4, num_iteration + 1)) #TODO: improve the data generator: # - use fixed lists for training and testing # - use augmentation # names, norm=1, aug=1, isSeg=0 , new_size=[0,0,0]): training_generator = Data.DataLoader(Dataset_epoch(trainingLst, norm=doNormalisation, aug=opt.doAugmentation, isSeg=opt.isSeg), batch_size=1, shuffle=True, num_workers=numWorkers) step = 0 if start_iteration > 0: model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str( lvlID) + "_" + str(num_iteration) + '.pth' loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str( lvlID) + "_" + str(num_iteration) + '.npy' print("Loading weight and loss : ", model_lvl_path) step = num_iteration + 1 model.load_state_dict(torch.load(model_lvl_path)) temp_lossall = np.load(loss_lvl_path) lossall[:, 0:num_iteration] = temp_lossall[:, 0:num_iteration] else: #create log file only when logLvlFile = open(logLvlPath, "w") logLvlFile.close stepsLst = [] lossLst = [] simNCCLst = [] JdetLst = [] smoLst = [] # for each iteration #TODO: modify the iteration to be related to the number of images while step <= num_iteration: #for each pair in the data generator for pair in training_generator: X = pair[0][0] Y = pair[0][1] movingPath = pair[1][0][0] fixedPath = pair[1][1][0] ext = ".nii.gz" if ".nii.gz" in fixedPath else ( ".nii" if ".nii" in fixedPath else ".nrrd") X = X.to(device).float() Y = Y.to(device).float() assert not np.any(np.isnan(X.cpu().numpy())) assert not np.any(np.isnan(Y.cpu().numpy())) # output_disp_e0, warpped_inputx_lvl1_out, down_y, output_disp_e0_v, e0 # F_X_Y: displacement_field, # X_Y: wrapped_moving_image, # Y_4x: downsampled_fixed_image, # F_xy: velocity_field if lvlID == 1: F_X_Y, X_Y, Y_4x, F_xy, _ = model(X, Y) elif lvlID == 2: F_X_Y, X_Y, Y_4x, F_xy, F_xy_lvl1, _ = model(X, Y) elif lvlID == 3: F_X_Y, X_Y, Y_4x, F_xy, F_xy_lvl1, F_xy_lvl2, _ = model(X, Y) # print("Y_4x shape : ",Y_4x.shape) # print("X_Y shape : ",X_Y.shape) if opt.simLossType == 0: # NCC if lvlID == 1: loss_similarity = NCC(win=3) elif lvlID == 2: loss_similarity = multi_resolution_NCC(win=5, scale=2) elif lvlID == 3: loss_similarity = multi_resolution_NCC(win=7, scale=3) loss_sim = loss_similarity(X_Y, Y_4x) elif opt.simLossType == 1: # mse loss #loss_sim = mseLoss(X_Y, Y_4x) loss_sim = mse_loss(X_Y, Y_4x) #print("loss_sim : ",loss_sim) elif opt.simLossType == 2: # Dice loss # transform seg dv = math.pow(2, 3 - lvlID) fixedSeg = img2SegTensor(fixedPath, ext, dv) movingSeg = img2SegTensor(movingPath, ext, dv) movingSeg = movingSeg[np.newaxis, ...] movingSeg = torch.from_numpy(movingSeg).float().to( device).unsqueeze(dim=0) transformedSeg = transform(movingSeg, F_X_Y.permute(0, 2, 3, 4, 1), grid).data.cpu().numpy()[0, 0, :, :, :] transformedSeg[transformedSeg > 0] = 1.0 loss_sim = diceLoss(fixedSeg, transformedSeg) loss_sim = DiceLoss.getDiceLoss(fixedSeg, transformedSeg) else: print("error: not supported loss ........") # 3 level deep supervision NCC F_X_Y_norm = transform_unit_flow_to_flow_cuda( F_X_Y.permute(0, 2, 3, 4, 1).clone()) loss_Jacobian = loss_Jdet(F_X_Y_norm, grid) # reg2 - use velocity _, _, x, y, z = F_X_Y.shape F_X_Y[:, 0, :, :, :] = F_X_Y[:, 0, :, :, :] * z F_X_Y[:, 1, :, :, :] = F_X_Y[:, 1, :, :, :] * y F_X_Y[:, 2, :, :, :] = F_X_Y[:, 2, :, :, :] * x loss_regulation = loss_smooth(F_X_Y) assert not np.any(np.isnan(loss_sim.item())) assert not np.any(np.isnan(loss_Jacobian.item())) assert not np.any(np.isnan(loss_regulation.item())) loss = loss_sim + antifold * loss_Jacobian + smooth * loss_regulation assert not np.any(np.isnan(loss.item())) # TODO: ??? why clearing optimiser evey new example? optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients lossall[:, step] = np.array([ loss.item(), loss_sim.item(), loss_Jacobian.item(), loss_regulation.item() ]) logLine = "\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:4f}" - Jdet "{3:.10f}" -smo "{4:.4f}"'.format( step, loss.item(), loss_sim.item(), loss_Jacobian.item(), loss_regulation.item()) #sys.stdout.write(logLine) #sys.stdout.flush() print(logLine) logLvlFile = open(logLvlPath, "a") logLvlFile.write(logLine) logLvlFile.close() # with lr 1e-3 + with bias if lvlID == 3: n_checkpoint = 10 if (step % n_checkpoint == 0): model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str( lvlID) + "_" + str(step) + '.pth' loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str( lvlID) + "_" + str(step) + '.npy' torch.save(model.state_dict(), model_lvl_path) # np.save(loss_lvl_path, lossall) iaLog2Fig(logLvlPath) if doValidation: iaVal = validate(model) if (lvlID == 3) and (step == freeze_step): model.unfreeze_modellvl2() step += 1 if step > num_iteration: break print("one epoch pass ....") model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str( lvlID) + "_" + str(num_iteration) + '.pth' loss_lvl_path = model_dir + '/' + 'loss' + model_name + "stagelvl" + str( lvlID) + "_" + str(num_iteration) + '.npy' torch.save(model.state_dict(), model_lvl_path) #np.save(loss_lvl_path, lossall) return model_lvl_path
def train(): model = SYMNet(2, 3, start_channel).cuda() loss_similarity = NCC() loss_smooth = smoothloss loss_magnitude = magnitude_loss loss_Jdet = neg_Jdet_loss transform = SpatialTransform().cuda() diff_transform = DiffeomorphicTransform(time_step=7).cuda() com_transform = CompositionTransform().cuda() for param in transform.parameters(): param.requires_grad = False param.volatile = True names = sorted(glob.glob(datapath + '/*.nii'))[0:255] grid = generate_grid(imgshape) grid = torch.from_numpy(np.reshape(grid, (1, ) + grid.shape)).cuda().float() optimizer = torch.optim.Adam(model.parameters(), lr=lr) # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) model_dir = '../Model' if not os.path.isdir(model_dir): os.mkdir(model_dir) lossall = np.zeros((6, iteration)) training_generator = Data.DataLoader(Dataset_epoch(names, norm=False), batch_size=1, shuffle=True, num_workers=2) step = 0 while step <= iteration: for X, Y in training_generator: X = X.cuda().float() Y = Y.cuda().float() F_xy, F_yx = model(X, Y) F_X_Y_half = diff_transform(F_xy, grid, range_flow) F_Y_X_half = diff_transform(F_yx, grid, range_flow) F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow) F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow) X_Y_half = transform( X, F_X_Y_half.permute(0, 2, 3, 4, 1) * range_flow, grid) Y_X_half = transform( Y, F_Y_X_half.permute(0, 2, 3, 4, 1) * range_flow, grid) F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow) F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow) X_Y = transform(X, F_X_Y.permute(0, 2, 3, 4, 1) * range_flow, grid) Y_X = transform(Y, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid) loss1 = loss_similarity(X_Y_half, Y_X_half) loss2 = loss_similarity(Y, X_Y) + loss_similarity(X, Y_X) loss3 = loss_magnitude(F_X_Y_half * range_flow, F_Y_X_half * range_flow) loss4 = loss_Jdet( F_X_Y.permute(0, 2, 3, 4, 1) * range_flow, grid) + loss_Jdet( F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid) loss5 = loss_smooth(F_xy * range_flow) + loss_smooth( F_yx * range_flow) loss = loss1 + loss2 + magnitude * loss3 + local_ori * loss4 + smooth * loss5 optimizer.zero_grad() loss.backward() optimizer.step() lossall[:, step] = np.array([ loss.item(), loss1.item(), loss2.item(), loss3.item(), loss4.item(), loss5.item() ]) sys.stdout.write( "\r" + 'step "{0}" -> training loss "{1:.4f}" - sim_mid "{2:.4f}" - sim_full "{3:4f}" - mag "{4:.4f}" - Jdet "{5:.10f}" -smo "{6:.4f}" ' .format(step, loss.item(), loss1.item(), loss2.item(), loss3.item(), loss4.item(), loss5.item())) sys.stdout.flush() if (step % n_checkpoint == 0): modelname = model_dir + '/SYMNet_' + str(step) + '.pth' torch.save(model.state_dict(), modelname) np.save(model_dir + '/loss_SYMNet_' + str(step) + '.npy', lossall) step += 1 if step > iteration: break print("one epoch pass") np.save(model_dir + '/loss_SYMNet.npy', lossall)