示例#1
0
def test():
    model = ModelFlow_stride(2, 3, opt.start_channel).cuda()
    transform = SpatialTransform().cuda()

    model.load_state_dict(torch.load(opt.modelpath))
    model.eval()
    transform.eval()

    grid = generate_grid(imgshape)
    grid = Variable(torch.from_numpy(np.reshape(grid, (1, ) +
                                                grid.shape))).cuda().float()

    start = timeit.default_timer()

    A = Variable(torch.from_numpy(load_5D(opt.fixed))).cuda().float()
    B = Variable(torch.from_numpy(load_5D(opt.moving))).cuda().float()
    start2 = timeit.default_timer()
    print('Time for loading data: ', start2 - start)
    pred = model(A, B)
    F_AB = pred.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
    F_AB = F_AB.astype(np.float32) * range_flow
    warped_A = transform(A,
                         pred.permute(0, 2, 3, 4, 1) * range_flow,
                         grid).data.cpu().numpy()[0, 0, :, :, :]
    start3 = timeit.default_timer()
    print('Time for registration: ', start3 - start2)

    warped_F_BA = transform(-pred,
                            pred.permute(0, 2, 3, 4, 1) * range_flow,
                            grid).permute(0, 2, 3, 4,
                                          1).data.cpu().numpy()[0, :, :, :, :]
    warped_F_BA = warped_F_BA.astype(np.float32) * range_flow

    start4 = timeit.default_timer()
    print('Time for generating inverse flow: ', start4 - start3)

    save_flow(F_AB, savepath + '/flow_A_B.nii.gz')
    save_flow(warped_F_BA, savepath + '/inverse_flow_B_A.nii.gz')
    save_img(warped_A, savepath + '/warped_A.nii.gz')

    start5 = timeit.default_timer()
    print('Time for saving results: ', start5 - start4)
    del pred

    pred = model(B, A)
    F_BA = pred.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
    F_BA = F_BA.astype(np.float32) * range_flow
    warped_B = transform(B,
                         pred.permute(0, 2, 3, 4, 1) * range_flow,
                         grid).data.cpu().numpy()[0, 0, :, :, :]
    warped_F_AB = transform(-pred,
                            pred.permute(0, 2, 3, 4, 1) * range_flow,
                            grid).permute(0, 2, 3, 4,
                                          1).data.cpu().numpy()[0, :, :, :, :]
    warped_F_AB = warped_F_AB.astype(np.float32) * range_flow
    save_flow(F_BA, savepath + '/flow_B_A.nii.gz')
    save_flow(warped_F_AB, savepath + '/inverse_flow_A_B.nii.gz')
    save_img(warped_B, savepath + '/warped_B.nii.gz')
