Ejemplo n.º 1
0
def load_data_to_crop(path_to_img, x_scale, y_scale, z_scale,
                        normalize, mu, sig):
    # read image data
    img, _, img_ext = load_data(path_to_img, 'first_queue', return_extension=True)
    if img is None:
        InputError.message = "Invalid image data %s." %(os.path.basename(path_to_img))
        raise InputError()
    z_shape, y_shape, x_shape = img.shape
    img = img.astype(np.float32)
    img_z = img_resize(img, z_shape, y_scale, x_scale)
    img_y = np.swapaxes(img_resize(img,z_scale,y_shape,x_scale),0,1)
    img_x = np.swapaxes(img_resize(img,z_scale,y_scale,x_shape),0,2)
    img = np.append(img_z,img_y,axis=0)
    img = np.append(img,img_x,axis=0)
    img -= np.amin(img)
    img /= np.amax(img)
    if normalize:
        mu_tmp, sig_tmp = np.mean(img), np.std(img)
        img = (img - mu_tmp) / sig_tmp
        img = img * sig + mu
        img[img<0] = 0
        img[img>1] = 1
    img = np.uint8(img*255)

    img_rgb = np.empty((img.shape + (3,)), dtype=np.uint8)
    for i in range(3):
        img_rgb[...,i] = img
    return img_rgb, z_shape, y_shape, x_shape
Ejemplo n.º 2
0
def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
                         normalize, mu, sig, region_of_interest):

    # read image data
    img, img_header, img_ext = load_data(path_to_img,
                                         'first_queue',
                                         return_extension=True)
    if img is None:
        InputError.message = "Invalid image data %s." % (
            os.path.basename(path_to_img))
        raise InputError()
    if img_ext != '.am':
        img_header = None
    z_shape, y_shape, x_shape = img.shape

    # automatic cropping of image to region of interest
    if np.any(region_of_interest):
        min_z, max_z, min_y, max_y, min_x, max_x = region_of_interest[:]
        min_z = min(min_z, z_shape)
        min_y = min(min_y, y_shape)
        min_x = min(min_x, x_shape)
        max_z = min(max_z, z_shape)
        max_y = min(max_y, y_shape)
        max_x = min(max_x, x_shape)
        if max_z - min_z < z_shape:
            min_z, max_z = 0, z_shape
        if max_y - min_y < y_shape:
            min_y, max_y = 0, y_shape
        if max_x - min_x < x_shape:
            min_x, max_x = 0, x_shape
        img = np.copy(img[min_z:max_z, min_y:max_y, min_x:max_x], order='C')
        region_of_interest = np.array([
            min_z, max_z, min_y, max_y, min_x, max_x, z_shape, y_shape, x_shape
        ])
        z_shape, y_shape, x_shape = max_z - min_z, max_y - min_y, max_x - min_x

    # scale image data
    img = img.astype(np.float32)
    img = img_resize(img, z_scale, y_scale, x_scale)
    img -= np.amin(img)
    img /= np.amax(img)
    if normalize:
        mu_tmp, sig_tmp = np.mean(img), np.std(img)
        img = (img - mu_tmp) / sig_tmp
        img = img * sig + mu
        img[img < 0] = 0
        img[img > 1] = 1

    # compute position data
    position = None
    if channels == 2:
        position = np.empty((z_scale, y_scale, x_scale), dtype=np.float32)
        position = compute_position(position, z_scale, y_scale, x_scale)
        position = np.sqrt(position)
        position /= np.amax(position)

    return img, img_header, position, z_shape, y_shape, x_shape, region_of_interest
Ejemplo n.º 3
0
def load_refine_data(path_to_img, path_to_final, patch_size, normalize,
                     allLabels, mu, sig):

    # read image data
    img, _ = load_data(path_to_img, 'first_queue')
    if img is None:
        InputError.message = "Invalid image data %s." % (
            os.path.basename(path_to_img))
        raise InputError()
    z_shape, y_shape, x_shape = img.shape
    img = img.astype(np.float32)
    img -= np.amin(img)
    img /= np.amax(img)
    if normalize:
        mu_tmp, sig_tmp = np.mean(img), np.std(img)
        img = (img - mu_tmp) / sig_tmp
        img = img * sig + mu
        img[img < 0] = 0
        img[img > 1] = 1
    #img = make_axis_divisible_by_patch_size(img, patch_size)

    # load label data
    label, _ = load_data(path_to_final, 'first_queue')
    if label is None:
        InputError.message = "Invalid label data %s." % (
            os.path.basename(path_to_final))
        raise InputError()
    #label = make_axis_divisible_by_patch_size(label, patch_size)

    # labels must be in ascending order
    for k, l in enumerate(allLabels):
        label[label == l] = k

    # load final data and scale to [0,1]
    final = np.copy(label)
    final = final.astype(np.float32)
    final /= len(allLabels) - 1

    return img, label, final, z_shape, y_shape, x_shape
