예제 #1
0
    def test(self):

        # get Dice
        flow_final = self.sess.run(self.flow)

        labels = sio.loadmat('data/labels.mat')['labels'][
            0]  # Anatomical labels we want to evaluate
        vals, _ = dice(self.img2_seg, self.img1_seg, labels=labels, nargout=2)
        print(np.mean(vals))

        # Warp segments with flow
        flow = np.zeros(
            [1, self.vol_size[0], self.vol_size[1], self.vol_size[2], 3])
        flow[0, :, :, :, 1] = flow_final[0, :, :, :, 0]
        flow[0, :, :, :, 0] = flow_final[0, :, :, :, 1]
        flow[0, :, :, :, 2] = flow_final[0, :, :, :, 2]

        xx = np.arange(self.vol_size[1])
        yy = np.arange(self.vol_size[0])
        zz = np.arange(self.vol_size[2])
        grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4)
        sample = flow[0, :, :, :, :] + grid
        sample = np.stack(
            (sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3)
        warp_seg = interpn((yy, xx, zz),
                           self.img2_seg,
                           sample,
                           method='nearest',
                           bounds_error=False,
                           fill_value=0)
        val, _ = dice(warp_seg, self.img1_seg, labels=labels, nargout=2)
        print(np.mean(val))
예제 #2
0
def test(model_name,
         gpu_id,
         nf_enc=[16, 32, 32, 32],
         nf_dec=[32, 32, 32, 32, 32, 16, 16]):
    """
    test

    nf_enc and nf_dec
    #nf_dec = [32,32,32,32,32,16,16,3]
    # This needs to be changed. Ideally, we could just call load_model, and we wont have to
    # specify the # of channels here, but the load_model is not working with the custom loss...
    """

    # load subject test
    print("load_data start")
    X_train, y_train = load_data(data_dir='../data',
                                 mode='test',
                                 fixed='joyoungje')
    vol_size = y_train.shape[1:-1]

    # gpu handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        net = networks.cvpr2018_net(vol_size, nf_enc, nf_dec)
        print("model load weights")
        net.load_weights(model_name)

        # NN transfer model
        nn_trf_model = networks.nn_trf(vol_size, indexing='ij')


#    # if CPU, prepare grid
#    if compute_type == 'CPU':
#        grid, xx, yy, zz = util.volshape2grid_3d(vol_size, nargout=4)

    with tf.device(gpu):
        print("model predict")
        pred = net.predict([X_train, y_train])
        print("nn_tft_model.predict")
        X_warp = nn_trf_model.predict([X_train, pred[1]])[0, ..., 0]

    reshape_y_train = y_train.reshape(y_train.shape[1:-1])
    vals = dice(pred[0].reshape(pred[0].shape[1:-1]), reshape_y_train)
    dice_mean = np.mean(vals)
    dice_std = np.std(vals)
    print('Dice mean over structures: {:.2f} ({:.2f})'.format(
        dice_mean, dice_std))
예제 #3
0
def test(model_name, iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]):
	"""
	test

	nf_enc and nf_dec
	#nf_dec = [32,32,32,32,32,16,16,3]
    # This needs to be changed. Ideally, we could just call load_model, and we wont have to
    # specify the # of channels here, but the load_model is not working with the custom loss...
    """  

	gpu = '/gpu:' + str(gpu_id)

	# Anatomical labels we want to evaluate
	labels = sio.loadmat('../data/labels.mat')['labels'][0]

	atlas = np.load('../data/atlas_norm.npz')
	atlas_vol = atlas['vol']
	atlas_seg = atlas['seg']
	atlas_vol = np.reshape(atlas_vol, (1,)+atlas_vol.shape+(1,))

	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True
	config.allow_soft_placement = True
	set_session(tf.Session(config=config))

	# load weights of model
	with tf.device(gpu):
		net = networks.unet(vol_size, nf_enc, nf_dec)
		net.load_weights('../models/' + model_name +
                         '/' + str(iter_num) + '.h5')

	xx = np.arange(vol_size[1])
	yy = np.arange(vol_size[0])
	zz = np.arange(vol_size[2])
	grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4)

	X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz')

	with tf.device(gpu):
		pred = net.predict([X_vol, atlas_vol])

	# Warp segments with flow
	flow = pred[1][0, :, :, :, :]
	sample = flow+grid
	sample = np.stack((sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3)
	warp_seg = interpn((yy, xx, zz), X_seg[0, :, :, :, 0], sample, method='nearest', bounds_error=False, fill_value=0)

	vals, _ = dice(warp_seg, atlas_seg, labels=labels, nargout=2)
	print(np.mean(vals), np.std(vals))
def test(model_name, iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]):
	"""
	test

	nf_enc and nf_dec
	#nf_dec = [32,32,32,32,32,16,16,3]
    # This needs to be changed. Ideally, we could just call load_model, and we wont have to
    # specify the # of channels here, but the load_model is not working with the custom loss...
    """  

	gpu = '/gpu:' + str(gpu_id)

	# Anatomical labels we want to evaluate
	labels = sio.loadmat('../data/labels.mat')['labels'][0]

	atlas = np.load('../data/atlas_norm.npz')
	atlas_vol = atlas['vol']
	atlas_seg = atlas['seg']
	atlas_vol = np.reshape(atlas_vol, (1,)+atlas_vol.shape+(1,))

	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True
	config.allow_soft_placement = True
	set_session(tf.Session(config=config))

	# load weights of model
	with tf.device(gpu):
		net = networks.unet(vol_size, nf_enc, nf_dec)
		net.load_weights('../models/' + model_name +
                         '/' + str(iter_num) + '.h5')

	xx = np.arange(vol_size[1])#192
	yy = np.arange(vol_size[0])#160
	zz = np.arange(vol_size[2])#224
	grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4)#(160,192,224,3) it stores the co-ordinate of the original position of the point in the

	X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz')

	with tf.device(gpu):
		pred = net.predict([X_vol, atlas_vol])

	# Warp segments with flow
	flow = pred[1][0, :, :, :, :]#(160,192,224,3)
	sample = flow+grid#add the original position with the shift flow the dimension is: (160,192,224,3)
	sample = np.stack((sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3)
	warp_seg = interpn((yy, xx, zz), X_seg[0, :, :, :, 0], sample, method='nearest', bounds_error=False, fill_value=0)

	vals, _ = dice(warp_seg, atlas_seg, labels=labels, nargout=2)
	print(np.mean(vals), np.std(vals))
예제 #5
0
def test(model_name, iters, gpu_id):
    patch_size = (64, 64, 64)
    num_labels = 30
    labels = sio.loadmat('labels.mat')['labels'][0]
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    validation_vols, vol_patch_loc = data_gen.vols_generator_patch(
        validation_vols_data,
        len(validation_vols_data),
        patch_size,
        stride_patch=32,
        out=2)

    with tf.device(gpu):
        model = network.unet(input_size=(64, 64, 64, 1))
        model.load_weights('models/' + model_name + '/' + 'weights-' +
                           str(iters) + '.hdf5')

    data_temp = np.load(validation_vols_data[0])['vol_data']
    subject_aseg = data_gen.re_label(validation_labels_data,
                                     len(validation_labels_data), labels)
    dice_scores = []

    # Make all the patches to one volume then predict by averaging the probability map. Argmax it to the last axis to choose the the label.
    for j in range(len(validation_vols)):
        mask = np.empty(data_temp.shape + (num_labels, ))
        for i in range(len(validation_vols[j])):
            pred_temp = model.predict(validation_vols[j][i])
            mask[vol_patch_loc[j][i][0].start:vol_patch_loc[j][i][0].stop,
                 vol_patch_loc[j][i][1].start:vol_patch_loc[j][i][1].stop,
                 vol_patch_loc[j][i][2].start:vol_patch_loc[j][i][2].
                 stop, :] += pred_temp[0, :, :, :, :]
        pred = np.argmax(mask, axis=-1)
        vals, _ = dice(pred, subject_aseg[j], labels=range(0, 30), nargout=2)
        dice_scores.append(np.mean(vals))
    print("dice:", np.mean(dice_scores), model_name, iters)
def test(
        gpu_id,
        model_dir,
        iter_num,
        compute_type='GPU',  # GPU or CPU
        vol_size=(160, 192, 224),
        nf_enc=[16, 32, 32, 32],
        nf_dec=[32, 32, 32, 32, 16, 3],
        save_file=None):
    """
    test via segmetnation propagation
    works by iterating over some iamge files, registering them to atlas,
    propagating the warps, then computing Dice with atlas segmentations
    """

    # GPU handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        # if testing miccai run, should be xy indexing.
        net = networks.miccai2018_net(vol_size,
                                      nf_enc,
                                      nf_dec,
                                      use_miccai_int=True,
                                      indexing='xy')
        net.load_weights(os.path.join(model_dir, str(iter_num) + '.h5'))

        # compose diffeomorphic flow output model
        diff_net = keras.models.Model(net.inputs,
                                      net.get_layer('diffflow').output)

        # NN transfer model
        nn_trf_model = networks.nn_trf(vol_size)

    # if CPU, prepare grid
    if compute_type == 'CPU':
        grid, xx, yy, zz = util.volshape2grid_3d(vol_size, nargout=4)

    # prepare a matrix of dice values
    dice_vals = np.zeros((len(good_labels), n_batches))
    for k in range(n_batches):
        # get data
        vol_name, seg_name = test_brain_strings[k].split(",")
        X_vol, X_seg = datagenerators.load_example_by_name(vol_name, seg_name)

        # predict transform
        with tf.device(gpu):
            pred = diff_net.predict([X_vol, atlas_vol])

        # Warp segments with flow
        if compute_type == 'CPU':
            flow = pred[0, :, :, :, :]
            warp_seg = util.warp_seg(X_seg,
                                     flow,
                                     grid=grid,
                                     xx=xx,
                                     yy=yy,
                                     zz=zz)

        else:  # GPU
            # Rigid registration only by GPU
            flow = pred[0, :, :, :, :]
            # Compute A(all about coordinate computation)
            x = np.linspace(0, 160 - 16, sample_num)
            x = x.astype(np.int32)
            y = np.linspace(0, 190 - 19, sample_num)
            y = y.astype(np.int32)
            z = np.linspace(0, 220 - 22, sample_num)
            z = z.astype(np.int32)
            index = np.rollaxis(np.array(np.meshgrid(x, y, z)), 0, 4)
            x = index[:, :, :, 0]
            y = index[:, :, :, 1]
            z = index[:, :, :, 2]

            # Y in formula
            x_flow = np.arange(vol_size[0])
            y_flow = np.arange(vol_size[1])
            z_flow = np.arange(vol_size[2])
            grid = np.rollaxis(np.array((np.meshgrid(y_flow, x_flow, z_flow))),
                               0, 4)  # original coordinate
            grid_x = grid_sample(x, y, z, grid[:, :, :, 0], sample_num)
            grid_y = grid_sample(x, y, z, grid[:, :, :, 1], sample_num)
            grid_z = grid_sample(x, y, z, grid[:, :, :, 2],
                                 sample_num)  # X (10,10,10)

            sample = flow + grid
            sample_x = grid_sample(x, y, z, sample[:, :, :, 0], sample_num)
            sample_y = grid_sample(x, y, z, sample[:, :, :, 1], sample_num)
            sample_z = grid_sample(x, y, z, sample[:, :, :, 2],
                                   sample_num)  # Y (10,10,10)

            sum_x = np.sum(flow[:, :, :, 0])
            sum_y = np.sum(flow[:, :, :, 1])
            sum_z = np.sum(flow[:, :, :, 2])

            ave_x = sum_x / (vol_size[0] * vol_size[1] * vol_size[2])
            ave_y = sum_y / (vol_size[0] * vol_size[1] * vol_size[2])
            ave_z = sum_z / (vol_size[0] * vol_size[1] * vol_size[2])

            # formula
            Y = np.zeros((10, 10, 10, grid_dimension))
            X = np.zeros((10, 10, 10, grid_dimension))
            T = np.array([ave_x, ave_y, ave_z, 1])  # (4,1)
            # R = np.zeros((10, 10, 10, grid_dimension, grid_dimension))

            for i in np.arange(10):
                for j in np.arange(10):
                    for z in np.arange(10):
                        Y[i, j, z, :] = np.array([
                            sample_x[i, j, z], sample_y[i, j, z],
                            sample_z[i, j, z], 1
                        ])

            for i in np.arange(10):
                for j in np.arange(10):
                    for z in np.arange(10):
                        X[i, j, z, :] = np.array([
                            grid_x[i, j, z], grid_y[i, j, z], grid_z[i, j, z],
                            1
                        ])

            X = X.reshape((1000, grid_dimension))
            Y = Y.reshape((1000, grid_dimension))
            R = np.dot(
                np.dot(np.linalg.pinv(np.dot(np.transpose(X), X)),
                       np.transpose(X)), Y)  # R

            # build new grid(Use R to do the spatial transform)
            shifted_x = np.arange(vol_size[0])
            shifted_y = np.arange(vol_size[1])
            shifted_z = np.arange(vol_size[2])
            print(shifted_x.shape)
            print(shifted_y.shape)
            print(shifted_z.shape)
            shifted_grid = np.rollaxis(
                np.array((np.meshgrid(shifted_y, shifted_x, shifted_z))), 0, 4)
            print(shifted_grid.shape)
            for i in np.arange(vol_size[0]):
                for j in np.arange(vol_size[1]):
                    for z in np.arange(vol_size[2]):
                        coordinates = np.dot(
                            R,
                            np.array([i, j, z, 1]).reshape(4, 1)) + T.reshape(
                                4, 1)
                        print("voxel." + '(' + str(i) + ',' + str(j) + ',' +
                              str(z) + ')')
                        shifted_grid[i, j, z, 0] = coordinates[0]
                        shifted_grid[i, j, z, 1] = coordinates[1]
                        shifted_grid[i, j, z, 2] = coordinates[2]

            # interpolation
            xx = np.arange(vol_size[1])
            yy = np.arange(vol_size[0])
            zz = np.arange(vol_size[2])
            warp_seg = interpn((yy, xx, zz),
                               X_seg[0, :, :, :, 0],
                               shifted_grid,
                               method='nearest',
                               bounds_error=False,
                               fill_value=0)

            # CVPR
            grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4)
            sample = flow + grid
            sample = np.stack(
                (sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]),
                3)
            warp_seg2 = interpn((yy, xx, zz),
                                X_seg[0, :, :, :, 0],
                                sample,
                                method='nearest',
                                bounds_error=False,
                                fill_value=0)

            # compute dice
            vals, _ = dice(warp_seg, atlas_seg, labels=labels, nargout=2)
            vals2, _ = dice(X_seg[0, :, :, :, 0],
                            atlas_seg,
                            labels=labels,
                            nargout=2)
            vals3, _ = dice(warp_seg2, atlas_seg, labels=labels, nargout=2)
            print("dice before:")
            print(np.mean(vals2), np.std(vals2))
            print("dice after deformable registration:")
            print(np.mean(vals3), np.std(vals3))
            print("dice after rigid registration:")
            print(np.mean(vals), np.std(vals))
            warp_seg = nn_trf_model.predict([X_seg, pred])[0, ..., 0]

        # compute Volume Overlap (Dice)
        dice_vals[:, k] = dice(warp_seg, atlas_seg, labels=good_labels)
        print('%3d %5.3f %5.3f' % (k, np.mean(
            dice_vals[:, k]), np.mean(np.mean(dice_vals[:, :k + 1]))))

        if save_file is not None:
            sio.savemat(save_file, {
                'dice_vals': dice_vals,
                'labels': good_labels
            })
