예제 #1
0
def validate(model):
    #save the current model
    testCounter = 0.0
    testing_generator = Data.DataLoader(Dataset_epoch(testingLst,
                                                      norm=doNormalisation,
                                                      aug=0),
                                        batch_size=1,
                                        shuffle=True,
                                        num_workers=numWorkers)
    for test_pair in testing_generator:
        tX = test_pair[0][0]
        tY = test_pair[0][1]
        moving_img_Path = test_pair[1][0][0]
        fixed_img_path = test_pair[1][1][0]
        # print(fixed_img_path)
        # print(moving_img_Path)
        fixed_seg_path = fixed_img_path[:-7] + "_seg.nii.gz"
        moving_seg_Path = moving_img_Path[:-7] + "_seg.nii.gz"

        fixed_seg = sitk.GetArrayFromImage(sitk.ReadImage(fixed_seg_path))
        fixed_seg = np.swapaxes(fixed_seg, 0, 2)
        moving_seg = sitk.GetArrayFromImage(sitk.ReadImage(moving_seg_Path))

        #convert to binary classes
        fixed_seg[fixed_seg > 0.0] = 1.0
        moving_seg[moving_seg > 0.0] = 1.0

        #moving_seg = load_4D(moving_seg_Path)
        moving_seg = moving_seg[np.newaxis, ...]
        moving_seg_tensor = torch.from_numpy(moving_seg).float().to(
            device).unsqueeze(dim=0)

        tX = tX.to(device).float()
        tY = tY.to(device).float()

        # compose_field_e0_lvl1, warpped_inputx_lvl1_out, y, output_disp_e0_v, lvl1_v, lvl2_v, e0
        with torch.no_grad():
            disp_fld, tX_Y, tY_4x, tF_xy, tF_xy_lvl1, tF_xy_lvl2, _ = model(
                tX, tY)
        transformed_seg_image = transform(moving_seg_tensor,
                                          disp_fld.permute(0, 2, 3, 4, 1),
                                          grid).data.cpu().numpy()[0,
                                                                   0, :, :, :]
        transformed_seg_image[transformed_seg_image > 0.0] = 1.0

        gt = fixed_seg.ravel()
        res = transformed_seg_image.ravel()
        total_avg_dice = total_avg_dice + np.sum(
            res[gt == 1]) * 2.0 / (np.sum(res) + np.sum(gt))
        testCounter += 1
    total_avg_dice = total_avg_dice / testCounter
    return total_avg_dice
예제 #2
0
def testImages(testingImagesLst):
    totalDice = 0.0
    testing_generator = Data.DataLoader(Dataset_epoch(testingLst, norm=doNormalisation, isSeg=isSeg), batch_size=1, shuffle=True,  num_workers=0)
    for imgPair in testing_generator:
        moving_path = imgPair[1][1][0]
        fixed_path  = imgPair[1][0][0]
        moving_tensor = imgPair[0][0]
        fixed_tensor  = imgPair[0][1]
        # print("moving_tensor.shape : ",moving_tensor.shape)
        # print("fixed_path  : ",fixed_path)
        # print("moving_path  : ", moving_path)
        dicePair = testOnePair(fixed_path,moving_path,  1  ,[moving_tensor,fixed_tensor])
        totalDice = totalDice +dicePair
    avgTotalDice = totalDice/len(imgPair)
    return avgTotalDice