示例#2
0
def train():
    model =ModelFlow_stride(2,3,start_channel).cuda()
    loss_similarity =mse_loss
    loss_inverse = mse_loss
    loss_antifold = antifoldloss
    loss_smooth = smoothloss
    transform = SpatialTransform().cuda()
    for param in transform.parameters():
        param.requires_grad = False
        param.volatile=True
    names = glob.glob(datapath + '/*.gz')
    grid = generate_grid(imgshape)
    grid = Variable(torch.from_numpy(np.reshape(grid, (1,) + grid.shape))).cuda().float()

    print(grid.type())
    optimizer = torch.optim.Adam(model.parameters(),lr=lr) 
    model_dir = '../Model'

    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    lossall = np.zeros((5,iteration))
    training_generator = Data.DataLoader(Dataset(names,iteration,True), batch_size=1,
                        shuffle=False, num_workers=2)
    step=0
    for  X,Y in training_generator:

        X = X.cuda().float()
        Y = Y.cuda().float()
        F_xy = model(X,Y)
        F_yx = model(Y,X)
    
        X_Y = transform(X,F_xy.permute(0,2,3,4,1)*range_flow,grid)
        Y_X = transform(Y,F_yx.permute(0,2,3,4,1)*range_flow,grid)
        # Note that, the generation of inverse flow depends on the definition of transform. 
        # The generation strategies are sligtly different for the backward warpping and forward warpping
        F_xy_ = transform(-F_yx,F_xy.permute(0,2,3,4,1)*range_flow,grid)
        F_yx_ = transform(-F_xy,F_yx.permute(0,2,3,4,1)*range_flow,grid)
        loss1 = loss_similarity(Y,X_Y) + loss_similarity(X,Y_X)
        loss2 = loss_inverse(F_xy*range_flow,F_xy_*range_flow) + loss_inverse(F_yx*range_flow,F_yx_*range_flow)
        
        
        loss3 =  loss_antifold(F_xy*range_flow) + loss_antifold(F_yx*range_flow)
        loss4 =  loss_smooth(F_xy*range_flow) + loss_smooth(F_yx*range_flow)
        loss = loss1+inverse*loss2 + antifold*loss3 + smooth*loss4
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients
        lossall[:,step] = np.array([loss.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item()])
        sys.stdout.write("\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:.4f}" - inv "{3:.4f}" \
            - ant "{4:.4f}" -smo "{5:.4f}" '.format(step, loss.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item()))
        sys.stdout.flush()
        if(step % n_checkpoint == 0):
            modelname = model_dir + '/' + str(step) + '.pth'
            torch.save(model.state_dict(), modelname)
        step+=1
    np.save(model_dir+'/loss.npy',lossall)
def test():
    model = SYMNet(2, 3, opt.start_channel).cuda()
    transform = SpatialTransform().cuda()

    diff_transform = DiffeomorphicTransform(time_step=7).cuda()
    com_transform = CompositionTransform().cuda()

    model.load_state_dict(torch.load(opt.modelpath))
    model.eval()
    transform.eval()
    diff_transform.eval()
    com_transform.eval()

    grid = generate_grid(imgshape)
    grid = torch.from_numpy(np.reshape(grid,
                                       (1, ) + grid.shape)).cuda().float()

    use_cuda = True
    device = torch.device("cuda" if use_cuda else "cpu")

    fixed_img = load_4D(fixed_path)
    moved_img = load_4D(moving_path)

    fixed_img = torch.from_numpy(fixed_img).float().to(device).unsqueeze(dim=0)
    moved_img = torch.from_numpy(moved_img).float().to(device).unsqueeze(dim=0)

    with torch.no_grad():
        F_xy, F_yx = model(fixed_img, moved_img)

        F_X_Y_half = diff_transform(F_xy, grid, range_flow)
        F_Y_X_half = diff_transform(F_yx, grid, range_flow)

        F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow)
        F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow)

        F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow)
        F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow)

        F_BA = F_Y_X.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
        F_BA = F_BA.astype(np.float32) * range_flow

        F_AB = F_X_Y.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
        F_AB = F_AB.astype(np.float32) * range_flow

        warped_B = transform(moved_img,
                             F_Y_X.permute(0, 2, 3, 4, 1) * range_flow,
                             grid).data.cpu().numpy()[0, 0, :, :, :]
        warped_A = transform(fixed_img,
                             F_X_Y.permute(0, 2, 3, 4, 1) * range_flow,
                             grid).data.cpu().numpy()[0, 0, :, :, :]

        save_flow(F_BA, savepath + '/wrapped_flow_B_to_A.nii.gz')
        save_flow(F_AB, savepath + '/wrapped_flow_A_to_B.nii.gz')
        save_img(warped_B, savepath + '/wrapped_norm_B_to_A.nii.gz')
        save_img(warped_A, savepath + '/wrapped_norm_A_to_B.nii.gz')

        print("Finished.")