예제 #7
0
def test(gpu, atlas_file, model, init_model_file):
    """
    model training function
    :param gpu: integer specifying the gpu to use
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param init_model_file: the model directory to load from
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    device = "cuda"

    # Produce the loaded atlas with dims.:160x192x224.
    atlas = np.load(atlas_file)
    atlas_vol = atlas['vol'][np.newaxis, ..., np.newaxis]
    atlas_seg = atlas['seg']
    vol_size = atlas_vol.shape[1:-1]

    # Test file and anatomical labels we want to evaluate
    test_file = open('../voxelmorph/data/val_examples.txt')
    test_strings = test_file.readlines()
    test_strings = [x.strip() for x in test_strings]
    good_labels = sio.loadmat(
        '../voxelmorph/data/test_labels.mat')['labels'][0]

    # Prepare the vm1 or vm2 model and send to device
    nf_enc = [16, 32, 32, 32]
    if model == "vm1":
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == "vm2":
        nf_dec = [32, 32, 32, 32, 32, 16, 16]

    # Set up model
    model = cvpr2018_net(vol_size, nf_enc, nf_dec)
    model.to(device)
    model.load_state_dict(
        torch.load(init_model_file, map_location=lambda storage, loc: storage))

    # set up atlas tensor
    input_fixed = torch.from_numpy(atlas_vol).to(device).float()
    input_fixed = input_fixed.permute(0, 4, 1, 2, 3)

    # Use this to warp segments
    trf = SpatialTransformer(atlas_vol.shape[1:-1], mode='nearest')
    trf.to(device)

    for k in range(0, len(test_strings)):

        vol_name, seg_name = test_strings[k].split(",")
        X_vol, X_seg = datagenerators.load_example_by_name(vol_name, seg_name)

        input_moving = torch.from_numpy(X_vol).to(device).float()
        input_moving = input_moving.permute(0, 4, 1, 2, 3)

        warp, flow = model(input_moving, input_fixed)

        # Warp segment using flow
        moving_seg = torch.from_numpy(X_seg).to(device).float()
        moving_seg = moving_seg.permute(0, 4, 1, 2, 3)
        warp_seg = trf(moving_seg, flow).detach().cpu().numpy()

        vals, labels = dice(warp_seg, atlas_seg, labels=good_labels, nargout=2)
        #dice_vals[:, k] = vals
        #print(np.mean(dice_vals[:, k]))
        print(np.mean(vals))
예제 #8
0
                                                        0].transpose(2, 1, 0)

        sflow = _transform(global_flow[:, 0], global_flow[:, 1],
                           global_flow[:, 2])
        nb, nc, nd, nw, nh = sflow.shape
        segflow = torch.FloatTensor(sflow.shape).zero_()
        segflow[:, 2] = (sflow[:, 0] / (nd - 1) - 0.5) * 2.0
        segflow[:, 1] = (sflow[:, 1] / (nw - 1) - 0.5) * 2.0
        segflow[:, 0] = (sflow[:, 2] / (nh - 1) - 0.5) * 2.0
        regist_seg = F.grid_sample(
            batch_s.cuda().float(),
            (segflow.cuda().float().permute(0, 2, 3, 4, 1)),
            mode='nearest')
        regist_seg = regist_seg.cpu().numpy()[0, 0].transpose(2, 1, 0)

        vals_regist, _ = dice(regist_seg, label_seg, labels=labels, nargout=2)
        vals_origin, _ = dice(data_seg, label_seg, labels=labels, nargout=2)
        registDice[isub] = vals_regist
        originDice[isub] = vals_origin
        print(np.mean(vals_regist))

        dataName = dataFile.split('\\')[-1]
        savePath = os.path.join(opt.results_dir, 'regist_' + dataName)
        result_data = {
            'seg_regist': regist_seg.astype('float32'),
            'data_regist': regist_data.astype('float32'),
            'data_field': regist_flow.astype('float32')
        }
        sio.savemat(savePath, result_data)

    dataName = 'OASIS_testSeg.mat'
    concatenate_outcome = np.empty(seg_data.shape)
    for i in range(0, 191):
        vol_train = vol_data[:, :, i, :, :]

        # concatenate slices
        #slice_outcome = load_model.predict(slice_vol)
        concatenate_outcome[:, :, i, :, :] = load_model.predict(vol_train)
        #np.concatenate([concatenate_outcome,concatenate_outcome])

    concatenate_outcome.reshape([
        1, concatenate_outcome.shape[1], concatenate_outcome.shape[2],
        concatenate_outcome.shape[3], concatenate_outcome.shape[4]
    ])
    #print('the shape of the output:')
    #print(concatenate_outcome.shape)
    # compute the dice score of test example
    print('the dice score of the test is:')
    #dice_score = nm.Dice(nb_labels = len(labels_data), input_type='prob', dice_type='hard',).dice(seg_data,concatenate_outcome)
    #dice_score = dice(concatenate_outcome,,)
    vals, _ = dice(concatenate_outcome, seg_data, nargout=2)
    print(np.mean(vals), np.std(vals))
    #print(dice_score)
"""
2d unet changed for class project
output in (1,256,256,label_nums)
input in (1,x,y,1) 4d tensor
"""
'''
we need to change the output into (1, 256, 256, label_nums)
'''
예제 #10
0
def test(
        gpu_id,
        model_dir,
        iter_num,
        compute_type='GPU',  # GPU or CPU
        vol_size=(160, 192, 224),
        nf_enc=[16, 32, 32, 32],
        nf_dec=[32, 32, 32, 32, 16, 3],
        save_file=None):
    """
    test via segmetnation propagation
    works by iterating over some iamge files, registering them to atlas,
    propagating the warps, then computing Dice with atlas segmentations
    """

    # GPU handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        # if testing miccai run, should be xy indexing.
        net = networks.miccai2018_net(vol_size,
                                      nf_enc,
                                      nf_dec,
                                      use_miccai_int=False,
                                      indexing='ij')
        net.load_weights(os.path.join(model_dir, str(iter_num) + '.h5'))

        # compose diffeomorphic flow output model
        diff_net = keras.models.Model(net.inputs,
                                      net.get_layer('diffflow').output)

        # NN transfer model
        nn_trf_model = networks.nn_trf(vol_size, indexing='ij')

    # if CPU, prepare grid
    if compute_type == 'CPU':
        grid, xx, yy, zz = util.volshape2grid_3d(vol_size, nargout=4)

    # prepare a matrix of dice values
    dice_vals = np.zeros((len(good_labels), n_batches))
    for k in range(n_batches):
        # get data
        vol_name, seg_name = test_brain_strings[k].split(",")
        X_vol, X_seg = datagenerators.load_example_by_name(vol_name, seg_name)

        # predict transform
        with tf.device(gpu):
            pred = diff_net.predict([X_vol, atlas_vol])

        # Warp segments with flow
        if compute_type == 'CPU':
            flow = pred[0, :, :, :, :]
            warp_seg = util.warp_seg(X_seg,
                                     flow,
                                     grid=grid,
                                     xx=xx,
                                     yy=yy,
                                     zz=zz)

        else:  # GPU
            warp_seg = nn_trf_model.predict([X_seg, pred])[0, ..., 0]

        # compute Volume Overlap (Dice)
        dice_vals[:, k] = dice(warp_seg, atlas_seg, labels=good_labels)
        print('%3d %5.3f %5.3f' % (k, np.mean(
            dice_vals[:, k]), np.mean(np.mean(dice_vals[:, :k + 1]))))

        if save_file is not None:
            sio.savemat(save_file, {
                'dice_vals': dice_vals,
                'labels': good_labels
            })
예제 #11
0
        #pred = pred.astype(np.float64)
        #pred = pred.astype(np.float64)
        #print(seg.shape)
        #print(pred.shape)
        #dice_score = metrics.Dice(nb_labels = 30,
        #         weights=None,
        #         input_type='prob',
        #         dice_type='hard',
        #         approx_hard_max=True,
        #         vox_weights=None,
        #         crop_indices=None,
        #         area_reg=0.1).dice(seg,pred)
        #dice_score = losses.dice_coef(seg,pred).eval()
        #y_pred_op = pred
        #y_true_op = seg
        #sum_dim = 1
        #top = 2 * K.sum(y_true_op * y_pred_op, sum_dim)
        #bottom = K.sum(K.square(y_true_op), sum_dim) + K.sum(K.square(y_pred_op), sum_dim)
        # make sure we have no 0s on the bottom. K.epsilon()
        #bottom = K.maximum(bottom, self.area_reg)
        #dice_score = top / bottom
        #sum_dice = sum_dice + dice_score
        #print(dice_score.eval())
        vals, _ = dice(pred, seg, nargout=2)
        #sum_dice = sum_dice + np.mean(vals)
        dice_score = np.mean(vals)
        print(np.mean(vals), np.std(vals))
        sum_dice = sum_dice + dice_score
#
print(sum_dice / cnt2)
print(np.mean(dice(f_seg, seg, argout=1)))
예제 #12
0
def eval_seg_sas_from_gen(sas_model,
                          atlas_vol,
                          atlas_labels,
                          eval_gen,
                          label_mapping,
                          n_eval_examples,
                          batch_size,
                          logger=None):
    '''
    Evaluates a single-atlas segmentation method on a bunch of evaluation volumes.
    :param sas_model: spatial transform model used for SAS. Can be voxelmorph.
    :param atlas_vol: atlas volume
    :param atlas_labels: atlas segmentations
    :param eval_gen: generator that yields vols_valid, segs_valid batches
    :param label_mapping: list of label ids that will appear in segs, ordered by how they map to channels
    :param n_eval_examples: total number of examples to evaluate
    :param batch_size: batch size to use in evaluation
    :param logger: python logger if we want to log messages
    :return:
    '''
    img_shape = atlas_vol.shape[1:]

    seg_warp_model = networks.warp_model(
        img_shape=img_shape,
        interp_mode='nearest',
        indexing='xy',
    )

    from keras.models import Model
    from keras.layers import Input, Activation
    from keras.optimizers import Adam
    n_labels = len(label_mapping)

    warped_in = Input(img_shape[0:-1] + (n_labels, ))
    warped = Activation('softmax')(warped_in)

    ce_model = Model(inputs=[warped_in], outputs=[warped], name='ce_model')
    ce_model.compile(loss='categorical_crossentropy', optimizer=Adam(0.0001))

    # test metrics: categorical cross-entropy and dice
    dice_per_label = np.zeros((n_eval_examples, len(label_mapping)))
    cces = np.zeros((n_eval_examples, ))
    accs = np.zeros((n_eval_examples, ))
    all_ids = []
    for bi in range(n_eval_examples):
        if logger is not None:
            logger.debug('Testing on subject {} of {}'.format(
                bi, n_eval_examples))
        else:
            print('Testing on subject {} of {}'.format(bi, n_eval_examples))
        X, Y, _, ids = next(eval_gen)
        Y_oh = labels_to_onehot(Y, label_mapping=label_mapping)

        warped, warp = sas_model.predict([atlas_vol, X])

        # warp our source models according to the predicted flow field. get rid of channels
        if Y.shape[-1] == 1:
            Y = Y[..., 0]
        preds_batch = seg_warp_model.predict(
            [atlas_labels[..., np.newaxis], warp])[..., 0]
        preds_oh = labels_to_onehot(preds_batch, label_mapping=label_mapping)

        cce = np.mean(ce_model.evaluate(preds_oh, Y_oh, verbose=False))
        subject_dice_per_label = medipy_metrics.dice(Y,
                                                     preds_batch,
                                                     labels=label_mapping)

        nonbkgmap = (Y > 0)
        acc = np.sum(((Y == preds_batch) *
                      nonbkgmap).astype(int)) / np.sum(nonbkgmap).astype(float)
        print(acc)
        dice_per_label[bi] = subject_dice_per_label
        cces[bi] = cce
        accs[bi] = acc
        all_ids += ids

    if logger is not None:
        logger.debug('Dice per label: {}, {}'.format(label_mapping,
                                                     dice_per_label))
        logger.debug('Mean dice (no bkg): {}'.format(
            np.mean(dice_per_label[:, 1:])))
        logger.debug('Mean CE: {}'.format(np.mean(cces)))
        logger.debug('Mean accuracy: {}'.format(np.mean(accs)))
    else:
        print('Dice per label: {}, {}'.format(label_mapping, dice_per_label))
        print('Mean dice (no bkg): {}'.format(np.mean(dice_per_label[:, 1:])))
        print('Mean CE: {}'.format(np.mean(cces)))
        print('Mean accuracy: {}'.format(np.mean(accs)))
    return cces, dice_per_label, accs, all_ids
예제 #13
0
def test(model_name,
         iter_num,
         gpu_id,
         n_test,
         filename,
         vol_size=(160, 192, 224),
         nf_enc=[16, 32, 32, 32],
         nf_dec=[32, 32, 32, 32, 32, 16, 16]):
    """
    test

    nf_enc and nf_dec
    #nf_dec = [32,32,32,32,32,16,16,3]
    # This needs to be changed. Ideally, we could just call load_model, and we wont have to
    # specify the # of channels here, but the load_model is not working with the custom loss...
    """
    start_time = time.time()
    gpu = '/gpu:' + str(gpu_id)
    print(gpu)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    # Anatomical labels we want to evaluate
    labels = sio.loadmat('../data/labels.mat')['labels'][0]

    atlas = np.load('../data/atlas_norm.npz')
    atlas_vol = atlas['vol']
    atlas_seg = atlas['seg']
    atlas_vol = np.reshape(atlas_vol, (1, ) + atlas_vol.shape + (1, ))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        net = networks.unet(vol_size, nf_enc, nf_dec)
        net.load_weights('../models/' + model_name + '/' + str(iter_num) +
                         '.h5')

    seg_path = '../models/seg_pretrained/0.h5'
    feature_model, num_features = networks.segmenter_feature_model(seg_path)

    with open('seg_feature_stats.txt', 'rb') as file:
        feature_stats = pickle.loads(
            file.read())  # use `pickle.loads` to do the reverse

    xx = np.arange(vol_size[1])
    yy = np.arange(vol_size[0])
    zz = np.arange(vol_size[2])
    grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4)

    percentile = 99
    dice_means = []

    results = {}

    for step in range(0, n_test):

        res = {}

        vol_name, seg_name = test_brain_strings[step].split(",")
        X_vol, X_seg = datagenerators.load_example_by_name(vol_name, seg_name)

        with tf.device(gpu):
            pred = net.predict([X_vol, atlas_vol])
            warped_image = np.transpose(pred[0][0, :, :, :, :], (2, 0, 1, 3))
            pred_ac_features = feature_model.predict([warped_image])
            orig_ac_features = feature_model.predict(
                [np.transpose(X_vol[0, :, :, :, :], (2, 0, 1, 3))])

        # Warp segments with flow
        flow = pred[1][0, :, :, :, :]
        sample = flow + grid
        sample = np.stack(
            (sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3)
        warp_seg = interpn((yy, xx, zz),
                           X_seg[0, :, :, :, 0],
                           sample,
                           method='nearest',
                           bounds_error=False,
                           fill_value=0)

        vals, _ = dice(warp_seg, atlas_seg, labels=labels, nargout=2)
        # print(np.mean(vals), np.std(vals))
        mean = np.mean(vals)
        std = np.std(vals)

        res['dice_mean'] = mean
        res['dice_std'] = std

        for i in range(len(pred_ac_features)):
            normalized_pred = normalize_percentile(pred_ac_features[i],
                                                   percentile,
                                                   feature_stats,
                                                   i,
                                                   twod=True)
            normalized_orig = normalize_percentile(orig_ac_features[i],
                                                   percentile,
                                                   feature_stats,
                                                   i,
                                                   twod=True)

            for j in range(normalized_pred.shape[-1]):
                pred_feature = normalized_pred[:, :, :, j]
                orig_feature = normalized_orig[:, :, :, j]
                append_to_dict(res, 'l1_diff',
                               np.mean(np.abs(pred_feature - orig_feature)))
                append_to_dict(res, 'l2_diff',
                               np.mean(np.square(pred_feature - orig_feature)))

                append_to_dict(res, 'pred_mean', np.mean(pred_feature))
                append_to_dict(res, 'pred_std', np.std(pred_feature))
                append_to_dict(res, 'pred_99pc',
                               np.percentile(pred_feature, 99))
                append_to_dict(res, 'pred_1pc', np.percentile(pred_feature, 1))

                append_to_dict(res, 'orig_mean', np.mean(orig_feature))
                append_to_dict(res, 'orig_std', np.std(orig_feature))
                append_to_dict(res, 'orig_99pc',
                               np.percentile(orig_feature, 99))
                append_to_dict(res, 'orig_1pc', np.percentile(orig_feature, 1))

        dice_means.append(mean)

        results[vol_name] = res
        print(step, mean, std)
        print('time:', time.time() - start_time)

    print('average dice:', np.mean(dice_means))
    print('time taken:', time.time() - start_time)
    # for key, value in results.items():
    #     print(key)
    #     print(value)

    with open(filename, 'wb') as file:
        file.write(
            pickle.dumps(results))  # use `pickle.loads` to do the reverse
예제 #14
0
def train(src_dir,
          tgt_dir,
          model_dir,
          model_lr_dir,
          lr,
          nb_epochs,
          reg_param,
          steps_per_epoch,
          batch_size,
          load_model_file=None,
          data_loss='ncc',
          initial_epoch=0):
    """
    model training function
    :param data_dir: folder with npz files for each subject.
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param model_dir: the model directory to save to
    :param lr: learning rate
    :param n_iterations: number of training iterations
    :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper)
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param load_model_file: optional h5 model file to initialize with
    :param data_loss: data_loss: 'mse' or 'ncc
    """

    # prepare data files
    # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders
    # inside each folder is a /vols/ and a /asegs/ folder with the volumes
    # and segmentations. All of our papers use npz formated data.
    src_vol_names = glob.glob(os.path.join(src_dir, '*.npz'))
    tgt_vol_names = glob.glob(os.path.join(tgt_dir, '*.npz'))
    random.shuffle(src_vol_names)  # shuffle volume list
    random.shuffle(tgt_vol_names)  # shuffle volume list
    assert len(src_vol_names) > 0, "Could not find any training data"

    assert data_loss in [
        'mse', 'ncc'
    ], 'Loss should be one of mse or cc, found %s' % data_loss
    if data_loss == 'ncc':
        data_loss = losses.NCC().loss

        # GPU handling
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    # set_session(tf.Session(config=config))

    vol_size = (56, 56, 56)
    # prepare the model
    src_lr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_src_lr')
    tgt_lr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_tgt_lr')
    srm_lr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='mask_src_lr')
    attn_lr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_lr')

    src_mr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_src_mr')
    tgt_mr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_tgt_mr')
    srm_mr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='mask_src_mr')
    df_lr2mr = tf.placeholder(tf.float32, [None, *vol_size, 3],
                              name='df_lr2mr')
    attn_mr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_mr')

    src_hr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_src_hr')
    tgt_hr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_tgt_hr')
    srm_hr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='mask_src_hr')
    df_mr2hr = tf.placeholder(tf.float32, [None, *vol_size, 3],
                              name='df_mr2hr')
    attn_hr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_hr')

    model_lr = networks.net_lr(src_lr, tgt_lr, srm_lr)
    model_mr = networks.net_mr(src_mr, tgt_mr, srm_mr, df_lr2mr)
    model_hr = networks.net_hr(src_hr, tgt_hr, srm_hr, df_mr2hr)

    # the loss functions
    lr_ncc = data_loss(model_lr[0].outputs, tgt_lr)
    #lr_grd = losses.Grad('l2').loss(model_lr[0].outputs, model_lr[2].outputs)
    lr_grd = losses.Anti_Folding('l2').loss(model_lr[0].outputs,
                                            model_lr[2].outputs)

    cost_lr = lr_ncc + reg_param * lr_grd  # + lr_attn

    mr_ncc = data_loss(model_mr[0].outputs, tgt_mr)
    #mr_grd = losses.Grad('l2').loss(model_mr[0].outputs, model_mr[2].outputs)
    mr_grd = losses.Anti_Folding('l2').loss(model_mr[0].outputs,
                                            model_mr[2].outputs)

    cost_mr = mr_ncc + reg_param * mr_grd

    hr_ncc = data_loss(model_hr[0].outputs, tgt_hr)
    #hr_grd = losses.Grad('l2').loss(model_hr[0].outputs, model_hr[2].outputs)
    hr_grd = losses.Anti_Folding('l2').loss(model_hr[0].outputs,
                                            model_hr[2].outputs)

    cost_hr = hr_ncc + reg_param * hr_grd

    # the training operations
    def get_v(name):
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if name in var.name]
        return d_vars

    #attn_vars = tl.layers.get_variables_with_name('cbam_1', True, True)
    attn_vars = get_v('cbam_1')
    for a_v in attn_vars:
        print(a_v)

    train_op_lr = tf.train.AdamOptimizer(lr).minimize(cost_lr)

    train_op_mr = tf.train.AdamOptimizer(lr).minimize(cost_mr)
    train_op_hr = tf.train.AdamOptimizer(lr).minimize(cost_hr)

    # data generator
    src_example_gen = datagenerators.example_gen(src_vol_names,
                                                 batch_size=batch_size)
    tgt_example_gen = datagenerators.example_gen(tgt_vol_names,
                                                 batch_size=batch_size)

    data_gen = datagenerators.gen_with_mask(src_example_gen,
                                            tgt_example_gen,
                                            batch_size=batch_size)

    variables_to_restore = tf.contrib.framework.get_variables_to_restore(
        exclude=['net_hr'])
    saver = tf.train.Saver(variables_to_restore)

    #saver = tf.train.Saver(max_to_keep=3)
    # fit generator
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        # load initial weights
        try:
            if load_model_file is not None:
                model_file = tf.train.latest_checkpoint(load_model_file)  #
                saver.restore(sess, model_file)
        except:
            print('No files in', load_model_file)
        saver.save(sess, model_dir + 'dfnet', global_step=0)

        def resize_df(df, zoom):
            df1 = nd.interpolation.zoom(
                df[0, :, :, :, 0], zoom=zoom, mode='nearest', order=3) * zoom[
                    0]  # Cubic: order=3; Bilinear: order=1; Nearest: order=0
            df2 = nd.interpolation.zoom(
                df[0, :, :, :,
                   1], zoom=zoom, mode='nearest', order=3) * zoom[1]
            df3 = nd.interpolation.zoom(
                df[0, :, :, :,
                   2], zoom=zoom, mode='nearest', order=3) * zoom[2]
            dfs = np.stack((df1, df2, df3), axis=3)
            return dfs[np.newaxis, :, :, :]

        class logPrinter(object):
            def __init__(self):
                self.n_batch = 0
                self.total_dice = []
                self.cost = []
                self.ncc = []
                self.grd = []

            def addLog(self, dice, cost, ncc, grd):
                self.n_batch += 1
                self.dice.append(dice)
                self.cost.append(cost)
                self.ncc.append(ncc)
                self.grd.append(grd)

            def output(self):
                dice = np.array(self.dice).mean(axis=0).round(3).tolist()
                cost = np.array(self.cost).mean()
                ncc = np.array(self.ncc).mean()
                grd = np.array(self.grd).mean()
                return dice, cost, ncc, grd, self.n_batch

            def clear(self):
                self.n_batch = 0
                self.dice = []
                self.cost = []
                self.ncc = []
                self.grd = []

        lr_log = logPrinter()
        mr_log = logPrinter()
        hr_log = logPrinter()

        # train low resolution
        # load initial weights
        saver = tf.train.Saver(max_to_keep=1)
        #if model_lr_dir is not None:
        #    model_lr_dir = tf.train.latest_checkpoint(model_lr_dir)  #
        #    print(model_lr_dir)
        #    saver.restore(sess, model_lr_dir)

        nb_epochs = 20  #20#10
        steps_per_epoch = 30 * 29
        for epoch in range(nb_epochs):
            tbar = trange(steps_per_epoch, unit='batch', ncols=100)
            lr_log.clear()
            for i in tbar:
                image, mask = data_gen.__next__()
                global_X, global_atlas = image
                global_X_mask, global_atlas_mask = mask
                global_diff = global_X[0, :, :, :,
                                       0] - global_atlas[0, :, :, :, 0]

                # low resolution
                global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :,
                                                                 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :,
                                                                   0],
                                                     zoom=(0.25, 0.25, 0.25),
                                                     mode='nearest',
                                                     order=0)
                global_AM_64 = nd.interpolation.zoom(
                    global_atlas_mask[0, :, :, :, 0],
                    zoom=(0.25, 0.25, 0.25),
                    mode='nearest',
                    order=0)
                global_diff_16 = nd.interpolation.zoom(global_diff,
                                                       zoom=(0.25, 0.25, 0.25),
                                                       mode='nearest')

                global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis]
                global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis]
                global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis]
                global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis]
                global_diff_16 = global_diff_16[np.newaxis, :, :, :,
                                                np.newaxis]

                feed_dict = {
                    src_lr: global_X_64,
                    tgt_lr: global_A_64,
                    srm_lr: global_XM_64,
                    attn_lr: global_diff_16
                }
                err_lr, _ = sess.run([cost_lr, train_op_lr],
                                     feed_dict=feed_dict)
                df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run(
                    [
                        model_lr[2].outputs, model_lr[1].outputs, lr_ncc,
                        lr_grd, model_lr[3], model_lr[4]
                    ],
                    feed_dict=feed_dict)
                # print(df_lr.shape)
                lr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                  global_AM_64[0, :, :, :, 0],
                                  labels=[0, 10, 150, 250],
                                  nargout=2)
                lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad)
                lr_out = lr_log.output()

                tbar.set_description('Epoch %d/%d ### step %i' %
                                     (epoch + 1, nb_epochs, i))
                tbar.set_postfix(lr_dice=lr_out[0],
                                 lr_cost=lr_out[1],
                                 lr_ncc=lr_out[2],
                                 lr_grd=lr_out[3])

            saver.save(sess, model_lr_dir + 'dfnet', global_step=0)
        # train middle resolution
        nb_epochs = 1  #1
        steps_per_epoch = 30 * 29
        for epoch in range(nb_epochs):
            lr_log.clear()
            for lr_step in range(steps_per_epoch):
                image, mask = data_gen.__next__()
                global_X, global_atlas = image
                global_X_mask, global_atlas_mask = mask
                global_diff = global_X[0, :, :, :,
                                       0] - global_atlas[0, :, :, :, 0]

                # low resolution
                global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :,
                                                                 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :,
                                                                   0],
                                                     zoom=(0.25, 0.25, 0.25),
                                                     mode='nearest',
                                                     order=0)
                global_AM_64 = nd.interpolation.zoom(
                    global_atlas_mask[0, :, :, :, 0],
                    zoom=(0.25, 0.25, 0.25),
                    mode='nearest',
                    order=0)
                global_diff_16 = nd.interpolation.zoom(global_diff,
                                                       zoom=(0.25, 0.25, 0.25),
                                                       mode='nearest')

                global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis]
                global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis]
                global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis]
                global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis]
                global_diff_16 = global_diff_16[np.newaxis, :, :, :,
                                                np.newaxis]

                feed_dict = {
                    src_lr: global_X_64,
                    tgt_lr: global_A_64,
                    srm_lr: global_XM_64,
                    attn_lr: global_diff_16
                }
                err_lr, _ = sess.run([cost_lr, train_op_lr],
                                     feed_dict=feed_dict)
                df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run(
                    [
                        model_lr[2].outputs, model_lr[1].outputs, lr_ncc,
                        lr_grd, model_lr[3], model_lr[4]
                    ],
                    feed_dict=feed_dict)

                lr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                  global_AM_64[0, :, :, :, 0],
                                  labels=[0, 10, 150, 250],
                                  nargout=2)
                lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad)
                lr_out = lr_log.output()

                print('\nEpoch %d/%d ### step %i' %
                      (epoch + 1, nb_epochs, lr_out[-1]))
                print(
                    '[lr] lr_dice={}, lr_cost={:.3f}, lr_ncc={:.3f}, lr_grd={:.3f}'
                    .format(lr_out[0], lr_out[1], lr_out[2], lr_out[3]))

                # middle part
                df_lr_res2mr = resize_df(df_lr, zoom=(2, 2, 2))

                select_points_lr = patch_selection_attn(lr_attn_map,
                                                        zoom_scales=[8, 8, 8],
                                                        kernel=7,
                                                        mi=10,
                                                        ma=18)
                print(select_points_lr)
                mr_log.clear()

                for sp in select_points_lr:
                    mov_img_112 = global_X[0, sp[0] - 56:sp[0] + 56,
                                           sp[1] - 56:sp[1] + 56,
                                           sp[2] - 56:sp[2] + 56, 0]
                    fix_img_112 = global_atlas[0, sp[0] - 56:sp[0] + 56,
                                               sp[1] - 56:sp[1] + 56,
                                               sp[2] - 56:sp[2] + 56, 0]
                    mov_seg_112 = global_X_mask[0, sp[0] - 56:sp[0] + 56,
                                                sp[1] - 56:sp[1] + 56,
                                                sp[2] - 56:sp[2] + 56, 0]
                    fix_seg_112 = global_atlas_mask[0, sp[0] - 56:sp[0] + 56,
                                                    sp[1] - 56:sp[1] + 56,
                                                    sp[2] - 56:sp[2] + 56, 0]
                    dif_img_112 = global_diff[sp[0] - 56:sp[0] + 56,
                                              sp[1] - 56:sp[1] + 56,
                                              sp[2] - 56:sp[2] + 56]

                    #print(mov_img_112.shape)
                    if fix_img_112.shape != (112, 112, 112):
                        print(mov_img_112.shape)
                        continue
                    fix_112_56 = nd.interpolation.zoom(fix_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')
                    mov_112_56 = nd.interpolation.zoom(mov_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')
                    fix_112_56m = nd.interpolation.zoom(fix_seg_112,
                                                        zoom=(0.5, 0.5, 0.5),
                                                        mode='nearest',
                                                        order=0)
                    mov_112_56m = nd.interpolation.zoom(mov_seg_112,
                                                        zoom=(0.5, 0.5, 0.5),
                                                        mode='nearest',
                                                        order=0)
                    dif_112_56 = nd.interpolation.zoom(dif_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')

                    mid_fix_img = fix_112_56[np.newaxis, :, :, :, np.newaxis]
                    mid_mov_img = mov_112_56[np.newaxis, :, :, :, np.newaxis]
                    mid_fix_seg = fix_112_56m[np.newaxis, :, :, :, np.newaxis]
                    mid_mov_seg = mov_112_56m[np.newaxis, :, :, :, np.newaxis]
                    mid_dif_img = dif_112_56[np.newaxis, :, :, :, np.newaxis]
                    df_mr_feed = df_lr_res2mr[:,
                                              sp[0] // 2 - 28:sp[0] // 2 + 28,
                                              sp[1] // 2 - 28:sp[1] // 2 + 28,
                                              sp[2] // 2 - 28:sp[2] // 2 +
                                              28, :]

                    feed_dict = {
                        src_mr: mid_mov_img,
                        tgt_mr: mid_fix_img,
                        srm_mr: mid_mov_seg,
                        df_lr2mr: df_mr_feed,
                        attn_mr: mid_dif_img
                    }
                    err_mr, _ = sess.run([cost_mr, train_op_mr],
                                         feed_dict=feed_dict)
                    df_mr, warp_seg, emr_ncc, emr_grad = sess.run(
                        [
                            model_mr[2].outputs, model_mr[1].outputs, mr_ncc,
                            mr_grd
                        ],
                        feed_dict=feed_dict)

                    mr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                      mid_fix_seg[0, :, :, :, 0],
                                      labels=[0, 10, 150, 250],
                                      nargout=2)
                    mr_log.addLog(mr_dice, err_mr, emr_ncc, emr_grad)
                    mr_out = mr_log.output()

                    # print('  Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1]))
                    print(
                        '  [mr] {}/{} mr_dice={}, mr_cost={:.3f}, mr_ncc={:.3f}, mr_grd={:.3f}'
                        .format(mr_out[-1], len(select_points_lr), mr_out[0],
                                mr_out[1], mr_out[2], mr_out[3]))

            saver.save(sess, model_dir + 'dfnet', global_step=0)

        # train high resolution
        nb_epochs = 1
        steps_per_epoch = 300
        for epoch in range(nb_epochs):
            lr_log.clear()
            for lr_step in range(steps_per_epoch):
                image, mask = data_gen.__next__()
                global_X, global_atlas = image
                global_X_mask, global_atlas_mask = mask
                global_diff = global_X[0, :, :, :,
                                       0] - global_atlas[0, :, :, :, 0]

                # low resolution
                global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :,
                                                                 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :,
                                                                   0],
                                                     zoom=(0.25, 0.25, 0.25),
                                                     mode='nearest',
                                                     order=0)
                global_AM_64 = nd.interpolation.zoom(
                    global_atlas_mask[0, :, :, :, 0],
                    zoom=(0.25, 0.25, 0.25),
                    mode='nearest',
                    order=0)
                global_diff_16 = nd.interpolation.zoom(global_diff,
                                                       zoom=(0.25, 0.25, 0.25),
                                                       mode='nearest')

                global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis]
                global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis]
                global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis]
                global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis]
                global_diff_16 = global_diff_16[np.newaxis, :, :, :,
                                                np.newaxis]

                feed_dict = {
                    src_lr: global_X_64,
                    tgt_lr: global_A_64,
                    srm_lr: global_XM_64,
                    attn_lr: global_diff_16
                }
                err_lr, _ = sess.run([cost_lr, train_op_lr],
                                     feed_dict=feed_dict)
                df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run(
                    [
                        model_lr[2].outputs, model_lr[1].outputs, lr_ncc,
                        lr_grd, model_lr[3], model_lr[4]
                    ],
                    feed_dict=feed_dict)

                lr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                  global_AM_64[0, :, :, :, 0],
                                  labels=[0, 10, 150, 250],
                                  nargout=2)
                lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad)
                lr_out = lr_log.output()

                print('\nEpoch %d/%d ### step %i' %
                      (epoch + 1, nb_epochs, lr_out[-1]))
                print(
                    '[lr] lr_dice={}, lr_cost={:.3f}, lr_ncc={:.3f}, lr_grd={:.3f}'
                    .format(lr_out[0], lr_out[1], lr_out[2], lr_out[3]))

                # middle part
                df_lr_res2mr = resize_df(df_lr, zoom=(2, 2, 2))

                select_points_lr = patch_selection_attn(lr_attn_map,
                                                        zoom_scales=[8, 8, 8],
                                                        kernel=7,
                                                        mi=10,
                                                        ma=18)
                print(select_points_lr)
                mr_log.clear()

                for sp in select_points_lr:
                    mov_img_112 = global_X[0, sp[0] - 56:sp[0] + 56,
                                           sp[1] - 56:sp[1] + 56,
                                           sp[2] - 56:sp[2] + 56, 0]
                    fix_img_112 = global_atlas[0, sp[0] - 56:sp[0] + 56,
                                               sp[1] - 56:sp[1] + 56,
                                               sp[2] - 56:sp[2] + 56, 0]
                    mov_seg_112 = global_X_mask[0, sp[0] - 56:sp[0] + 56,
                                                sp[1] - 56:sp[1] + 56,
                                                sp[2] - 56:sp[2] + 56, 0]
                    fix_seg_112 = global_atlas_mask[0, sp[0] - 56:sp[0] + 56,
                                                    sp[1] - 56:sp[1] + 56,
                                                    sp[2] - 56:sp[2] + 56, 0]
                    dif_img_112 = global_diff[sp[0] - 56:sp[0] + 56,
                                              sp[1] - 56:sp[1] + 56,
                                              sp[2] - 56:sp[2] + 56]

                    #print(mov_img_112.shape)
                    if fix_img_112.shape != (112, 112, 112):
                        print(mov_img_112.shape)
                        continue
                    fix_112_56 = nd.interpolation.zoom(fix_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')
                    mov_112_56 = nd.interpolation.zoom(mov_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')
                    fix_112_56m = nd.interpolation.zoom(fix_seg_112,
                                                        zoom=(0.5, 0.5, 0.5),
                                                        mode='nearest',
                                                        order=0)
                    mov_112_56m = nd.interpolation.zoom(mov_seg_112,
                                                        zoom=(0.5, 0.5, 0.5),
                                                        mode='nearest',
                                                        order=0)
                    dif_112_56 = nd.interpolation.zoom(dif_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')

                    mid_fix_img = fix_112_56[np.newaxis, :, :, :, np.newaxis]
                    mid_mov_img = mov_112_56[np.newaxis, :, :, :, np.newaxis]
                    mid_fix_seg = fix_112_56m[np.newaxis, :, :, :, np.newaxis]
                    mid_mov_seg = mov_112_56m[np.newaxis, :, :, :, np.newaxis]
                    mid_dif_img = dif_112_56[np.newaxis, :, :, :, np.newaxis]
                    df_mr_feed = df_lr_res2mr[:,
                                              sp[0] // 2 - 28:sp[0] // 2 + 28,
                                              sp[1] // 2 - 28:sp[1] // 2 + 28,
                                              sp[2] // 2 - 28:sp[2] // 2 +
                                              28, :]

                    feed_dict = {
                        src_mr: mid_mov_img,
                        tgt_mr: mid_fix_img,
                        srm_mr: mid_mov_seg,
                        df_lr2mr: df_mr_feed,
                        attn_mr: mid_dif_img
                    }
                    err_mr, _ = sess.run([cost_mr, train_op_mr],
                                         feed_dict=feed_dict)
                    df_mr, warp_seg, emr_ncc, emr_grad, mr_attn_map, mr_attn_feature = sess.run(
                        [
                            model_mr[2].outputs, model_mr[1].outputs, mr_ncc,
                            mr_grd, model_mr[3], model_mr[4]
                        ],
                        feed_dict=feed_dict)

                    mr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                      mid_fix_seg[0, :, :, :, 0],
                                      labels=[0, 10, 150, 250],
                                      nargout=2)
                    mr_log.addLog(mr_dice, err_mr, emr_ncc, emr_grad)
                    mr_out = mr_log.output()

                    # print('  Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1]))
                    print(
                        '  [mr] {}/{} mr_dice={}, mr_cost={:.3f}, mr_ncc={:.3f}, mr_grd={:.3f}'
                        .format(mr_out[-1], len(select_points_lr), mr_out[0],
                                mr_out[1], mr_out[2], mr_out[3]))

                    # high part
                    df_mr_res2hr = resize_df(df_mr, zoom=(2, 2, 2))
                    hr_log.clear()
                    select_points_mr = patch_selection_attn(
                        mr_attn_map,
                        zoom_scales=[4, 4, 4],
                        kernel=7,
                        mi=8,
                        ma=20)
                    print(30 * '-')
                    print('High Part')
                    print(select_points_mr)
                    for spm in select_points_mr:
                        fix_img_56 = fix_img_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        mov_img_56 = mov_img_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        fix_seg_56 = fix_seg_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        mov_seg_56 = mov_seg_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        dif_img_56 = dif_img_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        if fix_img_56.shape != (56, 56, 56):
                            continue

                        hig_fix_img = fix_img_56[np.newaxis, :, :, :,
                                                 np.newaxis]
                        hig_mov_img = mov_img_56[np.newaxis, :, :, :,
                                                 np.newaxis]
                        hig_fix_seg = fix_seg_56[np.newaxis, :, :, :,
                                                 np.newaxis]
                        hig_mov_seg = mov_seg_56[np.newaxis, :, :, :,
                                                 np.newaxis]
                        hig_dif_img = dif_img_56[np.newaxis, :, :, :,
                                                 np.newaxis]

                        df_hr_feed = df_mr_res2hr[:, spm[0] - 28:spm[0] + 28,
                                                  spm[1] - 28:spm[1] + 28,
                                                  spm[2] - 28:spm[2] + 28, :]

                        feed_dict = {
                            src_hr: hig_mov_img,
                            tgt_hr: hig_fix_img,
                            srm_hr: hig_mov_seg,
                            df_mr2hr: df_hr_feed,
                            attn_hr: hig_dif_img
                        }
                        err_hr, _ = sess.run([cost_hr, train_op_hr],
                                             feed_dict=feed_dict)
                        df_hr, warp_seg, ehr_ncc, ehr_grad = sess.run(
                            [
                                model_hr[2].outputs, model_hr[1].outputs,
                                hr_ncc, hr_grd
                            ],
                            feed_dict=feed_dict)

                        hr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                          hig_fix_seg[0, :, :, :, 0],
                                          labels=[0, 10, 150, 250],
                                          nargout=2)
                        hr_log.addLog(hr_dice, err_hr, ehr_ncc, ehr_grad)
                        hr_out = hr_log.output()

                        # print('  Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1]))
                        print(
                            '    [hr] {}/{} hr_dice={}, hr_cost={:.3f}, hr_ncc={:.3f}, hr_grd={:.3f}'
                            .format(hr_out[-1], len(select_points_mr),
                                    hr_out[0], hr_out[1], hr_out[2],
                                    hr_out[3]))

                saver.save(sess, model_dir + 'dfnet', global_step=lr_step)
def test(iter_num, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3]):
 gpu = '/gpu:' + str(gpu_id)

 # Anatomical labels we want to evaluate
 labels = sio.loadmat('../data/labels.mat')['labels'][0]

 # read atlas
 atlas_vol1, atlas_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz',
                                                              '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990114_vc722.npz')# [1,160,192,224,1]
 atlas_seg1 = atlas_seg1[0,:,:,:,0]# reduce the dimension to [160,192,224]

 atlas_vol2, atlas_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz',
                                                              '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990210_vc792.npz')
 atlas_seg2 = atlas_seg2[0, :, :, :, 0]

 #gpu = '/gpu:' + str(gpu_id)
 os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
 config = tf.ConfigProto()
 config.gpu_options.allow_growth = True
 config.allow_soft_placement = True
 set_session(tf.Session(config=config))

 # load weights of model
 with tf.device(gpu):
    net = networks.unet(vol_size, nf_enc, nf_dec)
    net.load_weights('/home/ys895/MAS2_Models/'+str(iter_num)+'.h5')
    #net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5')

 xx = np.arange(vol_size[1])
 yy = np.arange(vol_size[0])
 zz = np.arange(vol_size[2])
 grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4) # (160, 192, 224, 3)
 #X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz')
 X_vol1, X_seg1 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/981216_vc681.npz',
                                                      '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/981216_vc681.npz')

 X_vol2, X_seg2 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990205_vc783.npz',
                                                      '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990205_vc783.npz')

 X_vol3, X_seg3 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990525_vc1024.npz',
                                                     '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990525_vc1024.npz')

 X_vol4, X_seg4 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991025_vc1379.npz',
                                                      '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991025_vc1379.npz')

 X_vol5, X_seg5 = datagenerators.load_example_by_name('/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1463.npz',
                                                     '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991122_vc1463.npz')

 # change the direction of the atlas data and volume data
 # pred[0].shape (1, 160, 192, 224, 1)
 # pred[1].shape (1, 160, 192, 224, 3)
 # X1
 with tf.device(gpu):
    pred1 = net.predict([atlas_vol1, X_vol1])

 # Warp segments with flow
 flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3)

 sample1 = flow1+grid
 sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)

 warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224)



 # label fusion: get the final warp_seg
 warp_seg = np.empty((160, 192, 224))
 for x in range(0,160):
     for y in range(0,192):
         for z in range(0,224):
          warp_seg = np.array(warp_seg1[x, y, z])

 vals, _ = dice(warp_seg, X_seg1[0, :, :, :, 0], labels=labels, nargout=2)
 mean1 = np.mean(vals)

 # X2
 with tf.device(gpu):
    pred1 = net.predict([atlas_vol1, X_vol2])

 # Warp segments with flow
 flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3)

 sample1 = flow1+grid
 sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)

 warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224)



 # label fusion: get the final warp_seg
 warp_seg = np.empty((160, 192, 224))
 for x in range(0,160):
     for y in range(0,192):
         for z in range(0,224):
          warp_seg = np.array(warp_seg1[x, y, z])

 vals, _ = dice(warp_seg, X_seg2[0,:,:,:,0], labels=labels, nargout=2)
 mean2 = np.mean(vals)
 #print(np.mean(vals), np.std(vals))

# X3
 with tf.device(gpu):
    pred1 = net.predict([atlas_vol1, X_vol3])

 # Warp segments with flow
 flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3)

 sample1 = flow1+grid
 sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)

 warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224)



 # label fusion: get the final warp_seg
 warp_seg = np.empty((160, 192, 224))
 for x in range(0,160):
     for y in range(0,192):
         for z in range(0,224):
           warp_seg = np.array(warp_seg1[x, y, z])

 vals, _ = dice(warp_seg, X_seg3[0, :, :, :, 0], labels=labels, nargout=2)
 mean3 = np.mean(vals)

# X4
 with tf.device(gpu):
    pred1 = net.predict([atlas_vol1, X_vol4])

 # Warp segments with flow
 flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3)

 sample1 = flow1+grid
 sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)

 warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224)


 # label fusion: get the final warp_seg
 warp_seg = np.empty((160, 192, 224))
 for x in range(0,160):
     for y in range(0,192):
         for z in range(0,224):
           warp_seg = np.array(warp_seg1[x, y, z])

 vals, _ = dice(warp_seg, X_seg4[0, :, :, :, 0], labels=labels, nargout=2)
 mean4 = np.mean(vals)


# X5
 with tf.device(gpu):
    pred1 = net.predict([atlas_vol1, X_vol5])

 # Warp segments with flow
 flow1 = pred1[1][0, :, :, :, :]# (1, 160, 192, 224, 3)

 sample1 = flow1+grid
 sample1 = np.stack((sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)

 warp_seg1 = interpn((yy, xx, zz), atlas_seg1[ :, :, : ], sample1, method='nearest', bounds_error=False, fill_value=0) # (160, 192, 224)


 # label fusion: get the final warp_seg
 warp_seg = np.empty((160, 192, 224))
 for x in range(0,160):
     for y in range(0,192):
         for z in range(0,224):
             warp_seg = np.array(warp_seg1[x,y,z])
             #print(warp_arr)
             #warp_seg[x,y,z] = stats.mode(warp_arr)[0]

 vals, _ = dice(warp_seg, X_seg5[0, :, :, :, 0], labels=labels, nargout=2)
 mean5 = np.mean(vals)

 # compute mean of dice score
 sum = mean1 + mean2 + mean3 + mean4 + mean5
 mean_dice = sum/5
 print(mean_dice)
def test(
        gpu_id,
        iter_num,
        compute_type='GPU',  # GPU or CPU
        vol_size=(160, 192, 224),
        nf_enc=[16, 32, 32, 32],
        nf_dec=[32, 32, 32, 32, 16, 3],
        save_file=None):
    """
    test by segmentation, compute dice between atlas_seg and warp_seg
    :param gpu_id: gpu id
    :param iter_num: specify the model to read
    :param compute_type: CPU/GPU
    :param vol_size: volume size
    :param nf_enc: number of encoder
    :param nf_dec: number of decoder
    :param save_file: None
    :return: None
    """

    # GPU handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        # if testing miccai run, should be xy indexing.
        net = networks.miccai2018_net(vol_size,
                                      nf_enc,
                                      nf_dec,
                                      use_miccai_int=True,
                                      indexing='xy')
        model_dir = "/home/ys895/rigid_diff_model/"
        net.load_weights(os.path.join(model_dir, str(iter_num) + '.h5'))

        # compose diffeomorphic flow output model
        diff_net = keras.models.Model(net.inputs,
                                      net.get_layer('diffflow').output)

        # NN transfer model
        nn_trf_model = networks.nn_trf(vol_size)

    # if CPU, prepare grid
    if compute_type == 'CPU':
        grid, xx, yy, zz = util.volshape2grid_3d(vol_size, nargout=4)

    # prepare a matrix of dice values
    dice_vals = np.zeros((len(good_labels), n_batches))
    for k in range(n_batches):
        # get data
        vol_name, seg_name = test_brain_strings[k].split(",")
        X_vol, X_seg = datagenerators.load_example_by_name(vol_name, seg_name)
        orig_vol = X_vol
        orig_seg = X_seg

        theta = 0
        beta = 5
        omega = 0
        X_seg = rotate_img(X_seg[0, :, :, :, 0],
                           theta=theta,
                           beta=beta,
                           omega=omega)
        X_vol = rotate_img(X_vol[0, :, :, :, 0],
                           theta=theta,
                           beta=beta,
                           omega=omega)
        X_seg = X_seg.reshape((1, ) + X_seg.shape + (1, ))
        X_vol = X_vol.reshape((1, ) + X_vol.shape + (1, ))

        sample_num = 30
        grid_dimension = 4

        # predict transform
        with tf.device(gpu):
            pred = diff_net.predict([X_vol, atlas_vol])

        # Warp segments with flow
        if compute_type == 'CPU':
            flow = pred[0, :, :, :, :]
            warp_seg = util.warp_seg(X_seg,
                                     flow,
                                     grid=grid,
                                     xx=xx,
                                     yy=yy,
                                     zz=zz)
        else:  # GPU

            flow = pred[0, :, :, :, :]

            # sample coordinate(sample_num * sample_num * sample_num)
            x = np.linspace(0, (vol_size[0] / sample_num) * (sample_num - 1),
                            sample_num)
            x = x.astype(np.int32)
            y = np.linspace(0, (vol_size[1] / sample_num) * (sample_num - 1),
                            sample_num)
            y = y.astype(np.int32)
            z = np.linspace(0, (vol_size[2] / sample_num) * (sample_num - 1),
                            sample_num)
            z = z.astype(np.int32)
            index = np.rollaxis(np.array(np.meshgrid(y, x, z)), 0, 4)
            x = index[:, :, :, 1]
            y = index[:, :, :, 0]
            z = index[:, :, :, 2]

            # Y in formula
            x_flow = np.arange(vol_size[0])
            y_flow = np.arange(vol_size[1])
            z_flow = np.arange(vol_size[2])
            grid = np.rollaxis(np.array((np.meshgrid(y_flow, x_flow, z_flow))),
                               0, 4)  # original coordinate
            grid_x = grid_sample(x, y, z, grid[:, :, :, 1], sample_num)
            grid_y = grid_sample(x, y, z, grid[:, :, :, 0], sample_num)
            grid_z = grid_sample(x, y, z, grid[:, :, :, 2],
                                 sample_num)  # X (10,10,10)

            sample = flow + grid
            sample_x = grid_sample(x, y, z, sample[:, :, :, 1], sample_num)
            sample_y = grid_sample(x, y, z, sample[:, :, :, 0], sample_num)
            sample_z = grid_sample(x, y, z, sample[:, :, :, 2],
                                   sample_num)  # Y (10,10,10)

            sum_x = np.sum(flow[:, :, :, 1])
            sum_y = np.sum(flow[:, :, :, 0])
            sum_z = np.sum(flow[:, :, :, 2])

            ave_x = sum_x / (vol_size[0] * vol_size[1] * vol_size[2])
            ave_y = sum_y / (vol_size[0] * vol_size[1] * vol_size[2])
            ave_z = sum_z / (vol_size[0] * vol_size[1] * vol_size[2])

            # formula
            Y = np.zeros((sample_num, sample_num, sample_num, grid_dimension))
            X = np.zeros((sample_num, sample_num, sample_num, grid_dimension))
            T = np.array([ave_x, ave_y, ave_z, 1])  # (4,1)
            print(T)

            for i in np.arange(sample_num):
                for j in np.arange(sample_num):
                    for z in np.arange(sample_num):
                        Y[i, j, z, :] = np.array([
                            sample_x[i, j, z], sample_y[i, j, z],
                            sample_z[i, j, z], 1
                        ])
                        #Y[i, j, z, :] = Y[i, j, z, :] - np.array([ave_x, ave_y, ave_z, 0])  # amend: Y` = Y - T

            for i in np.arange(sample_num):
                for j in np.arange(sample_num):
                    for z in np.arange(sample_num):
                        X[i, j, z, :] = np.array([
                            grid_x[i, j, z], grid_y[i, j, z], grid_z[i, j, z],
                            1
                        ])

            X = X.reshape(
                (sample_num * sample_num * sample_num, grid_dimension))
            Y = Y.reshape(
                (sample_num * sample_num * sample_num, grid_dimension))
            R = np.dot(
                np.dot(np.linalg.pinv(np.dot(np.transpose(X), X)),
                       np.transpose(X)), Y)  # R(4, 4)
            print(R)
            beta = -(beta / 180) * math.pi
            R = np.array([[math.cos(beta), 0, -math.sin(beta), 0],
                          [0, 1, 0, 0], [math.sin(beta), 0,
                                         math.cos(beta), 0], [0, 0, 0, 1]])
            #R = R.transpose()

            # build new grid(Use R to do the spatial transform)
            shifted_x = np.arange(vol_size[0])
            shifted_y = np.arange(vol_size[1])
            shifted_z = np.arange(vol_size[2])
            shifted_grid = np.rollaxis(
                np.array((np.meshgrid(shifted_y, shifted_x, shifted_z))), 0, 4)

            # some required matrixs
            T1 = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0],
                           [
                               -int(vol_size[0] / 2), -int(vol_size[1] / 2),
                               -int(vol_size[2] / 2), 1
                           ]])

            T2 = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0],
                           [
                               int(vol_size[0] / 2),
                               int(vol_size[1] / 2),
                               int(vol_size[2] / 2), 1
                           ]])

            for i in np.arange(vol_size[0]):
                for j in np.arange(vol_size[1]):
                    for z in np.arange(vol_size[2]):
                        #coordinates = np.dot(R, np.array([i, j, z, 1]).reshape(4, 1)) + T.reshape(4, 1)
                        coordinates = np.dot(
                            np.dot(
                                np.dot(
                                    np.array([i, j, z, 1]).reshape(1, 4), T1),
                                R), T2)  # new implementation
                        # print("voxel." + '(' + str(i) + ',' + str(j) + ',' + str(z) + ')')
                        shifted_grid[i, j, z, 1] = coordinates[0, 0]
                        shifted_grid[i, j, z, 0] = coordinates[0, 1]
                        shifted_grid[i, j, z, 2] = coordinates[0, 2]

            # interpolation
            xx = np.arange(vol_size[1])
            yy = np.arange(vol_size[0])
            zz = np.arange(vol_size[2])
            shifted_grid = np.stack(
                (shifted_grid[:, :, :, 1], shifted_grid[:, :, :, 0],
                 shifted_grid[:, :, :, 2]), 3
            )  # notice: the shifted_grid is reverse in x and y, so this step is used for making it back.
            warp_seg = interpn((yy, xx, zz),
                               X_seg[0, :, :, :, 0],
                               shifted_grid,
                               method='nearest',
                               bounds_error=False,
                               fill_value=0)  # rigid registration
            warp_vol = interpn((yy, xx, zz),
                               X_vol[0, :, :, :, 0],
                               shifted_grid,
                               method='nearest',
                               bounds_error=False,
                               fill_value=0)  # rigid registration

        # compute Volume Overlap (Dice)
        dice_vals[:, k] = dice(warp_seg,
                               orig_seg[0, :, :, :, 0],
                               labels=good_labels)
        print('%3d %5.3f %5.3f' % (k, np.mean(
            dice_vals[:, k]), np.mean(np.mean(dice_vals[:, :k + 1]))))

        if save_file is not None:
            sio.savemat(save_file, {
                'dice_vals': dice_vals,
                'labels': good_labels
            })

        # specify slice
        num_slice = 90

        plt.figure()
        plt.subplot(1, 3, 1)
        plt.imshow(orig_vol[0, :, num_slice, :, 0])
        plt.subplot(1, 3, 2)
        plt.imshow(X_vol[0, :, num_slice, :, 0])
        plt.subplot(1, 3, 3)
        plt.imshow(warp_vol[:, num_slice, :])
        plt.savefig("slice" + str(num_slice) + '_' + str(k) + ".png")

        plt.figure()
        plt.subplot(1, 3, 1)
        plt.imshow(flow[:, num_slice, :, 1])
        plt.subplot(1, 3, 2)
        plt.imshow(flow[:, num_slice, :, 0])
        plt.subplot(1, 3, 3)
        plt.imshow(flow[:, num_slice, :, 2])
        plt.savefig("flow.png")
예제 #17
0
def test(
        model_name,
        gpu_id,
        compute_type='GPU',  # GPU or CPU
        nf_enc=[16, 32, 32, 32],
        nf_dec=[32, 32, 32, 32, 32, 16, 16]):
    """
    test

    nf_enc and nf_dec
    #nf_dec = [32,32,32,32,32,16,16,3]
    # This needs to be changed. Ideally, we could just call load_model, and we wont have to
    # specify the # of channels here, but the load_model is not working with the custom loss...
    """

    # Anatomical labels we want to evaluate
    labels = sio.loadmat('../data/labels.mat')['labels'][0]

    atlas = np.load('../data/atlas_norm.npz')
    atlas_vol = atlas['vol'][np.newaxis, ..., np.newaxis]
    atlas_seg = atlas['seg']
    vol_size = atlas_vol.shape[1:-1]

    # gpu handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        net = networks.cvpr2018_net(vol_size, nf_enc, nf_dec)
        net.load_weights(model_name)

        # NN transfer model
        nn_trf_model = networks.nn_trf(vol_size, indexing='ij')

    # if CPU, prepare grid
    if compute_type == 'CPU':
        grid, xx, yy, zz = util.volshape2grid_3d(vol_size, nargout=4)

    # load subject test
    X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz',
                                                       '../data/test_seg.npz')

    with tf.device(gpu):
        pred = net.predict([X_vol, atlas_vol])

        # Warp segments with flow
        if compute_type == 'CPU':
            flow = pred[1][0, :, :, :, :]
            warp_seg = util.warp_seg(X_seg,
                                     flow,
                                     grid=grid,
                                     xx=xx,
                                     yy=yy,
                                     zz=zz)

        else:  # GPU
            warp_seg = nn_trf_model.predict([X_seg, pred[1]])[0, ..., 0]

    vals, _ = dice(warp_seg, atlas_seg, labels=labels, nargout=2)
    dice_mean = np.mean(vals)
    dice_std = np.std(vals)
    print('Dice mean over structures: {:.2f} ({:.2f})'.format(
        dice_mean, dice_std))
예제 #18
0
def test(load_iters, gpu_id, vol_size=(160,192,224), nf_enc=[16,32,32,32], nf_dec=[32,32,32,32,32,16,16,3], sample_num = 10, grid_dimension = 4):
    """
    Test of the rigid registration by calculating the dice score between the atlas's segmentation and warped image's segmentation
    :param iter_num: iteration number
    :param gpu_id: gpu id
    :param vol_size: volume's size
    :param nf_enc: number of encode
    :param nf_dec: number of decoder
    :param model_name: load model's name
    :param sample_num: sample grid's dimension, this can be changed to improve the performance
    :param grid_dimension: R(in the formula)'s dimension
    :return: None
    """
    gpu = '/gpu:' + str(gpu_id)

    # Anatomical labels we want to evaluate
    labels = sio.loadmat('../data/labels.mat')['labels'][0]

    atlas = np.load('../data/atlas_norm.npz')
    atlas_vol = atlas['vol']
    atlas_seg = atlas['seg']
    atlas_vol = np.reshape(atlas_vol, (1,)+atlas_vol.shape+(1,))

    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        net = networks.unet(vol_size, nf_enc, nf_dec)
        net.load_weights('../rigid_model/' + load_iters + '.h5', by_name=True)

    X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz')

    orig_vol = X_vol

    theta = 0
    beta = 4
    omega = 0
    X_seg = rotate_img(X_seg[0, :, :, :, 0], theta=theta, beta=beta, omega=omega)
    X_vol = rotate_img(X_vol[0, :, :, :, 0], theta=theta, beta=beta, omega=omega)
    X_seg = X_seg.reshape((1,) + X_seg.shape + (1,))
    X_vol = X_vol.reshape((1,) + X_vol.shape + (1,))

    with tf.device(gpu):
        pred = net.predict([X_vol, atlas_vol])

    # get flow
    flow = pred[1][0, :, :, :, :]

    # sample coordinate(sample_num * sample_num * sample_num)
    x = np.linspace(0, (vol_size[0]/sample_num)*(sample_num-1), sample_num)
    x = x.astype(np.int32)
    y = np.linspace(0, (vol_size[1]/sample_num)*(sample_num-1), sample_num)
    y = y.astype(np.int32)
    z = np.linspace(0, (vol_size[2]/sample_num)*(sample_num-1), sample_num)
    z = z.astype(np.int32)
    index = np.rollaxis(np.array(np.meshgrid(y, x, z)), 0, 4)
    x = index[:, :, :, 1]
    y = index[:, :, :, 0]
    z = index[:, :, :, 2]

    # Y in formula
    x_flow = np.arange(vol_size[0])
    y_flow = np.arange(vol_size[1])
    z_flow = np.arange(vol_size[2])
    grid = np.rollaxis(np.array((np.meshgrid(y_flow, x_flow, z_flow))), 0, 4)# original coordinate
    grid_x = grid_sample(x, y, z, grid[:, :, :, 1], sample_num)
    grid_y = grid_sample(x, y, z, grid[:, :, :, 0], sample_num)
    grid_z = grid_sample(x, y, z, grid[:, :, :, 2], sample_num)#X (10,10,10)

    sample = flow + grid
    sample_x = grid_sample(x, y, z, sample[:, :, :, 1], sample_num)
    sample_y = grid_sample(x, y, z, sample[:, :, :, 0], sample_num)
    sample_z = grid_sample(x, y, z, sample[:, :, :, 2], sample_num)#Y (10,10,10)

    sum_x = np.sum(flow[:, :, :, 1])
    sum_y = np.sum(flow[:, :, :, 0])
    sum_z = np.sum(flow[:, :, :, 2])

    ave_x = sum_x/(vol_size[0] * vol_size[1] * vol_size[2])
    ave_y = sum_y/(vol_size[0] * vol_size[1] * vol_size[2])
    ave_z = sum_z/(vol_size[0] * vol_size[1] * vol_size[2])

    # formula
    Y = np.zeros((sample_num, sample_num, sample_num, grid_dimension))
    X = np.zeros((sample_num, sample_num, sample_num, grid_dimension))
    T = np.array([ave_x, ave_y, ave_z, 1])#(4,1)
    #R = np.zeros((10, 10, 10, grid_dimension, grid_dimension))

    for i in np.arange(sample_num):
        for j in np.arange(sample_num):
            for z in np.arange(sample_num):
                Y[i, j, z, :] = np.array([sample_x[i,j,z], sample_y[i,j,z], sample_z[i,j,z], 1])
                Y[i, j, z, :] = Y[i, j, z, :] - T# amend: Y` = Y - T

    for i in np.arange(sample_num):
        for j in np.arange(sample_num):
            for z in np.arange(sample_num):
                X[i, j, z, :] = np.array([grid_x[i, j, z], grid_y[i, j, z], grid_z[i, j, z], 1])

    X = X.reshape((sample_num * sample_num * sample_num, grid_dimension))
    Y = Y.reshape((sample_num * sample_num * sample_num, grid_dimension))
    R = np.dot(np.dot(np.linalg.pinv(np.dot(np.transpose(X), X)), np.transpose(X)), Y)# R
    print(R)
    # build new grid(Use R to do the spatial transform)
    shifted_x = np.arange(vol_size[0])
    shifted_y = np.arange(vol_size[1])
    shifted_z = np.arange(vol_size[2])
    shifted_grid = np.rollaxis(np.array((np.meshgrid(shifted_y, shifted_x, shifted_z))), 0, 4)

    for i in np.arange(vol_size[0]):
        for j in np.arange(vol_size[1]):
            for z in np.arange(vol_size[2]):
                coordinates = np.dot(R, np.array([i, j, z, 1]).reshape(4,1)) +  T.reshape(4,1)
                #print("voxel." + '(' + str(i) + ',' + str(j) + ',' + str(z) + ')')
                shifted_grid[i, j, z, 1] = coordinates[0]
                shifted_grid[i, j, z, 0] = coordinates[1]
                shifted_grid[i, j, z, 2] = coordinates[2]

    # interpolation
    xx = np.arange(vol_size[1])
    yy = np.arange(vol_size[0])
    zz = np.arange(vol_size[2])
    shifted_grid = np.stack((shifted_grid[:, :, :, 1], shifted_grid[:, :, :, 0], shifted_grid[:, :, :, 2]), 3)# notice: the shifted_grid is reverse in x and y, so this step is used for making it back.
    warp_seg = interpn((yy, xx, zz), X_seg[0, :, :, :, 0], shifted_grid, method='nearest', bounds_error=False, fill_value=0)# rigid registration
    warp_vol = interpn((yy, xx, zz), X_vol[0, :, :, :, 0], shifted_grid, method='nearest', bounds_error=False, fill_value=0)# rigid registration

    # CVPR
    #grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4)
    #sample = flow + grid
    #sample = np.stack((sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3)
    #warp_seg2 = interpn((yy, xx, zz), X_seg[0, :, :, :, 0], sample, method='nearest', bounds_error=False, fill_value=0)# deformable registration

    # compute dice
    vals, _ = dice(warp_seg, atlas_seg, labels=labels, nargout=2)
    #vals2, _ = dice(X_seg[0, :, :, :, 0], atlas_seg, labels=labels, nargout=2)
    #vals3, _ = dice(warp_seg2, atlas_seg, labels=labels, nargout=2)
    #print("dice before:")
    #print(np.mean(vals2), np.std(vals2))
    #print("dice after deformable registration:")
    #print(np.mean(vals3), np.std(vals3))
    print("dice after rigid registration:")
    print(np.mean(vals), np.std(vals))

    # plot
    #fig1, axs1 = nplt.slices(warp_seg[100, :, :], do_colorbars=True)
    #fig1.savefig('warp_seg100.png')
    #fig2, axs2 = nplt.slices(warp_seg[130, :, :], do_colorbars=True)
    #fig2.savefig('warp_seg130.png')
    #fig3, axs3 = nplt.slices(atlas_seg[100, :, :], do_colorbars=True)
    #fig3.savefig('atlas_seg100.png')
    #fig4, axs4 = nplt.slices(atlas_seg[130, :, :], do_colorbars=True)
    #fig4.savefig('atlas_seg130.png')

    # specify slice
    num_slice = 90

    plt.figure()
    plt.subplot(1, 3, 1)
    plt.imshow(orig_vol[0, :, num_slice, :, 0])
    plt.subplot(1, 3, 2)
    plt.imshow(X_vol[0, :, num_slice, :, 0])
    plt.subplot(1, 3, 3)
    plt.imshow(warp_vol[:, num_slice, :])
    plt.savefig("slice" + str(num_slice) + '_' + str(k) + ".png")
def test(gpu_id=0,
         model_dir="../model/lpba40",
         iter_num="00",
         compute_type = 'GPU',  # GPU or CPU
         vol_size=(160, 192, 224),
         nf_enc=[16,32,32,32],
         nf_dec=[32,32,32,32,16,3],
         save_file=None):
    """
    test via segmetnation propagation
    works by iterating over some iamge files, registering them to atlas,
    propagating the warps, then computing Dice with atlas segmentations
    """

    # GPU handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        # if testing miccai run, should be xy indexing.
        net = networks_lpba40.miccai2018_net(vol_size, nf_enc, nf_dec, use_miccai_int=False, indexing='ij')
        net.load_weights(os.path.join(model_dir, str(iter_num) + '.h5'))
        print(os.path.join(model_dir, str(iter_num) + '.h5'))

        # compose diffeomorphic flow output model
        diff_net = keras.models.Model(net.inputs, net.get_layer('diffflow').output)

        # NN transfer model
        nn_trf_model = networks_lpba40.nn_trf(vol_size, indexing='ij')
        nn_trf_model = networks_lpba40.nn_trf(vol_size, indexing='ij')

    # if CPU, prepare grid
    if compute_type == 'CPU':
        # grid, xx, yy, zz = util.volshape2grid_3d(vol_size, nargout=4)
        print('Error: No GPU.')

    # prepare a matrix of dice values
    dice_vals = np.zeros((len(good_labels), n_batches))
    dice_vals_FL = np.zeros((len(good_labels_FL), n_batches))
    dice_vals_PL = np.zeros((len(good_labels_PL), n_batches))
    dice_vals_OL = np.zeros((len(good_labels_OL), n_batches))
    dice_vals_TL = np.zeros((len(good_labels_TL), n_batches))
    dice_vals_CL = np.zeros((len(good_labels_CL), n_batches))
    dice_vals_Ptm = np.zeros((len(good_labels_Ptm), n_batches))
    dice_vals_Hpcp = np.zeros((len(good_labels_Hpcp), n_batches))

    for k in range(n_batches):
        print(111)
        # get data
        vol_name, atlas_vol_name = test_brain_strings[k].split(",")

        # seg
        seg_name = vol_name.replace("/data/lpba40/Brains_MNIspace_reglinear/",
                                    "/data/lpba40/Segmentations/")
        atlas_seg_name = atlas_vol_name.replace("/data/lpba40/Brains_MNIspace_reglinear/",
                                    "/data/lpba40/Segmentations/")
        # vol
        X_vol, X_seg = datagenerators_lpba40.load_example_by_name(vol_name, seg_name, fixed=True)
        atlas_vol, atlas_seg = datagenerators_lpba40.load_example_by_name(atlas_vol_name, atlas_seg_name, fixed=True)
        atlas_seg = atlas_seg[0, ..., 0]

        # predict transform
        with tf.device(gpu):
            pred = diff_net.predict([X_vol, atlas_vol])
            [y, flow_params, flow_params0, flow_params1, flow_params2, y0, y1, y2] = net.predict([X_vol, atlas_vol])

        # Warp segments with flow
        if compute_type == 'CPU':
            print('Error: No GPU.')
        else:
            warp_seg = nn_trf_model.predict([X_seg, pred])[0,...,0]

        # compute Volume Overlap (Dice)
        dice_vals[:, k] = dice(warp_seg, atlas_seg, labels=good_labels)
        dice_vals_FL[:, k] = dice(warp_seg, atlas_seg, labels=good_labels_FL)
        dice_vals_PL[:, k] = dice(warp_seg, atlas_seg, labels=good_labels_PL)
        dice_vals_OL[:, k] = dice(warp_seg, atlas_seg, labels=good_labels_OL)
        dice_vals_TL[:, k] = dice(warp_seg, atlas_seg, labels=good_labels_TL)
        dice_vals_CL[:, k] = dice(warp_seg, atlas_seg, labels=good_labels_CL)
        dice_vals_Ptm[:, k] = dice(warp_seg, atlas_seg, labels=good_labels_Ptm)
        dice_vals_Hpcp[:, k] = dice(warp_seg, atlas_seg, labels=good_labels_Hpcp)

        print('%s %3d: %5.3f All: %5.3f' % (vol_name, k, np.mean(dice_vals[:, k]), np.mean(np.mean(dice_vals[:, :k+1]))))
        print('%s %3d: %5.3f All: %5.3f' % ("FL", k, np.mean(dice_vals_FL[:, k]), np.mean(np.mean(dice_vals_FL[:, :k+1]))))
        print('%s %3d: %5.3f All: %5.3f' % ("PL", k, np.mean(dice_vals_PL[:, k]), np.mean(np.mean(dice_vals_PL[:, :k+1]))))
        print('%s %3d: %5.3f All: %5.3f' % ("OL", k, np.mean(dice_vals_OL[:, k]), np.mean(np.mean(dice_vals_OL[:, :k+1]))))
        print('%s %3d: %5.3f All: %5.3f' % ("TL", k, np.mean(dice_vals_TL[:, k]), np.mean(np.mean(dice_vals_TL[:, :k+1]))))
        print('%s %3d: %5.3f All: %5.3f' % ("CL", k, np.mean(dice_vals_CL[:, k]), np.mean(np.mean(dice_vals_CL[:, :k+1]))))
        print('%s %3d: %5.3f All: %5.3f' % ("Ptm", k, np.mean(dice_vals_Ptm[:, k]), np.mean(np.mean(dice_vals_Ptm[:, :k+1]))))
        print('%s %3d: %5.3f All: %5.3f' % ("Hpcp", k, np.mean(dice_vals_Hpcp[:, k]), np.mean(np.mean(dice_vals_Hpcp[:, :k+1]))))

    if save_file is not None:
        sio.savemat(save_file, {'dice_vals': dice_vals, 'labels': good_labels})
예제 #20
0
def eval_seg_from_gen(segmenter_model,
                      eval_gen,
                      label_mapping,
                      n_eval_examples,
                      batch_size,
                      logger=None):
    '''
    Evaluates accuracy of a segmentation CNN
    :param segmenter_model: keras model for segmenter
    :param eval_gen: genrator that yields vols_valid, segs_valid
    :param label_mapping: list of label ids, ordered by how they map to channels
    :param n_eval_examples: total number of volumes to evaluate
    :param batch_size: batch size (number of slices per batch)
    :param logger: python logger (optional)
    :return:
    '''

    # test metrics: categorical cross-entropy and dice
    dice_per_label = np.zeros((n_eval_examples, len(label_mapping)))
    cces = np.zeros((n_eval_examples, ))
    accs = np.zeros((n_eval_examples, ))
    all_ids = []
    for bi in range(n_eval_examples):
        if logger is not None:
            logger.debug('Testing on subject {} of {}'.format(
                bi, n_eval_examples))
        else:
            print('Testing on subject {} of {}'.format(bi, n_eval_examples))
        X, Y, _, ids = next(eval_gen)
        Y_oh = labels_to_onehot(Y, label_mapping=label_mapping)
        preds_batch, cce = segment_vol_by_slice(
            segmenter_model,
            X,
            label_mapping=label_mapping,
            batch_size=batch_size,
            Y_oh=Y_oh,
            compute_cce=True,
        )
        subject_dice_per_label = medipy_metrics.dice(Y,
                                                     preds_batch,
                                                     labels=label_mapping)

        # only consider pixels where the gt label is not bkg (if we count bkg, accuracy will be very high)
        nonbkgmap = (Y > 0)

        acc = np.sum(((Y == preds_batch) *
                      nonbkgmap).astype(int)) / np.sum(nonbkgmap).astype(float)

        print(acc)
        dice_per_label[bi] = subject_dice_per_label
        cces[bi] = cce
        accs[bi] = acc
        all_ids += ids

    if logger is not None:
        logger.debug('Dice per label: {}, {}'.format(
            label_mapping,
            np.mean(dice_per_label, axis=0).tolist()))
        logger.debug('Mean dice (no bkg): {}'.format(
            np.mean(dice_per_label[:, 1:])))
        logger.debug('Mean CE: {}'.format(np.mean(cces)))
        logger.debug('Mean accuracy: {}'.format(np.mean(accs)))
    else:
        print('Dice per label: {}, {}'.format(
            label_mapping,
            np.mean(dice_per_label, axis=0).tolist()))
        print('Mean dice (no bkg): {}'.format(np.mean(dice_per_label[:, 1:])))
        print('Mean CE: {}'.format(np.mean(cces)))
        print('Mean accuracy: {}'.format(np.mean(accs)))
    return cces, dice_per_label, accs, all_ids
예제 #21
0
	# load weights of model
	with tf.device(gpu):
		net = networks.unet(vol_size, nf_enc, nf_dec)
		# net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5')
		net.load_weights(model_name)

	xx = np.arange(vol_size[1])
	yy = np.arange(vol_size[0])
	zz = np.arange(vol_size[2])
	grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4)

	X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz')

	with tf.device(gpu):
		pred = net.predict([X_vol, atlas_vol])

	# Warp segments with flow
	flow = pred[1][0, :, :, :, :]
	sample = flow+grid
	sample = np.stack((sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3)
	warp_seg = interpn((yy, xx, zz), X_seg[0, :, :, :, 0], sample, method='nearest', bounds_error=False, fill_value=0)

	vals, _ = dice(warp_seg, atlas_seg, labels=labels, nargout=2)
	print(np.mean(vals), np.std(vals))


if __name__ == "__main__":
	# test(sys.argv[1], sys.argv[2], sys.argv[3])
	test(sys.argv[1], sys.argv[2])
예제 #22
0
def test(
        gpu_id,
        model_dir,
        iter_num,
        compute_type='GPU',  # GPU or CPU
        vol_size=(160, 192, 224),
        nf_enc=[16, 32, 32, 32],
        nf_dec=[32, 32, 32, 32, 16, 3],
        save_file=None):
    """
    test via segmetnation propagation
    works by iterating over some iamge files, registering them to atlas,
    propagating the warps, then computing Dice with atlas segmentations
    """

    # GPU handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        # if testing miccai run, should be xy indexing.
        net = networks.miccai2018_net(vol_size,
                                      nf_enc,
                                      nf_dec,
                                      use_miccai_int=False,
                                      indexing='ij')
        net.load_weights(os.path.join(model_dir, str(iter_num) + '.h5'))

        # compose diffeomorphic flow output model
        diff_net = keras.models.Model(net.inputs,
                                      net.get_layer('diffflow').output)

        # NN transfer model
        nn_trf_model = networks.nn_trf(vol_size, indexing='ij')

    # if CPU, prepare grid
    if compute_type == 'CPU':
        grid, xx, yy, zz = util.volshape2grid_3d(vol_size, nargout=4)

    # get data
    X_vol = nib.load(r'D:\users\zzx\data\2018nor\pre\01.nii').get_data()
    X_vol = X_vol[np.newaxis, ..., np.newaxis]
    atlas_vol = nib.load(r'D:\users\zzx\data\2018nor\pre\01_a.nii').get_data()
    atlas_vol = atlas_vol[np.newaxis, ..., np.newaxis]

    X_mask = nib.load(
        r'D:\users\zzx\data\2018mask\pre\pig01_pre_final_r.nii').get_data()
    X_mask = X_mask[np.newaxis, ..., np.newaxis]
    X_mask[X_mask == np.max(X_mask)] = 1
    X_mask[X_mask != 1] = 0
    atlas_mask = nib.load(
        r'D:\users\zzx\data\2018mask\aft\pig01_02_final_r.nii').get_data()
    atlas_mask[atlas_mask == np.max(atlas_mask)] = 1
    atlas_mask[atlas_mask != 1] = 0
    ## feature point
    # X_feapt = np.zeros((160,192,224))
    # X_feapt[128,51,165] = 1
    # X_feapt = X_feapt[np.newaxis,...,np.newaxis]

    # predict transform
    with tf.device(gpu):
        pred = diff_net.predict([X_vol, atlas_vol])

    # Warp segments with flow
    if compute_type == 'CPU':
        flow = pred[0, :, :, :, :]
        warp_seg = util.warp_seg(X_mask, flow, grid=grid, xx=xx, yy=yy, zz=zz)

    else:  # GPU
        warp_mask = nn_trf_model.predict([X_mask, pred])[0, ..., 0]
        warp_vol = nn_trf_model.predict([X_vol, pred])[0, ..., 0]
        # pred_point1 = nn_trf_model.predict([X_feapt, pred])[0,...,0]
    print(X_vol.shape)
    # warp_vol = nib.Nifti1Image(warp_vol,np.eye(4))
    warp_vol = nib.Nifti1Image(warp_vol, np.eye(4))
    nib.save(warp_vol, r'D:\users\zzx\data\2018warp\1w.nii')
    # compute Volume Overlap (Dice)
    # X_mask = X_mask[0,...,0]
    # print(X_mask.shape, atlas_mask.shape,pred_point1.shape,np.where(pred_point1 != 0))
    dice_vals = dice(warp_mask, atlas_mask)
    # print('%3d %5.3f %5.3f' % (k, np.mean(dice_vals[:, k]), np.mean(np.mean(dice_vals[:, :k+1]))))
    print(dice_vals)
    if save_file is not None:
        sio.savemat(save_file, {'dice_vals': dice_vals})
예제 #23
0
def test(model_name,
         gpu_id,
         iter_num,
         vol_size=(160, 192, 224),
         nf_enc=[16, 32, 32, 32],
         nf_dec=[32, 32, 32, 32, 8, 8, 3]):
    """
	test

	nf_enc and nf_dec
	#nf_dec = [32,32,32,32,32,16,16,3]
    # This needs to be changed. Ideally, we could just call load_model, and we wont have to
    # specify the # of channels here, but the load_model is not working with the custom loss...
    """

    gpu = '/gpu:' + str(gpu_id)

    # Test file and anatomical labels we want to evaluate
    test_brain_file = open('../data/test_examples.txt')
    test_brain_strings = test_brain_file.readlines()
    test_brain_strings = [x.strip() for x in test_brain_strings]
    good_labels = sio.loadmat('../data/test_labels.mat')['labels'][0]

    atlas = np.load('../data/atlas_norm.npz')
    atlas_vol = atlas['vol']
    atlas_seg = atlas['seg']
    atlas_vol = np.reshape(atlas_vol, (1, ) + atlas_vol.shape + (1, ))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        net = networks.unet(vol_size, nf_enc, nf_dec)
        net.load_weights('../models/' + model_name + '/' + str(iter_num) +
                         '.h5')

    n_batches = len(test_brain_strings)
    xx = np.arange(vol_size[1])
    yy = np.arange(vol_size[0])
    zz = np.arange(vol_size[2])
    grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0, 4)

    dice_vals = np.zeros((len(good_labels), n_batches))

    np.random.seed(17)

    for k in range(0, n_batches):
        vol_name, seg_name = test_brain_strings[k].split(",")
        X_vol, X_seg = datagenerators.load_example_by_name(vol_name, seg_name)

        with tf.device(gpu):
            pred = net.predict([X_vol, atlas_vol])

        # Warp segments with flow
        flow = pred[1][0, :, :, :, :]
        sample = flow + grid
        sample = np.stack(
            (sample[:, :, :, 1], sample[:, :, :, 0], sample[:, :, :, 2]), 3)
        warp_seg = interpn((yy, xx, zz),
                           X_seg[0, :, :, :, 0],
                           sample,
                           method='nearest',
                           bounds_error=False,
                           fill_value=0)

        vals, labels = dice(warp_seg, atlas_seg, labels=good_labels, nargout=2)
        dice_vals[:, k] = vals
        print np.mean(dice_vals[:, k])
예제 #24
0
def test(iter_num,
         gpu_id,
         vol_size=(160, 192, 224),
         nf_enc=[16, 32, 32, 32],
         nf_dec=[32, 32, 32, 32, 32, 16, 16, 3]):
    gpu = '/gpu:' + str(gpu_id)

    # Anatomical labels we want to evaluate
    labels = sio.loadmat('../data/labels.mat')['labels'][0]

    # read atlas
    atlas_vol1, atlas_seg1 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990114_vc722.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990114_vc722.npz'
    )  # [1,160,192,224,1]
    atlas_seg1 = atlas_seg1[0, :, :, :,
                            0]  # reduce the dimension to [160,192,224]
    atlas_seg1 = keras.utils.to_categorical(atlas_seg1)
    atlas_vol2, atlas_seg2 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990210_vc792.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990210_vc792.npz'
    )
    atlas_seg2 = atlas_seg2[0, :, :, :, 0]
    atlas_seg2 = keras.utils.to_categorical(atlas_seg2)
    atlas_vol3, atlas_seg3 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990405_vc922.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990405_vc922.npz'
    )
    atlas_seg3 = atlas_seg3[0, :, :, :, 0]
    atlas_seg3 = keras.utils.to_categorical(atlas_seg3)
    atlas_vol4, atlas_seg4 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991006_vc1337.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991006_vc1337.npz'
    )
    atlas_seg4 = atlas_seg4[0, :, :, :, 0]
    atlas_seg4 = keras.utils.to_categorical(atlas_seg4)
    atlas_vol5, atlas_seg5 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991120_vc1456.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991120_vc1456.npz'
    )
    atlas_seg5 = atlas_seg5[0, :, :, :, 0]
    atlas_seg5 = keras.utils.to_categorical(atlas_seg5)
    #atlas = np.load('../data/atlas_norm.npz')
    #atlas_vol = atlas['vol']
    #print('the size of atlas:')
    #print(atlas_vol.shape)
    #atlas_seg = atlas['seg']
    #atlas_vol = np.reshape(atlas_vol, (1,)+atlas_vol.shape+(1,))

    #gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        net = networks.unet(vol_size, nf_enc, nf_dec)
        net.load_weights('/home/ys895/MAS_Models/' + str(iter_num) + '.h5')
        #net.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5')

    xx = np.arange(vol_size[1])
    yy = np.arange(vol_size[0])
    zz = np.arange(vol_size[2])
    grid = np.rollaxis(np.array(np.meshgrid(xx, yy, zz)), 0,
                       4)  # (160, 192, 224, 3)
    #X_vol, X_seg = datagenerators.load_example_by_name('../data/test_vol.npz', '../data/test_seg.npz')
    X_vol1, X_seg1 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/981216_vc681.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/981216_vc681.npz'
    )

    X_vol2, X_seg2 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990205_vc783.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990205_vc783.npz'
    )

    X_vol3, X_seg3 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/990525_vc1024.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/990525_vc1024.npz'
    )

    X_vol4, X_seg4 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991025_vc1379.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991025_vc1379.npz'
    )

    X_vol5, X_seg5 = datagenerators.load_example_by_name(
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/vols/991122_vc1463.npz',
        '/home/ys895/resize256/resize256-crop_x32/FromEugenio_prep/labels/991122_vc1463.npz'
    )

    # change the direction of the atlas data and volume data
    # pred[0].shape (1, 160, 192, 224, 1)
    # pred[1].shape (1, 160, 192, 224, 3)
    # X1
    with tf.device(gpu):
        pred1 = net.predict([atlas_vol1, X_vol1])
        pred2 = net.predict([atlas_vol2, X_vol1])
        pred3 = net.predict([atlas_vol3, X_vol1])
        pred4 = net.predict([atlas_vol4, X_vol1])
        pred5 = net.predict([atlas_vol5, X_vol1])
    # Warp segments with flow
    flow1 = pred1[1][0, :, :, :, :]  # (1, 160, 192, 224, 3)
    flow2 = pred2[1][0, :, :, :, :]
    flow3 = pred3[1][0, :, :, :, :]
    flow4 = pred4[1][0, :, :, :, :]
    flow5 = pred5[1][0, :, :, :, :]

    sample1 = flow1 + grid
    sample1 = np.stack(
        (sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)
    sample2 = flow2 + grid
    sample2 = np.stack(
        (sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3)
    sample3 = flow3 + grid
    sample3 = np.stack(
        (sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3)
    sample4 = flow4 + grid
    sample4 = np.stack(
        (sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3)
    sample5 = flow5 + grid
    sample5 = np.stack(
        (sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3)

    warp_seg1 = interpn((yy, xx, zz),
                        atlas_seg1[:, :, :, :],
                        sample1,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)  # (160, 192, 224)
    warp_seg2 = interpn((yy, xx, zz),
                        atlas_seg2[:, :, :, :],
                        sample2,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg3 = interpn((yy, xx, zz),
                        atlas_seg3[:, :, :, :],
                        sample3,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg4 = interpn((yy, xx, zz),
                        atlas_seg4[:, :, :, :],
                        sample4,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg5 = interpn((yy, xx, zz),
                        atlas_seg5[:, :, :, :],
                        sample5,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)

    # label fusion: get the final warp_seg
    warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3]))
    warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4 + warp_seg5) / 5
    warp_seg = np.argmax(warp_seg, axis=3)

    vals, _ = dice(warp_seg, X_seg1[0, :, :, :, 0], labels=labels, nargout=2)
    mean1 = np.mean(vals)
    var1 = np.std(vals)

    # X2
    with tf.device(gpu):
        pred1 = net.predict([atlas_vol1, X_vol2])
        pred2 = net.predict([atlas_vol2, X_vol2])
        pred3 = net.predict([atlas_vol3, X_vol2])
        pred4 = net.predict([atlas_vol4, X_vol2])
        pred5 = net.predict([atlas_vol5, X_vol2])
    # Warp segments with flow
    flow1 = pred1[1][0, :, :, :, :]  # (1, 160, 192, 224, 3)
    flow2 = pred2[1][0, :, :, :, :]
    flow3 = pred3[1][0, :, :, :, :]
    flow4 = pred4[1][0, :, :, :, :]
    flow5 = pred5[1][0, :, :, :, :]

    sample1 = flow1 + grid
    sample1 = np.stack(
        (sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)
    sample2 = flow2 + grid
    sample2 = np.stack(
        (sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3)
    sample3 = flow3 + grid
    sample3 = np.stack(
        (sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3)
    sample4 = flow4 + grid
    sample4 = np.stack(
        (sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3)
    sample5 = flow5 + grid
    sample5 = np.stack(
        (sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3)

    warp_seg1 = interpn((yy, xx, zz),
                        atlas_seg1[:, :, :, :],
                        sample1,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)  # (160, 192, 224)
    warp_seg2 = interpn((yy, xx, zz),
                        atlas_seg2[:, :, :, :],
                        sample2,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg3 = interpn((yy, xx, zz),
                        atlas_seg3[:, :, :, :],
                        sample3,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg4 = interpn((yy, xx, zz),
                        atlas_seg4[:, :, :, :],
                        sample4,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg5 = interpn((yy, xx, zz),
                        atlas_seg5[:, :, :, :],
                        sample5,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)

    # label fusion: get the final warp_seg
    warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3]))
    warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4 + warp_seg5) / 5
    warp_seg = np.argmax(warp_seg, axis=3)

    vals, _ = dice(warp_seg, X_seg2[0, :, :, :, 0], labels=labels, nargout=2)
    mean2 = np.mean(vals)
    var2 = np.std(vals)
    #print(np.mean(vals), np.std(vals))

    # X3
    with tf.device(gpu):
        pred1 = net.predict([atlas_vol1, X_vol3])
        pred2 = net.predict([atlas_vol2, X_vol3])
        pred3 = net.predict([atlas_vol3, X_vol3])
        pred4 = net.predict([atlas_vol4, X_vol3])
        pred5 = net.predict([atlas_vol5, X_vol3])
    # Warp segments with flow
    flow1 = pred1[1][0, :, :, :, :]  # (1, 160, 192, 224, 3)
    flow2 = pred2[1][0, :, :, :, :]
    flow3 = pred3[1][0, :, :, :, :]
    flow4 = pred4[1][0, :, :, :, :]
    flow5 = pred5[1][0, :, :, :, :]

    sample1 = flow1 + grid
    sample1 = np.stack(
        (sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)
    sample2 = flow2 + grid
    sample2 = np.stack(
        (sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3)
    sample3 = flow3 + grid
    sample3 = np.stack(
        (sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3)
    sample4 = flow4 + grid
    sample4 = np.stack(
        (sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3)
    sample5 = flow5 + grid
    sample5 = np.stack(
        (sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3)

    warp_seg1 = interpn((yy, xx, zz),
                        atlas_seg1[:, :, :, :],
                        sample1,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg2 = interpn((yy, xx, zz),
                        atlas_seg2[:, :, :, :],
                        sample2,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg3 = interpn((yy, xx, zz),
                        atlas_seg3[:, :, :, :],
                        sample3,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg4 = interpn((yy, xx, zz),
                        atlas_seg4[:, :, :, :],
                        sample4,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg5 = interpn((yy, xx, zz),
                        atlas_seg5[:, :, :, :],
                        sample5,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)

    # label fusion: get the final warp_seg
    warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3]))
    warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4 + warp_seg5) / 5
    warp_seg = np.argmax(warp_seg, axis=3)

    vals, _ = dice(warp_seg, X_seg3[0, :, :, :, 0], labels=labels, nargout=2)
    mean3 = np.mean(vals)
    var3 = np.std(vals)

    # X4
    with tf.device(gpu):
        pred1 = net.predict([atlas_vol1, X_vol4])
        pred2 = net.predict([atlas_vol2, X_vol4])
        pred3 = net.predict([atlas_vol3, X_vol4])
        pred4 = net.predict([atlas_vol4, X_vol4])
        pred5 = net.predict([atlas_vol5, X_vol4])
    # Warp segments with flow
    flow1 = pred1[1][0, :, :, :, :]  # (1, 160, 192, 224, 3)
    flow2 = pred2[1][0, :, :, :, :]
    flow3 = pred3[1][0, :, :, :, :]
    flow4 = pred4[1][0, :, :, :, :]
    flow5 = pred5[1][0, :, :, :, :]

    sample1 = flow1 + grid
    sample1 = np.stack(
        (sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)
    sample2 = flow2 + grid
    sample2 = np.stack(
        (sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3)
    sample3 = flow3 + grid
    sample3 = np.stack(
        (sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3)
    sample4 = flow4 + grid
    sample4 = np.stack(
        (sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3)
    sample5 = flow5 + grid
    sample5 = np.stack(
        (sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3)

    warp_seg1 = interpn((yy, xx, zz),
                        atlas_seg1[:, :, :, :],
                        sample1,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg2 = interpn((yy, xx, zz),
                        atlas_seg2[:, :, :, :],
                        sample2,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg3 = interpn((yy, xx, zz),
                        atlas_seg3[:, :, :, :],
                        sample3,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg4 = interpn((yy, xx, zz),
                        atlas_seg4[:, :, :, :],
                        sample4,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg5 = interpn((yy, xx, zz),
                        atlas_seg5[:, :, :, :],
                        sample5,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)

    # label fusion: get the final warp_seg
    warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3]))
    warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4 + warp_seg5) / 5
    warp_seg = np.argmax(warp_seg, axis=3)

    vals, _ = dice(warp_seg, X_seg4[0, :, :, :, 0], labels=labels, nargout=2)
    mean4 = np.mean(vals)
    var4 = np.std(vals)
    # X5
    with tf.device(gpu):
        pred1 = net.predict([atlas_vol1, X_vol5])
        pred2 = net.predict([atlas_vol2, X_vol5])
        pred3 = net.predict([atlas_vol3, X_vol5])
        pred4 = net.predict([atlas_vol4, X_vol5])
        pred5 = net.predict([atlas_vol5, X_vol5])
    # Warp segments with flow
    flow1 = pred1[1][0, :, :, :, :]  # (1, 160, 192, 224, 3)
    flow2 = pred2[1][0, :, :, :, :]
    flow3 = pred3[1][0, :, :, :, :]
    flow4 = pred4[1][0, :, :, :, :]
    flow5 = pred5[1][0, :, :, :, :]

    sample1 = flow1 + grid
    sample1 = np.stack(
        (sample1[:, :, :, 1], sample1[:, :, :, 0], sample1[:, :, :, 2]), 3)
    sample2 = flow2 + grid
    sample2 = np.stack(
        (sample2[:, :, :, 1], sample2[:, :, :, 0], sample2[:, :, :, 2]), 3)
    sample3 = flow3 + grid
    sample3 = np.stack(
        (sample3[:, :, :, 1], sample3[:, :, :, 0], sample3[:, :, :, 2]), 3)
    sample4 = flow4 + grid
    sample4 = np.stack(
        (sample4[:, :, :, 1], sample4[:, :, :, 0], sample4[:, :, :, 2]), 3)
    sample5 = flow5 + grid
    sample5 = np.stack(
        (sample5[:, :, :, 1], sample5[:, :, :, 0], sample5[:, :, :, 2]), 3)

    warp_seg1 = interpn((yy, xx, zz),
                        atlas_seg1[:, :, :, :],
                        sample1,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg2 = interpn((yy, xx, zz),
                        atlas_seg2[:, :, :, :],
                        sample2,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg3 = interpn((yy, xx, zz),
                        atlas_seg3[:, :, :, :],
                        sample3,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg4 = interpn((yy, xx, zz),
                        atlas_seg4[:, :, :, :],
                        sample4,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)
    warp_seg5 = interpn((yy, xx, zz),
                        atlas_seg5[:, :, :, :],
                        sample5,
                        method='linear',
                        bounds_error=False,
                        fill_value=0)

    # label fusion: get the final warp_seg
    warp_seg = np.empty((160, 192, 224, atlas_seg1.shape[3]))
    warp_seg = (warp_seg1 + warp_seg2 + warp_seg3 + warp_seg4 + warp_seg5) / 5
    warp_seg = np.argmax(warp_seg, axis=3)

    vals, _ = dice(warp_seg, X_seg5[0, :, :, :, 0], labels=labels, nargout=2)
    mean5 = np.mean(vals)
    var5 = np.std(vals)

    # compute mean of dice score
    sum = mean1 + mean2 + mean3 + mean4 + mean5
    mean_dice = sum / 5
    var = (var1 + var2 + var3 + var4 + var5) / 5
    print(str(mean_dice) + ',' + str(var))
예제 #25
0
def test(data_dir, fixed_image, label, device, load_model_file, DLR_model):

    assert DLR_model in [
        'VM', 'FAIM'
    ], 'DLR_model should be one of VM or FAIM, found %s' % LBR_model

    # prepare data files
    # inside the folder are npz files with the 'vol' and 'label'.
    test_vol_names = glob.glob(os.path.join(data_dir, '*.npz'))
    assert len(test_vol_names) > 0, "Could not find any testing data"

    fixed_vol = np.load(fixed_image)['vol'][np.newaxis, ..., np.newaxis]
    fixed_seg = np.load(fixed_image)['label']
    vol_size = fixed_vol.shape[1:-1]
    label = np.load(label)

    # device handling
    if 'gpu' in device:
        if '0' in device:
            device = '/gpu:0'
        if '1' in device:
            device = '/gpu:1'
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        set_session(tf.Session(config=config))
    else:
        device = '/cpu:0'

    # load weights of model
    with tf.device(device):
        net = networks.AAN(vol_size, DLR_model)
        net.load_weights(load_model_file)

        # NN transfer model
        nn_trf_model_nearest = networks.nn_trf(vol_size,
                                               interp_method='nearest',
                                               indexing='ij')
        nn_trf_model_linear = networks.nn_trf(vol_size,
                                              interp_method='linear',
                                              indexing='ij')

    dice_result = []
    for test_image in test_vol_names:

        X_vol, X_seg, x_boundary = datagenerators.load_example_by_name(
            test_image, return_boundary=True)

        with tf.device(device):
            pred = net.predict([X_vol, fixing_vol, x_boundary])
            warp_vol = nn_trf_model_linear.predict([X_vol, pred[1]])[0, ..., 0]
            warp_seg = nn_trf_model_nearest.predict([X_seg, pred[1]])[0, ...,
                                                                      0]

        vals, _ = dice(warp_seg, fixing_seg, label, nargout=2)
        dice_result.append(vals)

        print('Dice mean: {:.3f} ({:.3f})'.format(np.mean(vals), np.std(vals)))

    dice_result = np.array(dice_result)
    print('Average dice mean: {:.3f} ({:.3f})'.format(np.mean(dice_result),
                                                      np.std(dice_result)))