def test_existing_config_file_loaded(self):
        # create a config file with a custom NiftyNet home
        makedirs(NiftyNetGlobalConfigTest.config_home)
        custom_niftynet_home = '~/customniftynethome'
        custom_niftynet_home_abs = expanduser(custom_niftynet_home)
        config = ''.join(['home = ', custom_niftynet_home])
        with open(NiftyNetGlobalConfigTest.config_file, 'w') as config_file:
            config_file.write('\n'.join(
                [NiftyNetGlobalConfigTest.header, config]))

        global_config = NiftyNetGlobalConfig().setup()
        self.assertEqual(global_config.get_niftynet_home_folder(),
                         custom_niftynet_home_abs)
        NiftyNetGlobalConfigTest.remove_path(custom_niftynet_home_abs)
예제 #2
0
class ImageReader(Layer):
    """
    For a concrete example:
    _input_sources define multiple modality mappings, e.g.,
    _input_sources {'image': ('T1', 'T2'),
                    'label': ('manual_map',)}
    means
    'image' consists of two components, formed by
    concatenating 'T1' and 'T2' input source images.
    'label' consists of one component, loading from 'manual_map'

    self._names: a tuple of the output names of this reader.
    ('image', 'labels')

    self._shapes: the shapes after combining input sources
    {'image': (192, 160, 192, 1, 2), 'label': (192, 160, 192, 1, 1)}

    self._dtypes: store the dictionary of tensorflow shapes
    {'image': tf.float32, 'label': tf.float32}

    self.output_list is a list of dictionaries, with each item:
    {'image': <niftynet.io.image_type.SpatialImage4D object>,
     'label': <niftynet.io.image_type.SpatialImage3D object>}
    """
    def __init__(self, names):
        # list of file names
        self._file_list = None
        self._input_sources = None
        self._shapes = None
        self._dtypes = None
        self._names = None
        self.names = names

        self._global_config = NiftyNetGlobalConfig()

        # list of image objects
        self.output_list = None
        self.current_id = -1

        self.preprocessors = []
        super(ImageReader, self).__init__(name='image_reader')

    def initialise_reader(self, data_param, task_param):
        """
        task_param specifies how to combine user input modalities
        e.g., for multimodal segmentation 'image' corresponds to multiple
        modality sections, 'label' corresponds to one modality section
        """
        if not self.names:
            tf.logging.fatal('Please specify data names, this should '
                             'be a subset of SUPPORTED_INPUT provided '
                             'in application file')
            raise ValueError
        self._names = [
            name for name in self.names if vars(task_param).get(name, None)
        ]

        self._input_sources = {
            name: vars(task_param).get(name)
            for name in self.names
        }
        data_to_load = {}
        for name in self._names:
            for source in self._input_sources[name]:
                try:
                    data_to_load[source] = data_param[source]
                except KeyError:
                    tf.logging.fatal(
                        'reader name [%s] requires [%s], however it is not '
                        'specified as a section in the config, '
                        'current input section names: %s', name, source,
                        list(data_param))
                    raise ValueError

        default_data_folder = self._global_config.get_niftynet_home_folder()
        self._file_list = util_csv.load_and_merge_csv_files(
            data_to_load, default_data_folder)
        self.output_list = _filename_to_image_list(self._file_list,
                                                   self._input_sources,
                                                   data_param)
        for name in self.names:
            tf.logging.info('image reader: loading [%s] from %s (%d)', name,
                            self.input_sources[name], len(self.output_list))

    def prepare_preprocessors(self):
        for layer in self.preprocessors:
            if isinstance(layer, DataDependentLayer):
                layer.train(self.output_list)

    def add_preprocessing_layers(self, layers):
        assert self.output_list is not None, \
            'Please initialise the reader first, ' \
            'before adding preprocessors.'
        if isinstance(layers, Layer):
            self.preprocessors.append(layers)
        else:
            self.preprocessors.extend(layers)
        self.prepare_preprocessors()

    # pylint: disable=arguments-differ
    def layer_op(self, idx=None, shuffle=True):
        """
        this layer returns a dictionary
          keys: self.output_fields
          values: image volume array
        """
        if idx is None and shuffle:
            # training, with random list output
            idx = np.random.randint(len(self.output_list))

        if idx is None and not shuffle:
            # testing, with sequential output
            # accessing self.current_id, not suitable for multi-thread
            idx = self.current_id + 1
            self.current_id = idx

        try:
            idx = int(idx)
        except ValueError:
            idx = -1

        if idx < 0 or idx >= len(self.output_list):
            return -1, None, None

        image_dict = self.output_list[idx]
        image_data_dict = {
            field: image.get_data()
            for (field, image) in image_dict.items()
        }
        interp_order_dict = {
            field: image.interp_order
            for (field, image) in image_dict.items()
        }
        if self.preprocessors:
            preprocessors = [deepcopy(layer) for layer in self.preprocessors]
            # dictionary of masks is cached
            mask = None
            for layer in preprocessors:
                # import time; local_time = time.time()
                if layer is None:
                    continue
                if isinstance(layer, RandomisedLayer):
                    layer.randomise()
                    image_data_dict = layer(image_data_dict, interp_order_dict)
                else:
                    image_data_dict, mask = layer(image_data_dict, mask)
                # print('%s, %.3f sec'%(layer, -local_time + time.time()))
        return idx, image_data_dict, interp_order_dict

    @property
    def shapes(self):
        """
        image shapes before any preprocessing
        :return: tuple of integers as image shape
        """
        # to have fast access, the spatial dimensions are not accurate
        # 1) only read from the first image in list
        # 2) not considering effects of random augmentation layers
        # but time and modality dimensions should be correct
        if not self.output_list:
            tf.logging.fatal("please initialise the reader first")
            raise RuntimeError
        if not self._shapes:
            first_image = self.output_list[0]
            self._shapes = {
                field: first_image[field].shape
                for field in self.names
            }
        return self._shapes

    @property
    def tf_dtypes(self):
        if not self.output_list:
            tf.logging.fatal("please initialise the reader first")
            raise RuntimeError
        if not self._dtypes:
            first_image = self.output_list[0]
            self._dtypes = {
                field: infer_tf_dtypes(first_image[field])
                for field in self.names
            }
        return self._dtypes

    @property
    def input_sources(self):
        if not self._input_sources:
            tf.logging.fatal("please initialise the reader first")
            raise RuntimeError
        return self._input_sources

    @property
    def names(self):
        return self._names

    @names.setter
    def names(self, fields_tuple):
        # output_fields is a sequence of output names
        # each name might correspond to a list of multiple input sources
        # this should be specified in CUSTOM section in the config
        self._names = make_input_tuple(fields_tuple, string_types)

    def get_subject_id(self, image_index):
        return self._file_list.iloc[image_index, 0]