Ejemplo n.º 1
0
def train_lvl1():
    print("Training lvl1...")
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
        2,
        3,
        start_channel,
        is_train=True,
        imgshape=imgshape_4,
        range_flow=range_flow).to(device)

    loss_similarity = NCC(win=3)
    loss_Jdet = neg_Jdet_loss
    loss_smooth = smoothloss

    transform = SpatialTransform_unit().to(device)

    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True

    # OASIS
    names = sorted(glob.glob(datapath + '/*.nii'))

    grid_4 = generate_grid(imgshape_4)
    grid_4 = torch.from_numpy(np.reshape(grid_4, (1, ) +
                                         grid_4.shape)).to(device).float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    model_dir = '../Model/Stage'

    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    lossall = np.zeros((4, iteration_lvl1 + 1))

    training_generator = Data.DataLoader(Dataset_epoch(names, norm=False),
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=2)
    step = 0
    load_model = False
    if load_model is True:
        model_path = "../Model/LDR_LPBA_NCC_lap_share_preact_1_05_3000.pth"
        print("Loading weight: ", model_path)
        step = 3000
        model.load_state_dict(torch.load(model_path))
        temp_lossall = np.load(
            "../Model/loss_LDR_LPBA_NCC_lap_share_preact_1_05_3000.npy")
        lossall[:, 0:3000] = temp_lossall[:, 0:3000]

    while step <= iteration_lvl1:
        for X, Y in training_generator:

            X = X.to(device).float()
            Y = Y.to(device).float()

            # output_disp_e0, warpped_inputx_lvl1_out, down_y, output_disp_e0_v, e0
            F_X_Y, X_Y, Y_4x, F_xy, _ = model(X, Y)

            # 3 level deep supervision NCC
            loss_multiNCC = loss_similarity(X_Y, Y_4x)

            F_X_Y_norm = transform_unit_flow_to_flow_cuda(
                F_X_Y.permute(0, 2, 3, 4, 1).clone())

            loss_Jacobian = loss_Jdet(F_X_Y_norm, grid_4)

            # reg2 - use velocity
            _, _, x, y, z = F_X_Y.shape
            F_X_Y[:, 0, :, :, :] = F_X_Y[:, 0, :, :, :] * z
            F_X_Y[:, 1, :, :, :] = F_X_Y[:, 1, :, :, :] * y
            F_X_Y[:, 2, :, :, :] = F_X_Y[:, 2, :, :, :] * x
            loss_regulation = loss_smooth(F_X_Y)

            loss = loss_multiNCC + antifold * loss_Jacobian + smooth * loss_regulation

            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            lossall[:, step] = np.array([
                loss.item(),
                loss_multiNCC.item(),
                loss_Jacobian.item(),
                loss_regulation.item()
            ])
            sys.stdout.write(
                "\r" +
                'step "{0}" -> training loss "{1:.4f}" - sim_NCC "{2:4f}" - Jdet "{3:.10f}" -smo "{4:.4f}"'
                .format(step, loss.item(), loss_multiNCC.item(),
                        loss_Jacobian.item(), loss_regulation.item()))
            sys.stdout.flush()

            # with lr 1e-3 + with bias
            if (step % n_checkpoint == 0):
                modelname = model_dir + '/' + model_name + "stagelvl1_" + str(
                    step) + '.pth'
                torch.save(model.state_dict(), modelname)
                np.save(
                    model_dir + '/loss' + model_name + "stagelvl1_" +
                    str(step) + '.npy', lossall)

            step += 1

            if step > iteration_lvl1:
                break
        print("one epoch pass")
    np.save(model_dir + '/loss' + model_name + 'stagelvl1.npy', lossall)