def train():
    torch.cuda.empty_cache()
    device = torch.device("cuda:6")
    model = ModelFlow_stride(2, 3, start_channel).cuda(device)
    # model =ModelFlow_stride(2,3,start_channel).cpu()
    loss_similarity = mse_loss
    loss_inverse = mse_loss
    loss_antifold = antifoldloss
    loss_smooth = smoothloss
    transform = SpatialTransform().cuda(device)
    # transform = SpatialTransform().cpu()
    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True
    names = glob.glob(datapath + '/*.gz')
    grid = generate_grid(imgshape)
    grid = Variable(torch.from_numpy(np.reshape(
        grid, (1, ) + grid.shape))).cuda(device).float()
    # grid = Variable(torch.from_numpy(np.reshape(grid, (1,) + grid.shape))).cpu().float()

    print(grid.type())
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model_dir = '../Model'

    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    lossall = np.zeros((5, iteration))
    training_generator = Data.DataLoader(Dataset(names, iteration, True),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=0)
    step = 0
    for X, Y in training_generator:
        X = X.cuda(device).float()
        # X = X.cpu().float()
        Y = Y.cuda(device).float()
        # Y = Y.cpu().float()
        # # X = sitk.GetArrayFromImage(sitk.ReadImage(X, sitk.sitkFloat32))
        # # Y = sitk.GetArrayFromImage(sitk.ReadImage(Y, sitk.sitkFloat32))
        #X = Variable(X).cuda(device).float()
        #Y = Variable(Y).cuda(device).float()
        X, Y = train_padding(X, Y)  #added today

        X = X.cuda(device).float(
        )  #added as the values returned after padding was not cuda
        Y = Y.cuda(device).float()

        F_xy = model(X, Y)
        F_yx = model(Y, X)
        X_Y = transform(X, F_xy.permute(0, 2, 3, 4, 1) * range_flow, grid)
        Y_X = transform(Y, F_yx.permute(0, 2, 3, 4, 1) * range_flow, grid)
        F_xy_ = transform(-F_xy,
                          F_xy.permute(0, 2, 3, 4, 1) * range_flow, grid)
        F_yx_ = transform(-F_yx,
                          F_yx.permute(0, 2, 3, 4, 1) * range_flow, grid)
        loss1 = loss_similarity(Y, X_Y) + loss_similarity(X, Y_X)
        loss2 = loss_inverse(F_xy * range_flow,
                             F_xy_ * range_flow) + loss_inverse(
                                 F_yx * range_flow, F_yx_ * range_flow)
        loss3 = loss_antifold(F_xy * range_flow) + loss_antifold(
            F_yx * range_flow)
        loss4 = loss_smooth(F_xy * range_flow) + loss_smooth(F_yx * range_flow)
        loss = loss1 + inverse * loss2 + antifold * loss3 + smooth * loss4
        optimizer.zero_grad()  # clear gradients for this training step
        loss.backward()  # backpropagation, compute gradients
        optimizer.step()  # apply gradients
        #lossall[:,step] = np.array([loss.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0]])
        lossall[:, step] = np.array(
            [loss.data, loss1.data, loss2.data, loss3.data, loss4.data])
        #sys.stdout.write("\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:.4f}" - inv "{3:.4f}" - ant "{4:.4f}" -smo "{5:.4f}" '.format(step, loss.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0]))
        sys.stdout.write(
            "\r" +
            'step "{0}" -> training loss "{1:.4f}" - sim "{2:.4f}" - inv "{3:.4f}" - ant "{4:.4f}" -smo "{5:.4f}" '
            .format(step, loss.data, loss1.data, loss2.data, loss3.data,
                    loss4.data))
        sys.stdout.flush()
        if (step % n_checkpoint == 0):
            modelname = model_dir + '/' + str(step) + '.pth'
            torch.save(model.state_dict(), modelname)
        step += 1
        torch.cuda.empty_cache()
        del X
        del Y
    np.save(model_dir + '/loss.npy', lossall)
示例#5
0
                                                         imgshape=imgshape_2, range_flow=range_flow,
                                                         model_lvl1=model_lvl1).cuda()
model = Miccai2020_LDR_laplacian_unit_disp_add_lvl3(in_channel, n_classes, start_channel, is_train=isTrainLvl3,
                                                    imgshape=imgshape, range_flow=range_flow,
                                                    model_lvl2=model_lvl2).cuda()

transform = SpatialTransform_unit().cuda()
model.load_state_dict(torch.load(opt.modelpath))
print("pretrained model is loaded: " + opt.modelpath)
model.eval()
transform.eval()

use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
print(" processor use: ", device)
grid = generate_grid(imgshape)
grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).cuda().float()
#grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).to(device).float()


def testImages(testingImagesLst):
    totalDice = 0.0
    testing_generator = Data.DataLoader(Dataset_epoch(testingLst, norm=doNormalisation, isSeg=isSeg), batch_size=1, shuffle=True,  num_workers=0)
    for imgPair in testing_generator:
        moving_path = imgPair[1][1][0]
        fixed_path  = imgPair[1][0][0]
        moving_tensor = imgPair[0][0]
        fixed_tensor  = imgPair[0][1]
        # print("moving_tensor.shape : ",moving_tensor.shape)
        # print("fixed_path  : ",fixed_path)
        # print("moving_path  : ", moving_path)