Ejemplo n.º 4
0
def load_training_data_refine(path_to_model, x_scale, y_scale, z_scale, patch_size, z_patch, y_patch, x_patch, normalize, \
                    img_list, label_list, channels, stride_size, allLabels, mu, sig, batch_size):

    # get filenames
    img_names, label_names = [], []
    for img_name, label_name in zip(img_list, label_list):

        img_dir, img_ext = os.path.splitext(img_name)
        if img_ext == '.gz':
            img_dir, img_ext = os.path.splitext(img_dir)

        label_dir, label_ext = os.path.splitext(label_name)
        if label_ext == '.gz':
            label_dir, label_ext = os.path.splitext(label_dir)

        if img_ext == '.tar' and label_ext == '.tar':
            for data_type in ['.am','.tif','.tiff','.hdr','.mhd','.mha','.nrrd','.nii','.nii.gz']:
                tmp_img_names = glob(img_dir+'/**/*'+data_type, recursive=True)
                tmp_label_names = glob(label_dir+'/**/*'+data_type, recursive=True)
                tmp_img_names = sorted(tmp_img_names)
                tmp_label_names = sorted(tmp_label_names)
                img_names.extend(tmp_img_names)
                label_names.extend(tmp_label_names)
        else:
            img_names.append(img_name)
            label_names.append(label_name)

    # predict pre-final
    final = []
    for name in img_names:
        a, _ = load_data(name, 'first_queue')
        if a is None:
            InputError.message = "Invalid image data %s." %(os.path.basename(name))
            raise InputError()
        a = predict_pre_final(a, path_to_model, x_scale, y_scale, z_scale, z_patch, y_patch, x_patch, \
                              normalize, mu, sig, channels, stride_size, batch_size)
        a = a.astype(np.float32)
        a /= len(allLabels) - 1
        #a = make_axis_divisible_by_patch_size(a, patch_size)
        final.append(a)

    # load img data
    img = []
    for name in img_names:
        a, _ = load_data(name, 'first_queue')
        a = a.astype(np.float32)
        a -= np.amin(a)
        a /= np.amax(a)
        if normalize:
            mu_tmp, sig_tmp = np.mean(a), np.std(a)
            a = (a - mu_tmp) / sig_tmp
            a = a * sig + mu
            a[a<0] = 0
            a[a>1] = 1
        #a = make_axis_divisible_by_patch_size(a, patch_size)
        img.append(a)

    # load label data
    label = []
    for name in label_names:
        a, _ = load_data(name, 'first_queue')
        if a is None:
            InputError.message = "Invalid label data %s." %(os.path.basename(name))
            raise InputError()
        #a = make_axis_divisible_by_patch_size(a, patch_size)
        label.append(a)

    # labels must be in ascending order
    for i in range(len(label)):
        for k, l in enumerate(allLabels):
            label[i][label[i]==l] = k

    return img, label, final