예제 #3
0
def train_lvl1():
    print("Training lvl1...")
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
        2,
        3,
        start_channel,
        is_train=True,
        imgshape=imgshape_4,
        range_flow=range_flow).to(device)

    loss_similarity = NCC(win=3)
    loss_Jdet = neg_Jdet_loss
    loss_smooth = smoothloss

    transform = SpatialTransform_unit().to(device)

    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True

    # OASIS
    names = sorted(glob.glob(datapath + '/*.nii'))

    grid_4 = generate_grid(imgshape_4)
    grid_4 = torch.from_numpy(np.reshape(grid_4, (1, ) +
                                         grid_4.shape)).to(device).float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    model_dir = '../Model/Stage'

    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    lossall = np.zeros((4, iteration_lvl1 + 1))

    training_generator = Data.DataLoader(Dataset_epoch(names, norm=False),
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=2)
    step = 0
    load_model = False
    if load_model is True:
        model_path = "../Model/LDR_LPBA_NCC_lap_share_preact_1_05_3000.pth"
        print("Loading weight: ", model_path)
        step = 3000
        model.load_state_dict(torch.load(model_path))
        temp_lossall = np.load(
            "../Model/loss_LDR_LPBA_NCC_lap_share_preact_1_05_3000.npy")
        lossall[:, 0:3000] = temp_lossall[:, 0:3000]

    while step <= iteration_lvl1:
        for X, Y in training_generator:

            X = X.to(device).float()
            Y = Y.to(device).float()

            # output_disp_e0, warpped_inputx_lvl1_out, down_y, output_disp_e0_v, e0
            F_X_Y, X_Y, Y_4x, F_xy, _ = model(X, Y)

            # 3 level deep supervision NCC
            loss_multiNCC = loss_similarity(X_Y, Y_4x)

            F_X_Y_norm = transform_unit_flow_to_flow_cuda(
                F_X_Y.permute(0, 2, 3, 4, 1).clone())

            loss_Jacobian = loss_Jdet(F_X_Y_norm, grid_4)

            # reg2 - use velocity
            _, _, x, y, z = F_X_Y.shape
            F_X_Y[:, 0, :, :, :] = F_X_Y[:, 0, :, :, :] * z
            F_X_Y[:, 1, :, :, :] = F_X_Y[:, 1, :, :, :] * y
            F_X_Y[:, 2, :, :, :] = F_X_Y[:, 2, :, :, :] * x
            loss_regulation = loss_smooth(F_X_Y)

            loss = loss_multiNCC + antifold * loss_Jacobian + smooth * loss_regulation

            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            lossall[:, step] = np.array([
                loss.item(),
                loss_multiNCC.item(),
                loss_Jacobian.item(),
                loss_regulation.item()
            ])
            sys.stdout.write(
                "\r" +
                'step "{0}" -> training loss "{1:.4f}" - sim_NCC "{2:4f}" - Jdet "{3:.10f}" -smo "{4:.4f}"'
                .format(step, loss.item(), loss_multiNCC.item(),
                        loss_Jacobian.item(), loss_regulation.item()))
            sys.stdout.flush()

            # with lr 1e-3 + with bias
            if (step % n_checkpoint == 0):
                modelname = model_dir + '/' + model_name + "stagelvl1_" + str(
                    step) + '.pth'
                torch.save(model.state_dict(), modelname)
                np.save(
                    model_dir + '/loss' + model_name + "stagelvl1_" +
                    str(step) + '.npy', lossall)

            step += 1

            if step > iteration_lvl1:
                break
        print("one epoch pass")
    np.save(model_dir + '/loss' + model_name + 'stagelvl1.npy', lossall)