示例#6
0
def train_lvl1():
    print("Training lvl1...")
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
        2,
        3,
        start_channel,
        is_train=True,
        imgshape=imgshape_4,
        range_flow=range_flow).to(device)

    loss_similarity = NCC(win=3)
    loss_Jdet = neg_Jdet_loss
    loss_smooth = smoothloss

    transform = SpatialTransform_unit().to(device)

    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True

    # OASIS
    names = sorted(glob.glob(datapath + '/*.nii'))

    grid_4 = generate_grid(imgshape_4)
    grid_4 = torch.from_numpy(np.reshape(grid_4, (1, ) +
                                         grid_4.shape)).to(device).float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    model_dir = '../Model/Stage'

    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    lossall = np.zeros((4, iteration_lvl1 + 1))

    training_generator = Data.DataLoader(Dataset_epoch(names, norm=False),
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=2)
    step = 0
    load_model = False
    if load_model is True:
        model_path = "../Model/LDR_LPBA_NCC_lap_share_preact_1_05_3000.pth"
        print("Loading weight: ", model_path)
        step = 3000
        model.load_state_dict(torch.load(model_path))
        temp_lossall = np.load(
            "../Model/loss_LDR_LPBA_NCC_lap_share_preact_1_05_3000.npy")
        lossall[:, 0:3000] = temp_lossall[:, 0:3000]

    while step <= iteration_lvl1:
        for X, Y in training_generator:

            X = X.to(device).float()
            Y = Y.to(device).float()

            # output_disp_e0, warpped_inputx_lvl1_out, down_y, output_disp_e0_v, e0
            F_X_Y, X_Y, Y_4x, F_xy, _ = model(X, Y)

            # 3 level deep supervision NCC
            loss_multiNCC = loss_similarity(X_Y, Y_4x)

            F_X_Y_norm = transform_unit_flow_to_flow_cuda(
                F_X_Y.permute(0, 2, 3, 4, 1).clone())

            loss_Jacobian = loss_Jdet(F_X_Y_norm, grid_4)

            # reg2 - use velocity
            _, _, x, y, z = F_X_Y.shape
            F_X_Y[:, 0, :, :, :] = F_X_Y[:, 0, :, :, :] * z
            F_X_Y[:, 1, :, :, :] = F_X_Y[:, 1, :, :, :] * y
            F_X_Y[:, 2, :, :, :] = F_X_Y[:, 2, :, :, :] * x
            loss_regulation = loss_smooth(F_X_Y)

            loss = loss_multiNCC + antifold * loss_Jacobian + smooth * loss_regulation

            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            lossall[:, step] = np.array([
                loss.item(),
                loss_multiNCC.item(),
                loss_Jacobian.item(),
                loss_regulation.item()
            ])
            sys.stdout.write(
                "\r" +
                'step "{0}" -> training loss "{1:.4f}" - sim_NCC "{2:4f}" - Jdet "{3:.10f}" -smo "{4:.4f}"'
                .format(step, loss.item(), loss_multiNCC.item(),
                        loss_Jacobian.item(), loss_regulation.item()))
            sys.stdout.flush()

            # with lr 1e-3 + with bias
            if (step % n_checkpoint == 0):
                modelname = model_dir + '/' + model_name + "stagelvl1_" + str(
                    step) + '.pth'
                torch.save(model.state_dict(), modelname)
                np.save(
                    model_dir + '/loss' + model_name + "stagelvl1_" +
                    str(step) + '.npy', lossall)

            step += 1

            if step > iteration_lvl1:
                break
        print("one epoch pass")
    np.save(model_dir + '/loss' + model_name + 'stagelvl1.npy', lossall)