Ejemplo n.º 2
0
def train(lvlID, opt=[], model_lvl1_path="", model_lvl2_path=""):
    print("Training " + str(lvlID) +
          "===========================================================")

    model_dir = '../Model/Stage'
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    result_folder_path = "../Results"
    logLvlPath = model_dir + "/logLvl_" + str(lvlID) + ".txt"
    #logLvlChrtPath = model_dir + "/logLvl_"+str(lvlID)+".png"

    numWorkers = 0  # 2  number of threads for the data generators???

    #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    freeze_step = opt.freeze_step  # TODO:  ???

    lossName = "_NCC_" if opt.simLossType == 0 else (
        "_MSE_" if opt.simLossType == 1 else "_DICE_")
    model_name = "LDR_OASIS" + lossName + "_disp_" + str(
        opt.start_channel) + "_" + str(opt.iteration_lvl1) + "_f" + str(
            opt.iteration_lvl2) + "_f" + str(opt.iteration_lvl3) + "_"
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
        lvlID) + "_0.pth"
    loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
        lvlID) + "_0.npy"

    n_checkpoint = opt.checkpoint

    if lvlID == 1:
        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        grid = generate_grid(imgshape_4)

        start_iteration = opt.sIteration_lvl1
        num_iteration = opt.iteration_lvl1
    elif lvlID == 2:
        model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        model_lvl1.load_state_dict(torch.load(model_lvl1_path))
        # Freeze model_lvl1 weight
        for param in model_lvl1.parameters():
            param.requires_grad = False

        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl2(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl2,
            imgshape=imgshape_2,
            range_flow=range_flow,
            model_lvl1=model_lvl1).to(device)

        grid = generate_grid(imgshape_2)
        start_iteration = opt.sIteration_lvl2
        num_iteration = opt.iteration_lvl2
    elif lvlID == 3:
        model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl1,
            imgshape=imgshape_4,
            range_flow=range_flow).to(device)
        model_lvl2 = Miccai2020_LDR_laplacian_unit_disp_add_lvl2(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl2,
            imgshape=imgshape_2,
            range_flow=range_flow,
            model_lvl1=model_lvl1).to(device)
        model_lvl2.load_state_dict(torch.load(model_lvl2_path))
        # Freeze model_lvl1 weight
        for param in model_lvl2.parameters():
            param.requires_grad = False

        model = Miccai2020_LDR_laplacian_unit_disp_add_lvl3(
            in_channel,
            n_classes,
            start_channel,
            is_train=isTrainLvl3,
            imgshape=imgshape,
            range_flow=range_flow,
            model_lvl2=model_lvl2).to(device)

        grid = generate_grid(imgshape)
        start_iteration = opt.sIteration_lvl3
        num_iteration = opt.iteration_lvl3

    load_model_lvl = True if start_iteration > 0 else False

    loss_Jdet = neg_Jdet_loss
    loss_smooth = smoothloss

    transform = SpatialTransform_unit().to(device)

    for param in transform.parameters():
        param.requires_grad = False  # TODO: ???
        param.volatile = True  # TODO: ???

    grid = torch.from_numpy(np.reshape(grid,
                                       (1, ) + grid.shape)).to(device).float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    lossall = np.zeros((4, num_iteration + 1))

    #TODO: improve the data generator:
    #  - use fixed lists for training and testing
    #  - use augmentation

    # names, norm=1, aug=1, isSeg=0 , new_size=[0,0,0]):
    training_generator = Data.DataLoader(Dataset_epoch(trainingLst,
                                                       norm=doNormalisation,
                                                       aug=opt.doAugmentation,
                                                       isSeg=opt.isSeg),
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=numWorkers)

    step = 0
    if start_iteration > 0:
        model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
            lvlID) + "_" + str(num_iteration) + '.pth'
        loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
            lvlID) + "_" + str(num_iteration) + '.npy'
        print("Loading weight and loss : ", model_lvl_path)
        step = num_iteration + 1
        model.load_state_dict(torch.load(model_lvl_path))
        temp_lossall = np.load(loss_lvl_path)
        lossall[:, 0:num_iteration] = temp_lossall[:, 0:num_iteration]
    else:
        #create log file only when
        logLvlFile = open(logLvlPath, "w")
        logLvlFile.close

    stepsLst = []
    lossLst = []
    simNCCLst = []
    JdetLst = []
    smoLst = []

    # for each iteration
    #TODO: modify the iteration to be related to the number of images
    while step <= num_iteration:
        #for each pair in the data generator
        for pair in training_generator:
            X = pair[0][0]
            Y = pair[0][1]
            movingPath = pair[1][0][0]
            fixedPath = pair[1][1][0]
            ext = ".nii.gz" if ".nii.gz" in fixedPath else (
                ".nii" if ".nii" in fixedPath else ".nrrd")

            X = X.to(device).float()
            Y = Y.to(device).float()
            assert not np.any(np.isnan(X.cpu().numpy()))
            assert not np.any(np.isnan(Y.cpu().numpy()))

            # output_disp_e0, warpped_inputx_lvl1_out, down_y, output_disp_e0_v, e0
            # F_X_Y: displacement_field,
            # X_Y: wrapped_moving_image,
            # Y_4x: downsampled_fixed_image,
            # F_xy: velocity_field
            if lvlID == 1:
                F_X_Y, X_Y, Y_4x, F_xy, _ = model(X, Y)
            elif lvlID == 2:
                F_X_Y, X_Y, Y_4x, F_xy, F_xy_lvl1, _ = model(X, Y)
            elif lvlID == 3:
                F_X_Y, X_Y, Y_4x, F_xy, F_xy_lvl1, F_xy_lvl2, _ = model(X, Y)

            # print("Y_4x shape : ",Y_4x.shape)
            # print("X_Y shape  : ",X_Y.shape)
            if opt.simLossType == 0:  # NCC
                if lvlID == 1:
                    loss_similarity = NCC(win=3)
                elif lvlID == 2:
                    loss_similarity = multi_resolution_NCC(win=5, scale=2)
                elif lvlID == 3:
                    loss_similarity = multi_resolution_NCC(win=7, scale=3)
                loss_sim = loss_similarity(X_Y, Y_4x)
            elif opt.simLossType == 1:  # mse loss
                #loss_sim = mseLoss(X_Y, Y_4x)
                loss_sim = mse_loss(X_Y, Y_4x)
                #print("loss_sim : ",loss_sim)
            elif opt.simLossType == 2:  # Dice loss
                # transform seg
                dv = math.pow(2, 3 - lvlID)
                fixedSeg = img2SegTensor(fixedPath, ext, dv)
                movingSeg = img2SegTensor(movingPath, ext, dv)

                movingSeg = movingSeg[np.newaxis, ...]
                movingSeg = torch.from_numpy(movingSeg).float().to(
                    device).unsqueeze(dim=0)
                transformedSeg = transform(movingSeg,
                                           F_X_Y.permute(0, 2, 3, 4, 1),
                                           grid).data.cpu().numpy()[0,
                                                                    0, :, :, :]
                transformedSeg[transformedSeg > 0] = 1.0
                loss_sim = diceLoss(fixedSeg, transformedSeg)
                loss_sim = DiceLoss.getDiceLoss(fixedSeg, transformedSeg)
            else:
                print("error: not supported loss ........")

            # 3 level deep supervision NCC
            F_X_Y_norm = transform_unit_flow_to_flow_cuda(
                F_X_Y.permute(0, 2, 3, 4, 1).clone())
            loss_Jacobian = loss_Jdet(F_X_Y_norm, grid)
            # reg2 - use velocity
            _, _, x, y, z = F_X_Y.shape
            F_X_Y[:, 0, :, :, :] = F_X_Y[:, 0, :, :, :] * z
            F_X_Y[:, 1, :, :, :] = F_X_Y[:, 1, :, :, :] * y
            F_X_Y[:, 2, :, :, :] = F_X_Y[:, 2, :, :, :] * x
            loss_regulation = loss_smooth(F_X_Y)

            assert not np.any(np.isnan(loss_sim.item()))
            assert not np.any(np.isnan(loss_Jacobian.item()))
            assert not np.any(np.isnan(loss_regulation.item()))

            loss = loss_sim + antifold * loss_Jacobian + smooth * loss_regulation

            assert not np.any(np.isnan(loss.item()))

            # TODO: ??? why clearing optimiser evey new example?
            optimizer.zero_grad()  # clear gradients for this training step

            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            lossall[:, step] = np.array([
                loss.item(),
                loss_sim.item(),
                loss_Jacobian.item(),
                loss_regulation.item()
            ])

            logLine = "\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:4f}" - Jdet "{3:.10f}" -smo "{4:.4f}"'.format(
                step, loss.item(), loss_sim.item(), loss_Jacobian.item(),
                loss_regulation.item())
            #sys.stdout.write(logLine)
            #sys.stdout.flush()
            print(logLine)
            logLvlFile = open(logLvlPath, "a")
            logLvlFile.write(logLine)
            logLvlFile.close()
            # with lr 1e-3 + with bias
            if lvlID == 3:
                n_checkpoint = 10

            if (step % n_checkpoint == 0):
                model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
                    lvlID) + "_" + str(step) + '.pth'
                loss_lvl_path = model_dir + '/loss' + model_name + "stagelvl" + str(
                    lvlID) + "_" + str(step) + '.npy'
                torch.save(model.state_dict(), model_lvl_path)
                # np.save(loss_lvl_path, lossall)

                iaLog2Fig(logLvlPath)

                if doValidation:
                    iaVal = validate(model)

            if (lvlID == 3) and (step == freeze_step):
                model.unfreeze_modellvl2()

            step += 1

            if step > num_iteration:
                break
        print("one epoch pass ....")

    model_lvl_path = model_dir + '/' + model_name + "stagelvl" + str(
        lvlID) + "_" + str(num_iteration) + '.pth'
    loss_lvl_path = model_dir + '/' + 'loss' + model_name + "stagelvl" + str(
        lvlID) + "_" + str(num_iteration) + '.npy'
    torch.save(model.state_dict(), model_lvl_path)
    #np.save(loss_lvl_path, lossall)
    return model_lvl_path