예제 #4
0
def train(lvlID, opt=[], model_lvl1_path="", model_lvl2_path=""):
    print("Training " + str(lvlID) +
          "===========================================================")

    model_dir = '../Model/Stage'
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    result_folder_path = "../Results"
    logLvlPath = model_dir + "/logLvl_" + str(lvlID) + ".txt"
    #logLvlChrtPath = model_dir + "/logLvl_"+str(lvlID)+".png"

    numWorkers = 0  # 2  number of threads for the data generators???

    #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    freeze_step = opt.freeze_step  # TODO:  ???

    lossName = "_NCC_" if opt.simLossType == 0 else (
        "_MSE_" if opt.simLossType == 1 else "_DICE_")
    model_name = "LDR_OASIS" + lossName + "_disp_" + str(
        opt.start_channel) + "_" + str(opt.iteration_lvl1) + "_f" + str(
            opt.iteration_lvl2) + "_f" + str(opt.iteration_lvl3) + "_"
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
        lvlID) + "_0.pth"
    loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
        lvlID) + "_0.npy"

    n_checkpoint = opt.checkpoint

    if lvlID == 1:
        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        grid = generate_grid(imgshape_4)

        start_iteration = opt.sIteration_lvl1
        num_iteration = opt.iteration_lvl1
    elif lvlID == 2:
        model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        model_lvl1.load_state_dict(torch.load(model_lvl1_path))
        # Freeze model_lvl1 weight
        for param in model_lvl1.parameters():
            param.requires_grad = False

        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl2(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl2,
            imgshape=imgshape_2,
            range_flow=range_flow,
            model_lvl1=model_lvl1).to(device)

        grid = generate_grid(imgshape_2)
        start_iteration = opt.sIteration_lvl2
        num_iteration = opt.iteration_lvl2
    elif lvlID == 3:
        model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        model_lvl2 = Miccai2020_LDR_laplacian_unit_disp_add_lvl2(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl2,
            imgshape=imgshape_2,
            range_flow=range_flow,
            model_lvl1=model_lvl1).to(device)
        model_lvl2.load_state_dict(torch.load(model_lvl2_path))
        # Freeze model_lvl1 weight
        for param in model_lvl2.parameters():
            param.requires_grad = False

        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl3(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl3,
            imgshape=imgshape,
            range_flow=range_flow,
            model_lvl2=model_lvl2).to(device)

        grid = generate_grid(imgshape)
        start_iteration = opt.sIteration_lvl3
        num_iteration = opt.iteration_lvl3

    load_model_lvl = True if start_iteration > 0 else False

    loss_Jdet = neg_Jdet_loss
    loss_smooth = smoothloss

    transform = SpatialTransform_unit().to(device)

    for param in transform.parameters():
        param.requires_grad = False  # TODO: ???
        param.volatile = True  # TODO: ???

    grid = torch.from_numpy(np.reshape(grid,
                                       (1, ) + grid.shape)).to(device).float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    lossall = np.zeros((4, num_iteration + 1))

    #TODO: improve the data generator:
    #  - use fixed lists for training and testing
    #  - use augmentation

    # names, norm=1, aug=1, isSeg=0 , new_size=[0,0,0]):
    training_generator = Data.DataLoader(Dataset_epoch(trainingLst,
                                                       norm=doNormalisation,
                                                       aug=opt.doAugmentation,
                                                       isSeg=opt.isSeg),
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=numWorkers)

    step = 0
    if start_iteration > 0:
        model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
            lvlID) + "_" + str(num_iteration) + '.pth'
        loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
            lvlID) + "_" + str(num_iteration) + '.npy'
        print("Loading weight and loss : ", model_lvl_path)
        step = num_iteration + 1
        model.load_state_dict(torch.load(model_lvl_path))
        temp_lossall = np.load(loss_lvl_path)
        lossall[:, 0:num_iteration] = temp_lossall[:, 0:num_iteration]
    else:
        #create log file only when
        logLvlFile = open(logLvlPath, "w")
        logLvlFile.close

    stepsLst = []
    lossLst = []
    simNCCLst = []
    JdetLst = []
    smoLst = []

    # for each iteration
    #TODO: modify the iteration to be related to the number of images
    while step <= num_iteration:
        #for each pair in the data generator
        for pair in training_generator:
            X = pair[0][0]
            Y = pair[0][1]
            movingPath = pair[1][0][0]
            fixedPath = pair[1][1][0]
            ext = ".nii.gz" if ".nii.gz" in fixedPath else (
                ".nii" if ".nii" in fixedPath else ".nrrd")

            X = X.to(device).float()
            Y = Y.to(device).float()
            assert not np.any(np.isnan(X.cpu().numpy()))
            assert not np.any(np.isnan(Y.cpu().numpy()))

            # output_disp_e0, warpped_inputx_lvl1_out, down_y, output_disp_e0_v, e0
            # F_X_Y: displacement_field,
            # X_Y: wrapped_moving_image,
            # Y_4x: downsampled_fixed_image,
            # F_xy: velocity_field
            if lvlID == 1:
                F_X_Y, X_Y, Y_4x, F_xy, _ = model(X, Y)
            elif lvlID == 2:
                F_X_Y, X_Y, Y_4x, F_xy, F_xy_lvl1, _ = model(X, Y)
            elif lvlID == 3:
                F_X_Y, X_Y, Y_4x, F_xy, F_xy_lvl1, F_xy_lvl2, _ = model(X, Y)

            # print("Y_4x shape : ",Y_4x.shape)
            # print("X_Y shape  : ",X_Y.shape)
            if opt.simLossType == 0:  # NCC
                if lvlID == 1:
                    loss_similarity = NCC(win=3)
                elif lvlID == 2:
                    loss_similarity = multi_resolution_NCC(win=5, scale=2)
                elif lvlID == 3:
                    loss_similarity = multi_resolution_NCC(win=7, scale=3)
                loss_sim = loss_similarity(X_Y, Y_4x)
            elif opt.simLossType == 1:  # mse loss
                #loss_sim = mseLoss(X_Y, Y_4x)
                loss_sim = mse_loss(X_Y, Y_4x)
                #print("loss_sim : ",loss_sim)
            elif opt.simLossType == 2:  # Dice loss
                # transform seg
                dv = math.pow(2, 3 - lvlID)
                fixedSeg = img2SegTensor(fixedPath, ext, dv)
                movingSeg = img2SegTensor(movingPath, ext, dv)

                movingSeg = movingSeg[np.newaxis, ...]
                movingSeg = torch.from_numpy(movingSeg).float().to(
                    device).unsqueeze(dim=0)
                transformedSeg = transform(movingSeg,
                                           F_X_Y.permute(0, 2, 3, 4, 1),
                                           grid).data.cpu().numpy()[0,
                                                                    0, :, :, :]
                transformedSeg[transformedSeg > 0] = 1.0
                loss_sim = diceLoss(fixedSeg, transformedSeg)
                loss_sim = DiceLoss.getDiceLoss(fixedSeg, transformedSeg)
            else:
                print("error: not supported loss ........")

            # 3 level deep supervision NCC
            F_X_Y_norm = transform_unit_flow_to_flow_cuda(
                F_X_Y.permute(0, 2, 3, 4, 1).clone())
            loss_Jacobian = loss_Jdet(F_X_Y_norm, grid)
            # reg2 - use velocity
            _, _, x, y, z = F_X_Y.shape
            F_X_Y[:, 0, :, :, :] = F_X_Y[:, 0, :, :, :] * z
            F_X_Y[:, 1, :, :, :] = F_X_Y[:, 1, :, :, :] * y
            F_X_Y[:, 2, :, :, :] = F_X_Y[:, 2, :, :, :] * x
            loss_regulation = loss_smooth(F_X_Y)

            assert not np.any(np.isnan(loss_sim.item()))
            assert not np.any(np.isnan(loss_Jacobian.item()))
            assert not np.any(np.isnan(loss_regulation.item()))

            loss = loss_sim + antifold * loss_Jacobian + smooth * loss_regulation

            assert not np.any(np.isnan(loss.item()))

            # TODO: ??? why clearing optimiser evey new example?
            optimizer.zero_grad()  # clear gradients for this training step

            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            lossall[:, step] = np.array([
                loss.item(),
                loss_sim.item(),
                loss_Jacobian.item(),
                loss_regulation.item()
            ])

            logLine = "\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:4f}" - Jdet "{3:.10f}" -smo "{4:.4f}"'.format(
                step, loss.item(), loss_sim.item(), loss_Jacobian.item(),
                loss_regulation.item())
            #sys.stdout.write(logLine)
            #sys.stdout.flush()
            print(logLine)
            logLvlFile = open(logLvlPath, "a")
            logLvlFile.write(logLine)
            logLvlFile.close()
            # with lr 1e-3 + with bias
            if lvlID == 3:
                n_checkpoint = 10

            if (step % n_checkpoint == 0):
                model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
                    lvlID) + "_" + str(step) + '.pth'
                loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
                    lvlID) + "_" + str(step) + '.npy'
                torch.save(model.state_dict(), model_lvl_path)
                # np.save(loss_lvl_path, lossall)

                iaLog2Fig(logLvlPath)

                if doValidation:
                    iaVal = validate(model)

            if (lvlID == 3) and (step == freeze_step):
                model.unfreeze_modellvl2()

            step += 1

            if step > num_iteration:
                break
        print("one epoch pass ....")

    model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
        lvlID) + "_" + str(num_iteration) + '.pth'
    loss_lvl_path = model_dir + '/' + 'loss' + model_name + "stagelvl" + str(
        lvlID) + "_" + str(num_iteration) + '.npy'
    torch.save(model.state_dict(), model_lvl_path)
    #np.save(loss_lvl_path, lossall)
    return model_lvl_path