示例#7
0
def test():
    device = torch.device("cuda:6")
    model = ModelFlow_stride(2, 3, opt.start_channel).cuda(device)
    # model =ModelFlow_stride(2,3,opt.start_channel).cpu()
    # transform = SpatialTransform().cpu()

    transform = SpatialTransform().cuda(device)
    # model.load_state_dict(torch.load(opt.modelpath))
    # model.load_state_dict(torch.load(opt.modelpath, map_location=torch.device('cpu')))
    model.load_state_dict(torch.load(opt.modelpath, map_location=device))

    model.eval()
    transform.eval()
    grid = generate_grid(imgshape)
    grid = Variable(torch.from_numpy(np.reshape(
        grid, (1, ) + grid.shape))).cuda(device).float()
    # grid = Variable(torch.from_numpy(np.reshape(grid, (1,) + grid.shape))).cpu().float()

    start = timeit.default_timer()
    A = sitk.GetArrayFromImage(sitk.ReadImage(opt.fixed, sitk.sitkFloat32))
    B = sitk.GetArrayFromImage(sitk.ReadImage(opt.moving, sitk.sitkFloat32))

    #A = Variable(torch.from_numpy(A)).cuda(device).float()
    #B = Variable(torch.from_numpy(B)).cuda(device).float()

    # A, B = padding(A,B)
    #A = load_5D(A).cuda(device).float()
    #B = load_5D(B).cuda(device).float()

    A = load_5D(A)
    B = load_5D(B)

    A = Variable(torch.from_numpy(A)).cuda(device).float()
    B = Variable(torch.from_numpy(B)).cuda(device).float()

    #A = Variable(torch.from_numpy( load_5D(opt.fixed))).cuda(device).float()
    #B = Variable(torch.from_numpy( load_5D(opt.moving))).cuda(device).float()
    start2 = timeit.default_timer()
    print('Time for loading data: ', start2 - start)
    pred = model(A, B)
    F_AB = pred.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
    #F_AB = pred.permute(0,2,3,4,1).data.cuda(device).numpy()[0, :, :, :, :]

    F_AB = F_AB.astype(np.float32) * range_flow
    warped_A = transform(A,
                         pred.permute(0, 2, 3, 4, 1) * range_flow,
                         grid).data.cpu().numpy()[0, 0, :, :, :]
    #warped_A = transform(A,pred.permute(0,2,3,4,1)*range_flow,grid).data.cuda(device).numpy()[0, 0, :, :, :]
    start3 = timeit.default_timer()
    print('Time for registration: ', start3 - start2)
    warped_F_BA = transform(-pred,
                            pred.permute(0, 2, 3, 4, 1) * range_flow,
                            grid).permute(0, 2, 3, 4,
                                          1).data.cpu().numpy()[0, :, :, :, :]
    #warped_F_BA = transform(-pred,pred.permute(0,2,3,4,1)*range_flow,grid).permute(0,2,3,4,1).data.cuda(device).numpy()[0, :, :, :, :]
    warped_F_BA = warped_F_BA.astype(np.float32) * range_flow
    start4 = timeit.default_timer()
    print('Time for generating inverse flow: ', start4 - start3)
    save_flow(F_AB, savepath + '/flow_A_B.nii.gz')
    save_flow(warped_F_BA, savepath + '/inverse_flow_B_A.nii.gz')
    save_img(warped_A, savepath + '/warped_A.nii.gz')
    start5 = timeit.default_timer()
    print('Time for saving results: ', start5 - start4)
    del pred
    pred = model(B, A)
    F_BA = pred.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
    #F_BA = pred.permute(0,2,3,4,1).data.cuda(device).numpy()[0, :, :, :, :]
    F_BA = F_BA.astype(np.float32) * range_flow
    warped_B = transform(B,
                         pred.permute(0, 2, 3, 4, 1) * range_flow,
                         grid).data.cpu().numpy()[0, 0, :, :, :]
    #warped_B = transform(B,pred.permute(0,2,3,4,1)*range_flow,grid).data.cuda(device).numpy()[0, 0, :, :, :]
    warped_F_AB = transform(-pred,
                            pred.permute(0, 2, 3, 4, 1) * range_flow,
                            grid).permute(0, 2, 3, 4,
                                          1).data.cpu().numpy()[0, :, :, :, :]
    #warped_F_AB = transform(-pred,pred.permute(0,2,3,4,1)*range_flow,grid).permute(0,2,3,4,1).data.cuda(device).numpy()[0, :, :, :, :]
    warped_F_AB = warped_F_AB.astype(np.float32) * range_flow
    save_flow(F_BA, savepath + '/flow_B_A.nii.gz')
    save_flow(warped_F_AB, savepath + '/inverse_flow_A_B.nii.gz')
    save_img(warped_B, savepath + '/warped_B.nii.gz')
示例#8
0
def train():
    model = ModelFlow_stride(2, 3, start_channel).cuda()
    loss_similarity = NCC().loss
    loss_cycle = l1_loss
    loss_smooth = smoothloss
    transform = SpatialTransform().cuda()
    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True
    names = glob.glob(datapath + '/*.gz')
    grid = generate_grid(imgshape)
    grid = Variable(torch.from_numpy(np.reshape(grid, (1, ) +
                                                grid.shape))).cuda().float()

    print(grid.type())
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model_dir = '../Model'

    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)
    lossall = np.zeros((5, iteration))
    training_generator = Data.DataLoader(Dataset(names, iteration, False),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=2)
    step = 0
    for X, Y in training_generator:

        X = X.cuda().float()
        Y = Y.cuda().float()
        F_xy = model(X, Y)
        F_yx = model(Y, X)

        X_Y = transform(X, F_xy.permute(0, 2, 3, 4, 1) * range_flow, grid)
        Y_X = transform(Y, F_yx.permute(0, 2, 3, 4, 1) * range_flow, grid)

        F_xy_ = model(Y_X, X_Y)
        F_yx_ = model(X_Y, Y_X)

        Y_X_Y = transform(Y_X, F_xy_.permute(0, 2, 3, 4, 1) * range_flow, grid)
        X_Y_X = transform(X_Y, F_yx_.permute(0, 2, 3, 4, 1) * range_flow, grid)

        F_xx = model(X, X)
        F_yy = model(Y, Y)
        X_X = transform(X, F_xx.permute(0, 2, 3, 4, 1) * range_flow, grid)
        Y_Y = transform(Y, F_yy.permute(0, 2, 3, 4, 1) * range_flow, grid)

        L_smooth = loss_smooth(F_xy * range_flow) + loss_smooth(
            F_yx * range_flow)
        L_regist = loss_similarity(Y, X_Y) + loss_similarity(X, Y_X) + \
                      lambda_ * L_smooth
        L_cycle = loss_cycle(X, X_Y_X) + loss_cycle(Y, Y_X_Y)
        L_identity = loss_similarity(X, X_X) + loss_similarity(Y, Y_Y)
        loss = L_regist + alpha * L_cycle + beta * L_identity

        optimizer.zero_grad()  # clear gradients for this training step
        loss.backward()  # backpropagation, compute gradients
        optimizer.step()  # apply gradients
        lossall[:, step] = np.array([
            loss.data[0], L_regist.data[0], L_cycle.data[0],
            L_identity.data[0], L_smooth.data[0]
        ])
        sys.stdout.write(
            "\r" +
            'step "{0}" -> training loss "{1:.4f}" - reg "{2:.4f}" - cyc "{3:.4f}" - ind "{4:.4f}" -smo "{5:.4f}" '
            .format(step, loss.data[0], L_regist.data[0], L_cycle.data[0],
                    L_identity.data[0], L_smooth.data[0]))
        sys.stdout.flush()
        if (step % n_checkpoint == 0):
            modelname = model_dir + '/' + str(step) + '.pth'
            torch.save(model.state_dict(), modelname)
        step += 1
    np.save(model_dir + '/loss.npy', lossall)