Ejemplo n.º 5
0
def load_training_data(normalize, img_list, label_list, channels, x_scale, y_scale, z_scale,
        crop_data, configuration_data=None, allLabels=None, x_puffer=25, y_puffer=25, z_puffer=25):

    # get filenames
    img_names, label_names = [], []
    for img_name, label_name in zip(img_list, label_list):

        # check for tarball
        img_dir, img_ext = os.path.splitext(img_name)
        if img_ext == '.gz':
            img_dir, img_ext = os.path.splitext(img_dir)

        label_dir, label_ext = os.path.splitext(label_name)
        if label_ext == '.gz':
            label_dir, label_ext = os.path.splitext(label_dir)

        if (img_ext == '.tar' and label_ext == '.tar') or (os.path.isdir(img_name) and os.path.isdir(label_name)):

            # extract files if necessary
            if img_ext == '.tar' and not os.path.exists(img_dir):
                tar = tarfile.open(img_name)
                tar.extractall(path=img_dir)
                tar.close()
            if label_ext == '.tar' and not os.path.exists(label_dir):
                tar = tarfile.open(label_name)
                tar.extractall(path=label_dir)
                tar.close()

            for data_type in ['.am','.tif','.tiff','.hdr','.mhd','.mha','.nrrd','.nii','.nii.gz']:
                tmp_img_names = glob(img_dir+'/**/*'+data_type, recursive=True)
                tmp_label_names = glob(label_dir+'/**/*'+data_type, recursive=True)
                tmp_img_names = sorted(tmp_img_names)
                tmp_label_names = sorted(tmp_label_names)
                img_names.extend(tmp_img_names)
                label_names.extend(tmp_label_names)
            if len(img_names)==0:
                InputError.message = "Invalid image TAR file."
                raise InputError()
            if len(label_names)==0:
                InputError.message = "Invalid label TAR file."
                raise InputError()
        else:
            img_names.append(img_name)
            label_names.append(label_name)

    # load first label
    a, header, extension = load_data(label_names[0], 'first_queue', True)
    if a is None:
        InputError.message = "Invalid label data %s." %(os.path.basename(label_names[0]))
        raise InputError()
    if crop_data:
        argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
        a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
    a = a.astype(np.uint8)
    np_unique = np.unique(a)
    label = np.zeros((z_scale, y_scale, x_scale), dtype=a.dtype)
    for k in np_unique:
        tmp = np.zeros_like(a)
        tmp[a==k] = 1
        tmp = img_resize(tmp, z_scale, y_scale, x_scale)
        label[tmp==1] = k

    # load first img
    img, _ = load_data(img_names[0], 'first_queue')
    if img is None:
        InputError.message = "Invalid image data %s." %(os.path.basename(img_names[0]))
        raise InputError()
    if crop_data:
        img = np.copy(img[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
    img = img.astype(np.float32)
    img = img_resize(img, z_scale, y_scale, x_scale)
    img -= np.amin(img)
    img /= np.amax(img)
    if configuration_data is not None:
        mu, sig = configuration_data[5], configuration_data[6]
        mu_tmp, sig_tmp = np.mean(img), np.std(img)
        img = (img - mu_tmp) / sig_tmp
        img = img * sig + mu
    else:
        mu, sig = np.mean(img), np.std(img)

    for img_name, label_name in zip(img_names[1:], label_names[1:]):

        # append label
        a, _ = load_data(label_name, 'first_queue')
        if a is None:
            InputError.message = "Invalid label data %s." %(os.path.basename(name))
            raise InputError()
        if crop_data:
            argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
            a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
        a = a.astype(np.uint8)
        np_unique = np.unique(a)
        next_label = np.zeros((z_scale, y_scale, x_scale), dtype=a.dtype)
        for k in np_unique:
            tmp = np.zeros_like(a)
            tmp[a==k] = 1
            tmp = img_resize(tmp, z_scale, y_scale, x_scale)
            next_label[tmp==1] = k
        label = np.append(label, next_label, axis=0)

        # append image
        a, _ = load_data(img_name, 'first_queue')
        if a is None:
            InputError.message = "Invalid image data %s." %(os.path.basename(name))
            raise InputError()
        if crop_data:
            a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
        a = a.astype(np.float32)
        a = img_resize(a, z_scale, y_scale, x_scale)
        a -= np.amin(a)
        a /= np.amax(a)
        if normalize:
            mu_tmp, sig_tmp = np.mean(a), np.std(a)
            a = (a - mu_tmp) / sig_tmp
            a = a * sig + mu
        img = np.append(img, a, axis=0)

    # scale image data to [0,1]
    img[img<0] = 0
    img[img>1] = 1

    # compute position data
    position = None
    if channels == 2:
        position = np.empty((z_scale, y_scale, x_scale), dtype=np.float32)
        position = compute_position(position, z_scale, y_scale, x_scale)
        position = np.sqrt(position)
        position /= np.amax(position)
        for k in range(len(img_names[1:])):
            a = np.copy(position)
            position = np.append(position, a, axis=0)

    # labels must be in ascending order
    if allLabels is not None:
        counts = None
        for k, l in enumerate(allLabels):
            label[label==l] = k
    else:
        allLabels, counts = np.unique(label, return_counts=True)
        for k, l in enumerate(allLabels):
            label[label==l] = k

    # configuration data
    configuration_data = np.array([channels, x_scale, y_scale, z_scale, normalize, mu, sig])

    return img, label, position, allLabels, configuration_data, header, extension, counts
Ejemplo n.º 6
0
        if val in ['--smooth', '-s']:
            smooth = int(sys.argv[i + 1])
    uq = True if any(x in sys.argv
                     for x in ['--uncertainty', '-uq']) else False
    allx = 1 if '-allx' in sys.argv else 0

    # base directory
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

    # clean tmp folder
    filelist = glob.glob(BASE_DIR + '/tmp/*.tif')
    for f in filelist:
        os.remove(f)

    # data shape
    data, _ = load_data(path_to_data, 'split_volume')
    shape = np.copy(np.array(data.shape))
    zsh, ysh, xsh = shape
    del data

    # split volume
    sub_size_z = np.ceil(zsh / sub_z)
    sub_size_y = np.ceil(ysh / sub_y)
    sub_size_x = np.ceil(xsh / sub_x)

    # iterate over subvolumes
    for sub_z_i in range(sub_z):
        for sub_y_i in range(sub_y):
            for sub_x_i in range(sub_x):
                subvolume = sub_z_i * sub_y * sub_x + sub_y_i * sub_x + sub_x_i + 1
                print('Subvolume:', subvolume, '/', sub_z * sub_y * sub_x)
Ejemplo n.º 7
0
def load_training_data(normalize,
                       img_dir,
                       label_dir,
                       channels,
                       x_scale,
                       y_scale,
                       z_scale,
                       crop_data,
                       configuration_data=None,
                       allLabels=None):

    # get filenames
    img_names, label_names = [], []
    for data_type in [
            '.am', '.tif', '.tiff', '.hdr', '.mhd', '.mha', '.nrrd', '.nii',
            '.nii.gz'
    ]:
        tmp_img_names = glob(img_dir + '/**/*' + data_type, recursive=True)
        tmp_label_names = glob(label_dir + '/**/*' + data_type, recursive=True)
        tmp_img_names = sorted(tmp_img_names)
        tmp_label_names = sorted(tmp_label_names)
        img_names.extend(tmp_img_names)
        label_names.extend(tmp_label_names)

    # load first label
    region_of_interest = None
    a, header, extension = load_data(label_names[0], 'first_queue', True)
    if a is None:
        InputError.message = "Invalid label data %s." % (os.path.basename(
            label_names[0]))
        raise InputError()
    if crop_data:
        region_of_interest = np.zeros(6)
        argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = predict_blocksize(
            a)
        a = np.copy(a[argmin_z:argmax_z, argmin_y:argmax_y, argmin_x:argmax_x],
                    order='C')
        region_of_interest += [
            argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x
        ]
    a = a.astype(np.uint8)
    np_unique = np.unique(a)
    label = np.zeros((z_scale, y_scale, x_scale), dtype=a.dtype)
    for k in np_unique:
        tmp = np.zeros_like(a)
        tmp[a == k] = 1
        tmp = img_resize(tmp, z_scale, y_scale, x_scale)
        label[tmp == 1] = k

    # load first img
    img, _ = load_data(img_names[0], 'first_queue')
    if img is None:
        InputError.message = "Invalid image data %s." % (os.path.basename(
            img_names[0]))
        raise InputError()
    if crop_data:
        img = np.copy(img[argmin_z:argmax_z, argmin_y:argmax_y,
                          argmin_x:argmax_x],
                      order='C')
    img = img.astype(np.float32)
    img = img_resize(img, z_scale, y_scale, x_scale)
    img -= np.amin(img)
    img /= np.amax(img)
    if configuration_data is not None:
        mu, sig = configuration_data[5], configuration_data[6]
        mu_tmp, sig_tmp = np.mean(img), np.std(img)
        img = (img - mu_tmp) / sig_tmp
        img = img * sig + mu
    else:
        mu, sig = np.mean(img), np.std(img)
    for img_name, label_name in zip(img_names[1:], label_names[1:]):

        # append label
        a, _ = load_data(label_name, 'first_queue')
        if a is None:
            InputError.message = "Invalid label data %s." % (
                os.path.basename(name))
            raise InputError()
        if crop_data:
            argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = predict_blocksize(
                a)
            a = np.copy(a[argmin_z:argmax_z, argmin_y:argmax_y,
                          argmin_x:argmax_x],
                        order='C')
            region_of_interest += [
                argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x
            ]
        a = a.astype(np.uint8)
        np_unique = np.unique(a)
        next_label = np.zeros((z_scale, y_scale, x_scale), dtype=a.dtype)
        for k in np_unique:
            tmp = np.zeros_like(a)
            tmp[a == k] = 1
            tmp = img_resize(tmp, z_scale, y_scale, x_scale)
            next_label[tmp == 1] = k
        label = np.append(label, next_label, axis=0)

        # append image
        a, _ = load_data(img_name, 'first_queue')
        if a is None:
            InputError.message = "Invalid image data %s." % (
                os.path.basename(name))
            raise InputError()
        if crop_data:
            a = np.copy(a[argmin_z:argmax_z, argmin_y:argmax_y,
                          argmin_x:argmax_x],
                        order='C')
        a = a.astype(np.float32)
        a = img_resize(a, z_scale, y_scale, x_scale)
        a -= np.amin(a)
        a /= np.amax(a)
        if normalize:
            mu_tmp, sig_tmp = np.mean(a), np.std(a)
            a = (a - mu_tmp) / sig_tmp
            a = a * sig + mu
        img = np.append(img, a, axis=0)

    # automatic cropping
    if crop_data:
        region_of_interest /= float(len(img_names))
        region_of_interest = np.round(region_of_interest)
        region_of_interest[region_of_interest < 0] = 0
        region_of_interest = region_of_interest.astype(int)

    # scale image data to [0,1]
    img[img < 0] = 0
    img[img > 1] = 1

    # compute position data
    position = None
    if channels == 2:
        position = np.empty((z_scale, y_scale, x_scale), dtype=np.float32)
        position = compute_position(position, z_scale, y_scale, x_scale)
        position = np.sqrt(position)
        position /= np.amax(position)
        for k in range(len(img_names[1:])):
            a = np.copy(position)
            position = np.append(position, a, axis=0)

    # labels must be in ascending order
    if allLabels is not None:
        counts = None
        for k, l in enumerate(allLabels):
            label[label == l] = k
    else:
        allLabels, counts = np.unique(label, return_counts=True)
        for k, l in enumerate(allLabels):
            label[label == l] = k

    # configuration data
    configuration_data = np.array(
        [channels, x_scale, y_scale, z_scale, normalize, mu, sig])

    return img, label, position, allLabels, configuration_data, header, extension, region_of_interest, counts
Ejemplo n.º 8
0
def load_cropping_training_data(normalize, img_list, label_list, x_scale, y_scale, z_scale, mu=None, sig=None):

    # get filenames
    img_names, label_names = [], []
    for img_name, label_name in zip(img_list, label_list):

        # check for tarball
        img_dir, img_ext = os.path.splitext(img_name)
        if img_ext == '.gz':
            img_dir, img_ext = os.path.splitext(img_dir)

        label_dir, label_ext = os.path.splitext(label_name)
        if label_ext == '.gz':
            label_dir, label_ext = os.path.splitext(label_dir)

        if (img_ext == '.tar' and label_ext == '.tar') or (os.path.isdir(img_name) and os.path.isdir(label_name)):

            # extract files if necessary
            if img_ext == '.tar' and not os.path.exists(img_dir):
                tar = tarfile.open(img_name)
                tar.extractall(path=img_dir)
                tar.close()
            if label_ext == '.tar' and not os.path.exists(label_dir):
                tar = tarfile.open(label_name)
                tar.extractall(path=label_dir)
                tar.close()

            for data_type in ['.am','.tif','.tiff','.hdr','.mhd','.mha','.nrrd','.nii','.nii.gz']:
                tmp_img_names = glob(img_dir+'/**/*'+data_type, recursive=True)
                tmp_label_names = glob(label_dir+'/**/*'+data_type, recursive=True)
                tmp_img_names = sorted(tmp_img_names)
                tmp_label_names = sorted(tmp_label_names)
                img_names.extend(tmp_img_names)
                label_names.extend(tmp_label_names)
            if len(img_names)==0:
                InputError.message = "Invalid image TAR file."
                raise InputError()
            if len(label_names)==0:
                InputError.message = "Invalid label TAR file."
                raise InputError()
        else:
            img_names.append(img_name)
            label_names.append(label_name)

    # load first label
    a, header, extension = load_data(label_names[0], 'first_queue', True)
    if a is None:
        InputError.message = "Invalid label data %s." %(os.path.basename(label_names[0]))
        raise InputError()
    a = a.astype(np.uint8)
    label_z = np.any(a,axis=(1,2))
    label_y = np.any(a,axis=(0,2))
    label_x = np.any(a,axis=(0,1))
    label = np.append(label_z,label_y,axis=0)
    label = np.append(label,label_x,axis=0)

    # load first img
    img, _ = load_data(img_names[0], 'first_queue')
    if img is None:
        InputError.message = "Invalid image data %s." %(os.path.basename(img_names[0]))
        raise InputError()
    img = img.astype(np.float32)
    img_z = img_resize(img, a.shape[0], y_scale, x_scale)
    img_y = np.swapaxes(img_resize(img, z_scale, a.shape[1], x_scale),0,1)
    img_x = np.swapaxes(img_resize(img, z_scale, y_scale, a.shape[2]),0,2)
    img = np.append(img_z,img_y,axis=0)
    img = np.append(img,img_x,axis=0)
    img -= np.amin(img)
    img /= np.amax(img)
    if mu is not None and normalize:
        mu_tmp, sig_tmp = np.mean(img), np.std(img)
        img = (img - mu_tmp) / sig_tmp
        img = img * sig + mu
        img[img<0] = 0
        img[img>1] = 1
    else:
        mu, sig = np.mean(img), np.std(img)
    img = np.uint8(img*255)

    for img_name, label_name in zip(img_names[1:], label_names[1:]):

        # append label
        a, _ = load_data(label_name, 'first_queue')
        if a is None:
            InputError.message = "Invalid label data %s." %(os.path.basename(name))
            raise InputError()
        a = a.astype(np.uint8)
        next_label_z = np.any(a,axis=(1,2))
        next_label_y = np.any(a,axis=(0,2))
        next_label_x = np.any(a,axis=(0,1))
        label = np.append(label,next_label_z,axis=0)
        label = np.append(label,next_label_y,axis=0)
        label = np.append(label,next_label_x,axis=0)

        # append image
        a, _ = load_data(img_name, 'first_queue')
        if a is None:
            InputError.message = "Invalid image data %s." %(os.path.basename(name))
            raise InputError()
        a = a.astype(np.float32)
        img_z = img_resize(a, a.shape[0], y_scale, x_scale)
        img_y = np.swapaxes(img_resize(a, z_scale, a.shape[1], x_scale),0,1)
        img_x = np.swapaxes(img_resize(a, z_scale, y_scale, a.shape[2]),0,2)
        next_img = np.append(img_z,img_y,axis=0)
        next_img = np.append(next_img,img_x,axis=0)
        next_img -= np.amin(next_img)
        next_img /= np.amax(next_img)
        if normalize:
            mu_tmp, sig_tmp = np.mean(next_img), np.std(next_img)
            next_img = (next_img - mu_tmp) / sig_tmp
            next_img = next_img * sig + mu
            next_img[next_img<0] = 0
            next_img[next_img>1] = 1
        next_img = np.uint8(next_img*255)
        img = np.append(img, next_img, axis=0)

    img_rgb = np.empty((img.shape + (3,)), dtype=np.uint8)
    for i in range(3):
        img_rgb[...,i] = img

    # compute position data
    position = None

    return img_rgb, label, position, mu, sig, header, extension, len(img_names)
Ejemplo n.º 9
0
def crop_volume(img, path_to_volume, path_to_model, z_shape, y_shape, x_shape, batch_size, debug_cropping,
        x_puffer=25,y_puffer=25,z_puffer=25):

    # path to cropped image
    filename = os.path.basename(path_to_volume)
    filename = os.path.splitext(filename)[0]
    if filename[-4:] in ['.nii']:
        filename = filename[:-4]
    filename = filename + '_cropped.tif'
    path_to_final = path_to_volume.replace(os.path.basename(path_to_volume), filename)

    # img shape
    zsh, ysh, xsh, channels = img.shape

    # list of IDs
    list_IDs = [x for x in range(zsh)]

    # make length of list divisible by batch size
    rest = batch_size - (len(list_IDs) % batch_size)
    list_IDs = list_IDs + list_IDs[:rest]

    # parameters
    params = {'dim': (ysh,xsh),
              'dim_img': (zsh, ysh, xsh),
              'batch_size': batch_size,
              'n_channels': channels}

    # data generator
    predict_generator = PredictDataGeneratorCrop(img, list_IDs, **params)

    # create a MirroredStrategy
    if os.name == 'nt':
        cdo = tf.distribute.HierarchicalCopyAllReduce()
    else:
        cdo = tf.distribute.NcclAllReduce()
    strategy = tf.distribute.MirroredStrategy(cross_device_ops=cdo)

    # input shape
    input_shape = (ysh, xsh, channels)

    # load model
    with strategy.scope():
        model = make_densenet(input_shape)

    # load weights
    hf = h5py.File(path_to_model, 'r')
    cropping_weights = hf.get('cropping_weights')
    iterator = 0
    for layer in model.layers:
        if layer.get_weights() != []:
            new_weights = []
            for arr in layer.get_weights():
                new_weights.append(cropping_weights.get(str(iterator)))
                iterator += 1
            layer.set_weights(new_weights)
    hf.close()

    # predict
    probabilities = model.predict(predict_generator, verbose=0, steps=None)
    probabilities = probabilities[:zsh]
    probabilities = np.ravel(probabilities)

    # plot prediction
    if debug_cropping:
        import matplotlib.pyplot as plt
        import matplotlib
        x = range(len(probabilities))
        y = list(probabilities)
        plt.plot(x, y)

    # create mask
    probabilities[probabilities > 0.5] = 1
    probabilities[probabilities <= 0.5] = 0

    # remove outliers
    for k in range(4,zsh-4):
        if np.all(probabilities[k-1:k+2] == np.array([0,1,0])):
            probabilities[k-1:k+2] = 0
        elif np.all(probabilities[k-2:k+2] == np.array([0,1,1,0])):
            probabilities[k-2:k+2] = 0
        elif np.all(probabilities[k-2:k+3] == np.array([0,1,1,1,0])):
            probabilities[k-2:k+3] = 0
        elif np.all(probabilities[k-3:k+3] == np.array([0,1,1,1,1,0])):
            probabilities[k-3:k+3] = 0
        elif np.all(probabilities[k-3:k+4] == np.array([0,1,1,1,1,1,0])):
            probabilities[k-3:k+4] = 0
        elif np.all(probabilities[k-4:k+4] == np.array([0,1,1,1,1,1,1,0])):
            probabilities[k-4:k+4] = 0
        elif np.all(probabilities[k-4:k+5] == np.array([0,1,1,1,1,1,1,1,0])):
            probabilities[k-4:k+5] = 0

    # plot cleaned result
    if debug_cropping:
        y = list(probabilities)
        plt.plot(x, y, '--')
        plt.tight_layout()  # To prevent overlapping of subplots
        matplotlib.use("GTK3Agg")
        plt.savefig(path_to_final.replace('.tif','.png'), dpi=300)

    # create final
    z_upper = max(0,np.argmax(probabilities[:z_shape]) - z_puffer)
    z_lower = min(z_shape,z_shape - np.argmax(np.flip(probabilities[:z_shape])) + z_puffer +1)
    y_upper = max(0,np.argmax(probabilities[z_shape:z_shape+y_shape]) - y_puffer)
    y_lower = min(y_shape,y_shape - np.argmax(np.flip(probabilities[z_shape:z_shape+y_shape])) + y_puffer +1)
    x_upper = max(0,np.argmax(probabilities[z_shape+y_shape:]) - x_puffer)
    x_lower = min(x_shape,x_shape - np.argmax(np.flip(probabilities[z_shape+y_shape:])) + x_puffer +1)

    # crop image data
    if debug_cropping:
        volume, _ = load_data(path_to_volume)
        final = volume[z_upper:z_lower,y_upper:y_lower,x_upper:x_lower]
        save_data(path_to_final, final, compress=False)

    return z_upper, z_lower, y_upper, y_lower, x_upper, x_lower