예제 #5
0
def train():
    model = SYMNet(2, 3, start_channel).cuda()
    loss_similarity = NCC()
    loss_smooth = smoothloss
    loss_magnitude = magnitude_loss
    loss_Jdet = neg_Jdet_loss

    transform = SpatialTransform().cuda()
    diff_transform = DiffeomorphicTransform(time_step=7).cuda()
    com_transform = CompositionTransform().cuda()

    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True
    names = sorted(glob.glob(datapath + '/*.nii'))[0:255]
    grid = generate_grid(imgshape)
    grid = torch.from_numpy(np.reshape(grid,
                                       (1, ) + grid.shape)).cuda().float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    model_dir = '../Model'

    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    lossall = np.zeros((6, iteration))

    training_generator = Data.DataLoader(Dataset_epoch(names, norm=False),
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=2)
    step = 0

    while step <= iteration:
        for X, Y in training_generator:

            X = X.cuda().float()
            Y = Y.cuda().float()
            F_xy, F_yx = model(X, Y)

            F_X_Y_half = diff_transform(F_xy, grid, range_flow)
            F_Y_X_half = diff_transform(F_yx, grid, range_flow)

            F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow)
            F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow)

            X_Y_half = transform(
                X,
                F_X_Y_half.permute(0, 2, 3, 4, 1) * range_flow, grid)
            Y_X_half = transform(
                Y,
                F_Y_X_half.permute(0, 2, 3, 4, 1) * range_flow, grid)

            F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow)
            F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow)

            X_Y = transform(X, F_X_Y.permute(0, 2, 3, 4, 1) * range_flow, grid)
            Y_X = transform(Y, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid)

            loss1 = loss_similarity(X_Y_half, Y_X_half)
            loss2 = loss_similarity(Y, X_Y) + loss_similarity(X, Y_X)
            loss3 = loss_magnitude(F_X_Y_half * range_flow,
                                   F_Y_X_half * range_flow)
            loss4 = loss_Jdet(
                F_X_Y.permute(0, 2, 3, 4, 1) * range_flow, grid) + loss_Jdet(
                    F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid)
            loss5 = loss_smooth(F_xy * range_flow) + loss_smooth(
                F_yx * range_flow)

            loss = loss1 + loss2 + magnitude * loss3 + local_ori * loss4 + smooth * loss5
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            lossall[:, step] = np.array([
                loss.item(),
                loss1.item(),
                loss2.item(),
                loss3.item(),
                loss4.item(),
                loss5.item()
            ])
            sys.stdout.write(
                "\r" +
                'step "{0}" -> training loss "{1:.4f}" - sim_mid "{2:.4f}" - sim_full "{3:4f}" - mag "{4:.4f}" - Jdet "{5:.10f}" -smo "{6:.4f}" '
                .format(step, loss.item(), loss1.item(), loss2.item(),
                        loss3.item(), loss4.item(), loss5.item()))
            sys.stdout.flush()

            if (step % n_checkpoint == 0):
                modelname = model_dir + '/SYMNet_' + str(step) + '.pth'
                torch.save(model.state_dict(), modelname)
                np.save(model_dir + '/loss_SYMNet_' + str(step) + '.npy',
                        lossall)
            step += 1

            if step > iteration:
                break
        print("one epoch pass")
    np.save(model_dir + '/loss_SYMNet.npy', lossall)