Ejemplo n.º 1
0
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.

        Args:
            file_references (list): A list of lists containing file references, such
                as [['id_0', 'image_filename_0', target_value_0], ...,
                ['id_N', 'image_filename_N', target_value_N]].
            mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
                PREDICT.
            params (dict, optional): A dictionary to parameterise read_fn ouputs
                (e.g. reader_params = {'n_examples': 10, 'example_size':
                [64, 64, 64], 'extract_examples': True}, etc.).

        Yields:
            dict: A dictionary of reader outputs for dltk.io.abstract_reader.
        """
    def _augment(img, lbl):
        """An image augmentation function"""
        img = add_gaussian_noise(img, sigma=0.1)
        [img, lbl] = flip([img, lbl], axis=1)
        return img, lbl

    for f in file_references:
        subject_id = f[0]
        img_path = f[1]
        img_prefix = f[2]

        # Read the image nii with sitk and keep the pointer to the sitk.Image of an input
        # print(os.getcwd())
        t2_sitk = sitk.ReadImage(str(img_path + img_prefix + t2_postfix))
        t2 = sitk.GetArrayFromImage(t2_sitk)

        # Normalise volume images
        t2 = whitening(t2)

        # Create a 4D multi-sequence image (i.e. [channels, x,y,z])
        images = np.stack([t2], axis=-1).astype(np.float32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            print("Predict not yet implemented, please try a different mode")
            yield {
                'features': {
                    'x': images
                },
                'labels': None,
                'sitk': t2_sitk,
                'subject_id': subject_id
            }

        lbl = sitk.GetArrayFromImage(
            sitk.ReadImage(str(img_path + img_prefix + label_postfix))).astype(
                np.int32)

        # Remove other class labels to leave just the grey matter
        lbl[lbl != 2.] = 0.
        lbl[lbl == 2.] = 1.
        # Augment if in training
        if mode == tf.estimator.ModeKeys.TRAIN:
            images, lbl = _augment(images, lbl)

        # Check if reader is returning training examples or full images
        if params['extract_examples']:
            # print("extracting training examples (not full images)")
            n_examples = params['n_examples']
            example_size = params['example_size']
            #images = images.reshape([lbl.shape[0], lbl.shape[1], lbl.shape[2], NUM_CHANNELS])
            images, lbl = extract_class_balanced_example_array(
                image=images,
                label=lbl,
                example_size=example_size,
                n_examples=n_examples,
                classes=NUM_CLASSES)

            assert not np.any(np.isnan(images))
            for e in range(n_examples):
                yield {
                    'features': {
                        'x': images[e].astype(np.float32)
                    },
                    'labels': {
                        'y': lbl[e].astype(np.int32)
                    },
                    'subject_id': subject_id
                }
        else:
            #images = images.reshape([lbl.shape[0],lbl.shape[1], lbl.shape[2], 1])
            yield {
                'features': {
                    'x': images
                },
                'labels': {
                    'y': lbl
                },
                'sitk': t2_sitk,
                'subject_id': subject_id
            }

    return
Ejemplo n.º 2
0
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.

        Args:
            file_references (list): A list of lists containing file references, such
                as [['id_0', 'image_filename_0', target_value_0], ...,
                ['id_N', 'image_filename_N', target_value_N]].
            mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
                PREDICT.
            params (dict, optional): A dictionary to parameterise read_fn ouputs
                (e.g. reader_params = {'n_examples': 10, 'example_size':
                [64, 64, 64], 'extract_examples': True}, etc.).

        Yields:
            dict: A dictionary of reader outputs for dltk.io.abstract_reader.
        """

    def _augment(img, lbl):
        """An image augmentation function"""
        img = add_gaussian_noise(img, sigma=0.1)
        [img, lbl] = flip([img, lbl], axis=1)
        return img, lbl

    def get_config_for_app():
        dir_fn = os.path.dirname(__file__)[:-7]  # Remove 'readers' from filepath
        app_fn = os.path.join(dir_fn, 'app_config.json')
        with open(app_fn) as json_data:
            app_json = json.load(json_data)
        return app_json

    for f in file_references:
        subject_id = f[0]
        slice_index = int(f[3])
        man_path = f[4]
        img_path = f[5]
        img_prefix = f[6]

        app_json = get_config_for_app()
        sitk_ref = None
        inputs_to_stack = []
        for i, input_type in enumerate(app_json['input_postfix']):
            # Read the image nii with sitk and keep the pointer to the sitk.Image of an input
            im_sitk = sitk.ReadImage(os.path.join(img_path, str(img_prefix + input_type)))
            im = sitk.GetArrayFromImage(im_sitk)
        # Drop all unannotated slices
            im = im[slice_index, :, :]
            im = whitening(im)
            inputs_to_stack.append(im)
            if i == 0:
                sitk_ref = im_sitk

        # Create a 4D multi-sequence image (i.e. [channels, x,y,z])
        images = np.stack(inputs_to_stack, axis=-1).astype(np.float32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            yield {'features': {'x': images},
                   'labels': None,
                   'sitk': sitk_ref,
                   'subject_id': subject_id,
                   'path': img_path,
                   'prefix': img_prefix}

        lbl = sitk.GetArrayFromImage(sitk.ReadImage(man_path)).astype(
            np.int32)

        # Drop unnanotated slices

        lbl = lbl[slice_index, :, :]

        # Augment if in training
        if mode == tf.estimator.ModeKeys.TRAIN:
            images, lbl = _augment(images, lbl)

        # Check if reader is returning training examples or full images
        if params['extract_examples']:
            # print("extracting training examples (not full images)")
            n_examples = params['n_examples']
            example_size = params['example_size']
            lbl = lbl.reshape([1, lbl.shape[0], lbl.shape[1]])
            images = images.reshape([lbl.shape[0], lbl.shape[1], lbl.shape[2], app_json['num_channels']])

            images, lbl = extract_class_balanced_example_array(
                image=images,
                label=lbl,
                example_size=example_size,
                n_examples=n_examples,
                classes=app_json['num_classes'])

            for e in range(n_examples):
                yield {'features': {'x': images[e].astype(np.float32)},
                       'labels': {'y': lbl[e].astype(np.int32)},
                       'subject_id': subject_id,
                       'path': img_path,
                       'prefix': img_prefix
                      }
        else:
            lbl = lbl.reshape([1, lbl.shape[0], lbl.shape[1]])
            images = images.reshape([lbl.shape[0], lbl.shape[1], lbl.shape[2], app_json['num_channels']])
            print("extracting full images (not training examples)")
            yield {'features': {'x': images},
                   'labels': {'y': lbl},
                   'sitk': im_sitk,
                   'subject_id': subject_id,
                   'slice_index': slice_index,
                   'path': img_path,
                   'prefix': img_prefix}

    return
Ejemplo n.º 3
0
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.

    Args:
        file_references (list): A list of lists containing file references, such
            as [['id_0', 'image_filename_0', target_value_0], ...,
            ['id_N', 'image_filename_N', target_value_N]].
        mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
            PREDICT.
        params (dict, optional): A dictionary to parameterise read_fn ouputs
            (e.g. reader_params = {'n_examples': 10, 'example_size':
            [64, 64, 64], 'extract_examples': True}, etc.).

    Yields:
        dict: A dictionary of reader outputs for dltk.io.abstract_reader.
    """
    def _augment(img, lbl):
        """An image augmentation function"""
        img = add_gaussian_noise(img, sigma=0.1)
        # [img, lbl] = flip([img, lbl], axis=0)

        return img, lbl

    for f in file_references:
        subject_id = f[0]
        img_fn = f[1]

        # Read the image nii with sitk and keep the pointer to the sitk.Image of an input
        t1_sitk = sitk.ReadImage(
            str(img_fn).replace("/IBSR/", "/IBSR_preprocessed/").replace(
                "_ana_strip.nii.gz", "_ana_strip_1mm_center_cropped.nii.gz"))
        t1 = ((np.clip(sitk.GetArrayFromImage(t1_sitk), 0., 100.) - 0.) /
              100.).swapaxes(0, 1)

        lbl_sitk = sitk.ReadImage(
            str(img_fn).replace("/IBSR/", "/IBSR_preprocessed/").replace(
                "_ana_strip.nii.gz", "_seg_ana_1mm_center_cropped.nii.gz"))
        lbl = sitk.GetArrayFromImage(lbl_sitk).astype(np.int32).swapaxes(0, 1)

        # Create a 4D multi-sequence image (i.e. [channels, x, y, z])
        images = np.stack([t1], axis=-1).astype(np.float32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            yield {
                'features': {
                    'x': images
                },
                'labels': None,
                'sitk': t1_sitk,
                'subject_id': subject_id
            }

        # Augment if used in training mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            images, lbl = _augment(images, lbl)

        # Check if the reader is supposed to return training examples or full
        #  images
        if params['extract_examples']:
            n_examples = params['n_examples']
            example_size = params['example_size']
            class_weights = params[
                'class_weights'] if "class_weights" in params else None

            images, lbl = extract_class_balanced_example_array(
                image=images,
                label=lbl,
                example_size=example_size,
                n_examples=n_examples,
                classes=NUM_CLASSES,
                class_weights=class_weights)

            for e in range(len(images)):
                yield {
                    'features': {
                        'x': images[e].astype(np.float32)
                    },
                    'labels': {
                        'y': lbl[e].astype(np.int32)
                    },
                    'subject_id': subject_id
                }
        else:
            yield {
                'features': {
                    'x': images
                },
                'labels': {
                    'y': lbl
                },
                'sitk': t1_sitk,
                'subject_id': subject_id
            }

    return
Ejemplo n.º 4
0
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.

    Args:
        file_references (list): A list of lists containing file references, such
            as [['id_0', 'image_filename_0', target_value_0], ...,
            ['id_N', 'image_filename_N', target_value_N]].
        mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
            PREDICT.
        params (dict, optional): A dictionary to parameterise read_fn ouputs
            (e.g. reader_params = {'n_examples': 10, 'example_size':
            [64, 64, 64], 'extract_examples': True}, etc.).

    Yields:
        dict: A dictionary of reader outputs for dltk.io.abstract_reader.
    """

    def _augment(img, lbl):
        """An image augmentation function"""
        img = add_gaussian_noise(img, sigma=0.1)
        [img, lbl] = flip([img, lbl], axis=1)

        return img, lbl

    for f in file_references:
        subject_id = f[0]
        img_fn = f[1]

        # Read the image nii with sitk and keep the pointer to the sitk.Image
        # of an input
        t1_sitk = sitk.ReadImage(str(os.path.join(img_fn, 'T1.nii')))
        t1 = sitk.GetArrayFromImage(t1_sitk)
        t1_ir = sitk.GetArrayFromImage(
            sitk.ReadImage(str(os.path.join(img_fn, 'T1_IR.nii'))))
        t2_fl = sitk.GetArrayFromImage(
            sitk.ReadImage(str(os.path.join(img_fn, 'T2_FLAIR.nii'))))

        # Normalise volume images
        t1 = whitening(t1)
        t1_ir = whitening(t1_ir)
        t2_fl = whitening(t2_fl)

        # Create a 4D multi-sequence image (i.e. [channels, x, y, z])
        images = np.stack([t1, t1_ir, t2_fl], axis=-1).astype(np.float32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            yield {'features': {'x': images},
                   'labels': None,
                   'sitk': t1_sitk,
                   'subject_id': subject_id}

        lbl = sitk.GetArrayFromImage(sitk.ReadImage(str(os.path.join(
            img_fn,
            'LabelsForTraining.nii')))).astype(np.int32)

        # Augment if used in training mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            images, lbl = _augment(images, lbl)

        # Check if the reader is supposed to return training examples or full
        #  images
        if params['extract_examples']:
            n_examples = params['n_examples']
            example_size = params['example_size']

            images, lbl = extract_class_balanced_example_array(
                image=images,
                label=lbl,
                example_size=example_size,
                n_examples=n_examples,
                classes=9)

            for e in range(len(images)):
                yield {'features': {'x': images[e].astype(np.float32)},
                       'labels': {'y': lbl[e].astype(np.int32)},
                       'subject_id': subject_id}
        else:
            yield {'features': {'x': images},
                   'labels': {'y': lbl},
                   'sitk': t1_sitk,
                   'subject_id': subject_id}

    return
Ejemplo n.º 5
0
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.

            Args:
                file_references (list): A list of lists containing file references, such
                    as [['id_0', 'image_filename_0', target_value_0], ...,
                    ['id_N', 'image_filename_N', target_value_N]].
                mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
                    PREDICT.
                params (dict, optional): A dictionary to parameterise read_fn ouputs
                    (e.g. reader_params = {'n_examples': 10, 'example_size':
                    [64, 64, 64], 'extract_examples': True}, etc.).

            Yields:
                dict: A dictionary of reader outputs for dltk.io.abstract_reader.
            """
    def _augment(img, lbl):
        """An image augmentation function"""
        img = add_gaussian_noise(img, sigma=0.1)
        [img, lbl] = flip([img, lbl], axis=1)
        return img, lbl

    def get_config_for_app():
        dir_fn = os.path.dirname(
            __file__)[:-7]  # Remove 'readers' from filepath
        app_fn = os.path.join(dir_fn, 'app_config.json')
        with open(app_fn) as json_data:
            app_json = json.load(json_data)
        return app_json

    for f in file_references:

        if mode == tf.estimator.ModeKeys.TRAIN:

            if 'p' in f[0]:
                # Handle patches in here
                patch_id = f[0][-1]
                patch_path = f[1]
                assert 'p' not in patch_id
                lbl_postfix = '_emseg.nii.gz'  # TODO: change to be manual annotation patch
                app_json = get_config_for_app()
                sitk_ref = None
                inputs_to_stack = []
                for i, input_type in enumerate(app_json['input_postfix']):
                    # Read the image nii with sitk and keep the pointer to the sitk.Image of an input
                    im_sitk = sitk.ReadImage(
                        os.path.join(patch_path,
                                     str(str(patch_id) + '_' + input_type)))
                    im = sitk.GetArrayFromImage(im_sitk)
                    im = whitening(im)
                    inputs_to_stack.append(im)
                    if i == 0:
                        sitk_ref = im_sitk

                # Create a 4D multi-sequence image (i.e. [channels, x,y,z])
                # print("Correct stacking: ", len(inputs_to_stack) == len(app_json['input_postfix']))
                images = np.stack(inputs_to_stack, axis=-1).astype(np.float32)

                lbl_sitk = sitk.ReadImage(
                    os.path.join(patch_path, str(str(patch_id) + lbl_postfix)))
                lbl = sitk.GetArrayFromImage(lbl_sitk).astype(np.int32)

                images, lbl = _augment(images, lbl)

                yield {
                    'features': {
                        'x': images
                    },
                    'labels': {
                        'y': lbl
                    },
                    'sitk': sitk_ref
                }

            else:  # Read stack
                subject_id = f[0]
                slice_index = int(f[3])
                man_path = f[4]
                image_path = f[5]
                image_prefix = f[6]

                app_json = get_config_for_app()
                sitk_ref = None
                inputs_to_stack = []
                for i, input_type in enumerate(app_json['input_postfix']):
                    # Read the image nii with sitk and keep the pointer to the sitk.Image of an input
                    im_sitk = sitk.ReadImage(
                        os.path.join(image_path, image_prefix + input_type))
                    im = sitk.GetArrayFromImage(im_sitk)
                    im = im[slice_index, :, :]
                    im = whitening(im)
                    inputs_to_stack.append(im)
                    if i == 0:
                        sitk_ref = im_sitk

                # Create a 4D multi-sequence image (i.e. [channels, x,y,z])
                # print("Correct stacking: ", len(inputs_to_stack) == len(app_json['input_postfix']))
                images = np.stack(inputs_to_stack, axis=-1).astype(np.float32)

                lbl_sitk = sitk.ReadImage(os.path.join(man_path))
                lbl = sitk.GetArrayFromImage(lbl_sitk).astype(np.int32)

                # Remove other class labels to leave just the grey matter
                #lbl[lbl != 2.] = 0.
                #lbl[lbl == 2.] = 1.
                lbl = lbl[slice_index, :, :]

                # Augment if in training
                images, lbl = _augment(images, lbl)

                # Check if reader is returning training examples or full images
                if params['extract_examples']:
                    # print("extracting training examples (not full images)")
                    n_examples = params['n_examples']
                    example_size = params['example_size']
                    lbl = lbl.reshape([1, lbl.shape[0], lbl.shape[1]])
                    images = images.reshape([
                        lbl.shape[0], lbl.shape[1], lbl.shape[2],
                        app_json['num_channels']
                    ])

                    images, lbl = extract_class_balanced_example_array(
                        image=images,
                        label=lbl,
                        example_size=example_size,
                        n_examples=n_examples,
                        classes=app_json['num_classes'])

                    assert not np.any(np.isnan(images))
                    for e in range(n_examples):
                        yield {
                            'features': {
                                'x': images[e].astype(np.float32)
                            },
                            'labels': {
                                'y': lbl[e].astype(np.int32)
                            },
                            'subject_id': subject_id
                        }
                else:
                    lbl = lbl.reshape([1, lbl.shape[0], lbl.shape[1]])
                    images = images.reshape([
                        lbl.shape[0], lbl.shape[1], lbl.shape[2],
                        app_json['num_channels']
                    ])
                    assert not np.any(np.isnan(images))
                    assert sitk_ref is not None
                    yield {
                        'features': {
                            'x': images
                        },
                        'labels': {
                            'y': lbl
                        },
                        'sitk': sitk_ref,
                        'subject_id': subject_id,
                        'path': subj_path,
                        'prefix': subj_prefix
                    }

        elif mode == tf.estimator.ModeKeys.EVAL:
            # Handle Eval stacks in here
            subject_id = f[0]
            slice_index = int(f[3])
            subj_path = f[5]
            subj_prefix = f[6]
            man_path = f[4]
            #stack_folder_path = subj_path

            app_json = get_config_for_app()
            sitk_ref = None
            inputs_to_stack = []
            for i, input_type in enumerate(app_json['input_postfix']):
                # Read the image nii with sitk and keep the pointer to the sitk.Image of an input
                im_sitk = sitk.ReadImage(
                    os.path.join(subj_path, subj_prefix + input_type))
                im = sitk.GetArrayFromImage(im_sitk)
                im = im[slice_index, :, :]
                im = whitening(im)
                inputs_to_stack.append(im)
                if i == 0:
                    sitk_ref = im_sitk

            # Create a 4D multi-sequence image (i.e. [channels, x,y,z])
            # print("Correct stacking: ", len(inputs_to_stack) == len(app_json['input_postfix']))
            images = np.stack(inputs_to_stack, axis=-1).astype(np.float32)

            lbl_sitk = sitk.ReadImage(os.path.join(man_path))
            lbl = sitk.GetArrayFromImage(lbl_sitk).astype(np.int32)
            lbl = lbl[slice_index, :, :]
            # Remove other class labels to leave just the grey matter
            #lbl[lbl != 2.] = 0.
            #lbl[lbl == 2.] = 1.

            # Check if reader is returning training examples or full images
            if params['extract_examples']:
                # print("extracting training examples (not full images)")
                n_examples = params['n_examples']
                example_size = params['example_size']
                lbl = lbl.reshape([1, lbl.shape[0], lbl.shape[1]])
                images = images.reshape([
                    lbl.shape[0], lbl.shape[1], lbl.shape[2],
                    app_json['num_channels']
                ])

                images, lbl = extract_class_balanced_example_array(
                    image=images,
                    label=lbl,
                    example_size=example_size,
                    n_examples=n_examples,
                    classes=app_json['num_classes'])

                assert not np.any(np.isnan(images))
                for e in range(n_examples):
                    yield {
                        'features': {
                            'x': images[e].astype(np.float32)
                        },
                        'labels': {
                            'y': lbl[e].astype(np.int32)
                        },
                        'subject_id': subject_id
                    }
            else:
                lbl = lbl.reshape([1, lbl.shape[0], lbl.shape[1]])
                images = images.reshape([
                    lbl.shape[0], lbl.shape[1], lbl.shape[2],
                    app_json['num_channels']
                ])
                assert not np.any(np.isnan(images))
                assert sitk_ref is not None
                yield {
                    'features': {
                        'x': images
                    },
                    'labels': {
                        'y': lbl
                    },
                    'sitk': sitk_ref,
                    'subject_id': subject_id,
                    'path': subj_path,
                    'prefix': subj_prefix
                }

    return
Ejemplo n.º 6
0
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.

    Args:
        file_references (list): A list of lists containing file references,
            such as [['id_0', 'image_filename_0', target_value_0], ...,
             ['id_N', 'image_filename_N', target_value_N]].
        mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL
            or PREDICT.
        params (dict, optional): A dictionary to parameterise read_fn ouputs
            (e.g. reader_params = {'n_examples': 10, 'example_size':
            [64, 64, 64], 'extract_examples': True}, etc.).

    Yields:
        dict: A dictionary of reader outputs for dltk.io.abstract_reader.
    """

    def _augment(img, lbl):
        """An image augmentation function."""

        img = add_gaussian_offset(img, sigma=1.0)
        for a in range(3):
            [img, lbl] = flip([img, lbl], axis=a)

        return img, lbl

    def _map_labels(lbl, convert_to_protocol=False):
        """Map dataset specific label id protocols to consecutive integer ids
            for training and back.

            iFind segment ids:
                0 background
                2 brain
                9 placenta
                10 uterus ROI

        Args:
            lbl (np.array): A label map to be converted.
            convert_to_protocol (bool, optional) A flag to determine to convert
                from or to the protocol ids.

        Returns:
            np.array: The converted label map

        """

        ids = [0, 2]

        out_lbl = np.zeros_like(lbl)

        if convert_to_protocol:

            # Map from consecutive ints to protocol labels
            for i in range(len(ids)):
                out_lbl[lbl == i] = ids[i]
        else:

            # Map from protocol labels to consecutive ints
            for i in range(len(ids)):
                out_lbl[lbl == ids[i]] = i

        return out_lbl

    for f in file_references:

        # Read the image nii with sitk
        img_id = f[0]
        img_fn = f[1]
        img_sitk = sitk.ReadImage(str(img_fn))
        img = sitk.GetArrayFromImage(img_sitk)

        # Normalise volume image
        img = whitening(img)

        # Create a 4D image (i.e. [x, y, z, channels])
        images = np.expand_dims(img, axis=-1).astype(np.float32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            yield {'features': {'x': images}, 'labels': {'y': np.array([0])}, 'sitk': img_sitk, 'img_id': img_id}
            continue

        # Read the label nii with sitk
        lbl_fn = f[2]
        lbl = sitk.GetArrayFromImage(sitk.ReadImage(str(lbl_fn))).astype(np.int32)

        # Map the label ids to consecutive integers
        lbl = _map_labels(lbl)

        # Augment if used in training mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            images, lbl = _augment(images, lbl)

        # Check if the reader is supposed to return training examples or
        # full images
        if params['extract_examples']:
            images, lbl = extract_class_balanced_example_array(
                images,
                lbl,
                example_size=params['example_size'],
                n_examples=params['n_examples'], classes=2)

            for e in range(params['n_examples']):
                yield {'features': {'x': images[e].astype(np.float32)},
                       'labels': {'y': lbl[e].astype(np.int32)}}
        else:
            yield {'features': {'x': images},
                   'labels': {'y': lbl},
                   'sitk': img_sitk,
                   'img_id': img_id}
    return
Ejemplo n.º 7
0
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.

    Args:
        file_references (list): A list of lists containing file references, such
            as [['id_0', 'image_filename_0', target_value_0], ...,
            ['id_N', 'image_filename_N', target_value_N]].
        mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
            PREDICT.
        params (dict, optional): A dictionary to parameterise read_fn ouputs
            (e.g. reader_params = {'n_examples': 10, 'example_size':
            [64, 64, 64], 'extract_examples': True}, etc.).

    Yields:
        dict: A dictionary of reader outputs for dltk.io.abstract_reader.
    """
    def _augment(img, lbl):
        """An image augmentation function"""
        img = add_gaussian_noise(img, sigma=0.1)
        [img, lbl] = flip([img, lbl], axis=1)

        return img, lbl

    for f in file_references:
        subject_id = f[0]
        img_fn = f[1]

        # Read the image nii with sitk and keep the pointer to the sitk.Image
        # of an input
        t1_sitk = sitk.ReadImage(os.path.join(str(img_fn), 'T1.nii'))
        t1 = sitk.GetArrayFromImage(t1_sitk)
        t1_ir = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(str(img_fn), 'T1_IR.nii')))
        t2_fl = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(str(img_fn), 'T2_FLAIR.nii')))

        # Normalise volume images
        t1 = whitening(t1)
        t1_ir = whitening(t1_ir)
        t2_fl = whitening(t2_fl)

        # Create a 4D multi-sequence image (i.e. [channels, x, y, z])
        images = np.stack([t1, t1_ir, t2_fl], axis=-1).astype(np.float32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            yield {
                'features': {
                    'x': images
                },
                'labels': None,
                'sitk': t1_sitk,
                'subject_id': subject_id
            }

        lbl = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(
                str(img_fn), 'LabelsForTraining.nii'))).astype(np.int32)

        # Augment if used in training mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            images, lbl = _augment(images, lbl)

        # Check if the reader is supposed to return training examples or full
        #  images
        if params['extract_examples']:
            n_examples = params['n_examples']
            example_size = params['example_size']

            images, lbl = extract_class_balanced_example_array(
                image=images,
                label=lbl,
                example_size=example_size,
                n_examples=n_examples,
                classes=9)

            for e in range(n_examples):
                yield {
                    'features': {
                        'x': images[e].astype(np.float32)
                    },
                    'labels': {
                        'y': lbl[e].astype(np.int32)
                    },
                    'subject_id': subject_id
                }
        else:
            yield {
                'features': {
                    'x': images
                },
                'labels': {
                    'y': lbl
                },
                'sitk': t1_sitk,
                'subject_id': subject_id
            }

    return
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.
    Args:
        file_references (list):
        mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
            PREDICT.
        params (dict, optional): A dictionary to parametrise read_fn outputs
            (e.g. reader_params = {'n_patches': 10, 'patch_size':
            [64, 64, 64], 'extract_patches': True}, etc.).
    Yields:
        dict: A dictionary of reader outputs for dltk.io.abstract_reader.
    """

    # Data Augmentation (Gaussian Blurring, Axial-Plane Horizontal Flip)
    def _augment(img, lbl):
        if (np.random.randint(0, 10) < 3):
            img = scipy.ndimage.gaussian_filter(img, sigma=0.25)
        [img, lbl] = flip([img, lbl], axis=0)
        return img, lbl

    # Crop Central Block Matching Output of DeepMedic
    def crop_central_block_label(img, cropz, cropy, cropx):
        _, z, y, x = img.shape
        startx = x // 2 - (cropx // 2) - 1
        starty = y // 2 - (cropy // 2) - 1
        startz = z // 2 - (cropz // 2) - 1
        return img[:, startz:startz + cropz, starty:starty + cropy,
                   startx:startx + cropx]

    for f in file_references:
        t0 = time.time()

        scan_id = str(f[0])
        img_itk = sitk.ReadImage(str(f[1]), sitk.sitkFloat32)
        img = np.expand_dims(np.array(sitk.GetArrayFromImage(img_itk)), axis=3)
        lbl = np.array(sitk.GetArrayFromImage(sitk.ReadImage(str(
            f[2])))).astype(np.int32)

        print('Loaded {}; Time = {}'.format(scan_id, (time.time() - t0)))

        # Testing Mode
        if (mode == tf.estimator.ModeKeys.PREDICT):
            yield {
                'features': {
                    'x': img
                },
                'labels': None,
                'sitk': img_itk,
                'subject_id': scan_id
            }

        # Training Mode
        if (mode == tf.estimator.ModeKeys.TRAIN):
            img, lbl = _augment(img, lbl)

        # Return Training Examples
        if params['extract_patches']:
            img, lbl = extract_class_balanced_example_array(
                img,
                lbl,
                example_size=params['patch_size'],
                n_examples=params['n_patches'],
                classes=4,
                class_weights=[0, 1, 1, 1])

            lbl = crop_central_block_label(lbl, 9, 9, 9)

            for e in range(params['n_patches']):
                yield {
                    'features': {
                        'x': img[e].astype(np.float32)
                    },
                    'labels': {
                        'y': lbl[e].astype(np.int32)
                    },
                    'img_id': scan_id
                }

        # Return Full Images
        else:
            yield {
                'features': {
                    'x': img
                },
                'labels': {
                    'y': lbl
                },
                'sitk': img_itk,
                'img_id': scan_id
            }

    return
Ejemplo n.º 9
0
def read_fn(file_references, mode, params=None):
    """Summary

    Args:
        file_references (TYPE): Description
        mode (TYPE): Description
        params (TYPE): Description

    Returns:
        TYPE: Description
    """
    for f in file_references:
        img_fn = str(f[0])

        img_name = img_fn.split('/')[-1].split('.')[0]

        # Use a SimpleITK reader to load the multi channel
        # nii images and labels for training
        img_sitk = sitk.ReadImage(img_fn)
        images = sitk.GetArrayFromImage(img_sitk)

        images = np.expand_dims(images, axis=3)

        if mode == tf.estimator.ModeKeys.PREDICT:
            yield {
                'features': {
                    'x': images
                },
                'labels': None,
                'img_name': img_name,
                'sitk': img_sitk
            }
        else:
            lbl_fn = str(f[1])
            lbl = sitk.GetArrayFromImage(sitk.ReadImage(lbl_fn)).astype(
                np.int32)

            # Augment if used in training mode
            if mode == tf.estimator.ModeKeys.TRAIN:
                pass

            # Check if the reader is supposed to return
            # training examples or full images
            if params['extract_examples']:
                n_examples = params['n_examples']
                example_size = params['example_size']

                images, lbl = extract_class_balanced_example_array(
                    images,
                    lbl,
                    example_size=example_size,
                    n_examples=n_examples,
                    classes=14)

                for e in range(len(images)):
                    yield {
                        'features': {
                            'x': images[e].astype(np.float32)
                        },
                        'labels': {
                            'y': lbl[e].astype(np.int32)
                        }
                    }
            else:
                yield {
                    'features': {
                        'x': images
                    },
                    'labels': {
                        'y': lbl
                    },
                    'img_name': img_name,
                    'sitk': img_sitk
                }

    return
def read_fn(file_references, mode, params=None):
    """A custom python read function for interfacing with nii image files.
    Args:
        file_references (list):
        mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
            PREDICT.
        params (dict, optional): A dictionary to parametrise read_fn outputs
            (e.g. reader_params = {'n_patches': 10, 'patch_size':
            [64, 64, 64], 'extract_patches': True}, etc.).
    Yields:
        dict: A dictionary of reader outputs for dltk.io.abstract_reader.
    """

    # Data Augmentation (Gaussian Blurring, Axial-Plane Horizontal Flip)
    def _augment(img, lbl):
        if (np.random.randint(0, 10) < 3):
            img = scipy.ndimage.gaussian_filter(img, sigma=0.25)
            [img, lbl] = flip([img, lbl], axis=2)
        return img, lbl

    for f in file_references:
        t0 = time.time()

        scan_id = str(f[0])
        img_itk = sitk.ReadImage(str(f[1]), sitk.sitkFloat32)
        img = np.array(sitk.GetArrayFromImage(img_itk))
        img = np.power(img, 1.05)
        img = whitening(img)
        img = np.expand_dims(img, axis=3)

        print('Loaded {}; Time = {}'.format(scan_id, (time.time() - t0)))

        # Testing Mode with No Labels
        if (mode == tf.estimator.ModeKeys.PREDICT):
            yield {
                'features': {
                    'x': img
                },
                'labels': None,
                'sitk': img_itk,
                'img_id': scan_id
            }

        # Load Labels for Training/Validation Mode
        elif (mode == tf.estimator.ModeKeys.TRAIN) | (
                mode == tf.estimator.ModeKeys.EVAL):
            lbl = np.array(sitk.GetArrayFromImage(sitk.ReadImage(str(
                f[2])))).astype(np.int32)

            # Data Augmentation for Training Mode Only
            if (mode == tf.estimator.ModeKeys.TRAIN):
                img, lbl = _augment(img, lbl)

            # Return Training Examples
            if params['extract_patches']:
                img, lbl = extract_class_balanced_example_array(
                    img,
                    lbl,
                    example_size=params['patch_size'],
                    n_examples=params['n_patches'],
                    classes=4,
                    class_weights=[0, 0.5, 0.3, 0.2])

                for e in range(params['n_patches']):
                    yield {
                        'features': {
                            'x': img[e].astype(np.float32)
                        },
                        'labels': {
                            'y': lbl[e].astype(np.int32)
                        },
                        'img_id': scan_id
                    }

            # Return Full Images
            else:
                yield {
                    'features': {
                        'x': img
                    },
                    'labels': {
                        'y': lbl
                    },
                    'sitk': img_itk,
                    'img_id': scan_id
                }
    return