示例#9
0
def train(lvlID, opt=[], model_lvl1_path="", model_lvl2_path=""):
    print("Training " + str(lvlID) +
          "===========================================================")

    model_dir = '../Model/Stage'
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    result_folder_path = "../Results"
    logLvlPath = model_dir + "/logLvl_" + str(lvlID) + ".txt"
    #logLvlChrtPath = model_dir + "/logLvl_"+str(lvlID)+".png"

    numWorkers = 0  # 2  number of threads for the data generators???

    #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    freeze_step = opt.freeze_step  # TODO:  ???

    lossName = "_NCC_" if opt.simLossType == 0 else (
        "_MSE_" if opt.simLossType == 1 else "_DICE_")
    model_name = "LDR_OASIS" + lossName + "_disp_" + str(
        opt.start_channel) + "_" + str(opt.iteration_lvl1) + "_f" + str(
            opt.iteration_lvl2) + "_f" + str(opt.iteration_lvl3) + "_"
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
        lvlID) + "_0.pth"
    loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
        lvlID) + "_0.npy"

    n_checkpoint = opt.checkpoint

    if lvlID == 1:
        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        grid = generate_grid(imgshape_4)

        start_iteration = opt.sIteration_lvl1
        num_iteration = opt.iteration_lvl1
    elif lvlID == 2:
        model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        model_lvl1.load_state_dict(torch.load(model_lvl1_path))
        # Freeze model_lvl1 weight
        for param in model_lvl1.parameters():
            param.requires_grad = False

        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl2(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl2,
            imgshape=imgshape_2,
            range_flow=range_flow,
            model_lvl1=model_lvl1).to(device)

        grid = generate_grid(imgshape_2)
        start_iteration = opt.sIteration_lvl2
        num_iteration = opt.iteration_lvl2
    elif lvlID == 3:
        model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        model_lvl2 = Miccai2020_LDR_laplacian_unit_disp_add_lvl2(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl2,
            imgshape=imgshape_2,
            range_flow=range_flow,
            model_lvl1=model_lvl1).to(device)
        model_lvl2.load_state_dict(torch.load(model_lvl2_path))
        # Freeze model_lvl1 weight
        for param in model_lvl2.parameters():
            param.requires_grad = False

        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl3(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl3,
            imgshape=imgshape,
            range_flow=range_flow,
            model_lvl2=model_lvl2).to(device)

        grid = generate_grid(imgshape)
        start_iteration = opt.sIteration_lvl3
        num_iteration = opt.iteration_lvl3

    load_model_lvl = True if start_iteration > 0 else False

    loss_Jdet = neg_Jdet_loss
    loss_smooth = smoothloss

    transform = SpatialTransform_unit().to(device)

    for param in transform.parameters():
        param.requires_grad = False  # TODO: ???
        param.volatile = True  # TODO: ???

    grid = torch.from_numpy(np.reshape(grid,
                                       (1, ) + grid.shape)).to(device).float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    lossall = np.zeros((4, num_iteration + 1))

    #TODO: improve the data generator:
    #  - use fixed lists for training and testing
    #  - use augmentation

    # names, norm=1, aug=1, isSeg=0 , new_size=[0,0,0]):
    training_generator = Data.DataLoader(Dataset_epoch(trainingLst,
                                                       norm=doNormalisation,
                                                       aug=opt.doAugmentation,
                                                       isSeg=opt.isSeg),
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=numWorkers)

    step = 0
    if start_iteration > 0:
        model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
            lvlID) + "_" + str(num_iteration) + '.pth'
        loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
            lvlID) + "_" + str(num_iteration) + '.npy'
        print("Loading weight and loss : ", model_lvl_path)
        step = num_iteration + 1
        model.load_state_dict(torch.load(model_lvl_path))
        temp_lossall = np.load(loss_lvl_path)
        lossall[:, 0:num_iteration] = temp_lossall[:, 0:num_iteration]
    else:
        #create log file only when
        logLvlFile = open(logLvlPath, "w")
        logLvlFile.close

    stepsLst = []
    lossLst = []
    simNCCLst = []
    JdetLst = []
    smoLst = []

    # for each iteration
    #TODO: modify the iteration to be related to the number of images
    while step <= num_iteration:
        #for each pair in the data generator
        for pair in training_generator:
            X = pair[0][0]
            Y = pair[0][1]
            movingPath = pair[1][0][0]
            fixedPath = pair[1][1][0]
            ext = ".nii.gz" if ".nii.gz" in fixedPath else (
                ".nii" if ".nii" in fixedPath else ".nrrd")

            X = X.to(device).float()
            Y = Y.to(device).float()
            assert not np.any(np.isnan(X.cpu().numpy()))
            assert not np.any(np.isnan(Y.cpu().numpy()))

            # output_disp_e0, warpped_inputx_lvl1_out, down_y, output_disp_e0_v, e0
            # F_X_Y: displacement_field,
            # X_Y: wrapped_moving_image,
            # Y_4x: downsampled_fixed_image,
            # F_xy: velocity_field
            if lvlID == 1:
                F_X_Y, X_Y, Y_4x, F_xy, _ = model(X, Y)
            elif lvlID == 2:
                F_X_Y, X_Y, Y_4x, F_xy, F_xy_lvl1, _ = model(X, Y)
            elif lvlID == 3:
                F_X_Y, X_Y, Y_4x, F_xy, F_xy_lvl1, F_xy_lvl2, _ = model(X, Y)

            # print("Y_4x shape : ",Y_4x.shape)
            # print("X_Y shape  : ",X_Y.shape)
            if opt.simLossType == 0:  # NCC
                if lvlID == 1:
                    loss_similarity = NCC(win=3)
                elif lvlID == 2:
                    loss_similarity = multi_resolution_NCC(win=5, scale=2)
                elif lvlID == 3:
                    loss_similarity = multi_resolution_NCC(win=7, scale=3)
                loss_sim = loss_similarity(X_Y, Y_4x)
            elif opt.simLossType == 1:  # mse loss
                #loss_sim = mseLoss(X_Y, Y_4x)
                loss_sim = mse_loss(X_Y, Y_4x)
                #print("loss_sim : ",loss_sim)
            elif opt.simLossType == 2:  # Dice loss
                # transform seg
                dv = math.pow(2, 3 - lvlID)
                fixedSeg = img2SegTensor(fixedPath, ext, dv)
                movingSeg = img2SegTensor(movingPath, ext, dv)

                movingSeg = movingSeg[np.newaxis, ...]
                movingSeg = torch.from_numpy(movingSeg).float().to(
                    device).unsqueeze(dim=0)
                transformedSeg = transform(movingSeg,
                                           F_X_Y.permute(0, 2, 3, 4, 1),
                                           grid).data.cpu().numpy()[0,
                                                                    0, :, :, :]
                transformedSeg[transformedSeg > 0] = 1.0
                loss_sim = diceLoss(fixedSeg, transformedSeg)
                loss_sim = DiceLoss.getDiceLoss(fixedSeg, transformedSeg)
            else:
                print("error: not supported loss ........")

            # 3 level deep supervision NCC
            F_X_Y_norm = transform_unit_flow_to_flow_cuda(
                F_X_Y.permute(0, 2, 3, 4, 1).clone())
            loss_Jacobian = loss_Jdet(F_X_Y_norm, grid)
            # reg2 - use velocity
            _, _, x, y, z = F_X_Y.shape
            F_X_Y[:, 0, :, :, :] = F_X_Y[:, 0, :, :, :] * z
            F_X_Y[:, 1, :, :, :] = F_X_Y[:, 1, :, :, :] * y
            F_X_Y[:, 2, :, :, :] = F_X_Y[:, 2, :, :, :] * x
            loss_regulation = loss_smooth(F_X_Y)

            assert not np.any(np.isnan(loss_sim.item()))
            assert not np.any(np.isnan(loss_Jacobian.item()))
            assert not np.any(np.isnan(loss_regulation.item()))

            loss = loss_sim + antifold * loss_Jacobian + smooth * loss_regulation

            assert not np.any(np.isnan(loss.item()))

            # TODO: ??? why clearing optimiser evey new example?
            optimizer.zero_grad()  # clear gradients for this training step

            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            lossall[:, step] = np.array([
                loss.item(),
                loss_sim.item(),
                loss_Jacobian.item(),
                loss_regulation.item()
            ])

            logLine = "\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:4f}" - Jdet "{3:.10f}" -smo "{4:.4f}"'.format(
                step, loss.item(), loss_sim.item(), loss_Jacobian.item(),
                loss_regulation.item())
            #sys.stdout.write(logLine)
            #sys.stdout.flush()
            print(logLine)
            logLvlFile = open(logLvlPath, "a")
            logLvlFile.write(logLine)
            logLvlFile.close()
            # with lr 1e-3 + with bias
            if lvlID == 3:
                n_checkpoint = 10

            if (step % n_checkpoint == 0):
                model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
                    lvlID) + "_" + str(step) + '.pth'
                loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
                    lvlID) + "_" + str(step) + '.npy'
                torch.save(model.state_dict(), model_lvl_path)
                # np.save(loss_lvl_path, lossall)

                iaLog2Fig(logLvlPath)

                if doValidation:
                    iaVal = validate(model)

            if (lvlID == 3) and (step == freeze_step):
                model.unfreeze_modellvl2()

            step += 1

            if step > num_iteration:
                break
        print("one epoch pass ....")

    model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
        lvlID) + "_" + str(num_iteration) + '.pth'
    loss_lvl_path = model_dir + '/' + 'loss' + model_name + "stagelvl" + str(
        lvlID) + "_" + str(num_iteration) + '.npy'
    torch.save(model.state_dict(), model_lvl_path)
    #np.save(loss_lvl_path, lossall)
    return model_lvl_path
示例#10
0
def train():
    model = SYMNet(2, 3, start_channel).cuda()
    loss_similarity = NCC()
    loss_smooth = smoothloss
    loss_magnitude = magnitude_loss
    loss_Jdet = neg_Jdet_loss

    transform = SpatialTransform().cuda()
    diff_transform = DiffeomorphicTransform(time_step=7).cuda()
    com_transform = CompositionTransform().cuda()

    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True
    names = sorted(glob.glob(datapath + '/*.nii'))[0:255]
    grid = generate_grid(imgshape)
    grid = torch.from_numpy(np.reshape(grid,
                                       (1, ) + grid.shape)).cuda().float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    model_dir = '../Model'

    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    lossall = np.zeros((6, iteration))

    training_generator = Data.DataLoader(Dataset_epoch(names, norm=False),
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=2)
    step = 0

    while step <= iteration:
        for X, Y in training_generator:

            X = X.cuda().float()
            Y = Y.cuda().float()
            F_xy, F_yx = model(X, Y)

            F_X_Y_half = diff_transform(F_xy, grid, range_flow)
            F_Y_X_half = diff_transform(F_yx, grid, range_flow)

            F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow)
            F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow)

            X_Y_half = transform(
                X,
                F_X_Y_half.permute(0, 2, 3, 4, 1) * range_flow, grid)
            Y_X_half = transform(
                Y,
                F_Y_X_half.permute(0, 2, 3, 4, 1) * range_flow, grid)

            F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow)
            F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow)

            X_Y = transform(X, F_X_Y.permute(0, 2, 3, 4, 1) * range_flow, grid)
            Y_X = transform(Y, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid)

            loss1 = loss_similarity(X_Y_half, Y_X_half)
            loss2 = loss_similarity(Y, X_Y) + loss_similarity(X, Y_X)
            loss3 = loss_magnitude(F_X_Y_half * range_flow,
                                   F_Y_X_half * range_flow)
            loss4 = loss_Jdet(
                F_X_Y.permute(0, 2, 3, 4, 1) * range_flow, grid) + loss_Jdet(
                    F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid)
            loss5 = loss_smooth(F_xy * range_flow) + loss_smooth(
                F_yx * range_flow)

            loss = loss1 + loss2 + magnitude * loss3 + local_ori * loss4 + smooth * loss5
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            lossall[:, step] = np.array([
                loss.item(),
                loss1.item(),
                loss2.item(),
                loss3.item(),
                loss4.item(),
                loss5.item()
            ])
            sys.stdout.write(
                "\r" +
                'step "{0}" -> training loss "{1:.4f}" - sim_mid "{2:.4f}" - sim_full "{3:4f}" - mag "{4:.4f}" - Jdet "{5:.10f}" -smo "{6:.4f}" '
                .format(step, loss.item(), loss1.item(), loss2.item(),
                        loss3.item(), loss4.item(), loss5.item()))
            sys.stdout.flush()

            if (step % n_checkpoint == 0):
                modelname = model_dir + '/SYMNet_' + str(step) + '.pth'
                torch.save(model.state_dict(), modelname)
                np.save(model_dir + '/loss_SYMNet_' + str(step) + '.npy',
                        lossall)
            step += 1

            if step > iteration:
                break
        print("one epoch pass")
    np.save(model_dir + '/loss_SYMNet.npy', lossall)