def test_incorrect_config_file_backed_up(self):
        # create an incorrect config file at the correct location
        makedirs(NiftyNetGlobalConfigTest.config_home)
        incorrect_config = '\n'.join([NiftyNetGlobalConfigTest.header,
                                      'invalid_home_tag = ~/niftynet'])
        with open(NiftyNetGlobalConfigTest.config_file, 'w') as config_file:
            config_file.write(incorrect_config)

        # the following should back it up and replace it with default config
        global_config = NiftyNetGlobalConfig().setup()

        self.assertTrue(isfile(NiftyNetGlobalConfigTest.config_file))
        self.assertEqual(global_config.get_niftynet_config_folder(),
                         NiftyNetGlobalConfigTest.config_home)

        # check if incorrect file was backed up
        found_files = glob(
            join(NiftyNetGlobalConfigTest.config_home,
                 NiftyNetGlobalConfigTest.typify('config-backup-*')))
        self.assertTrue(len(found_files) == 1)
        with open(found_files[0], 'r') as backup_file:
            self.assertEqual(backup_file.read(), incorrect_config)

        # cleanup: remove backup file
        NiftyNetGlobalConfigTest.remove_path(found_files[0])
    def test_incorrect_config_file_backed_up(self):
        # create an incorrect config file at the correct location
        makedirs(NiftyNetGlobalConfigTestBase.config_home)
        incorrect_config = '\n'.join([
            NiftyNetGlobalConfigTestBase.header,
            'invalid_home_tag = ~/niftynet'
        ])
        with open(NiftyNetGlobalConfigTestBase.config_file,
                  'w') as config_file:
            config_file.write(incorrect_config)

        # the following should back it up and replace it with default config
        global_config = NiftyNetGlobalConfig().setup()

        self.assertTrue(isfile(NiftyNetGlobalConfigTestBase.config_file))
        self.assertEqual(global_config.get_niftynet_config_folder(),
                         NiftyNetGlobalConfigTestBase.config_home)

        # check if incorrect file was backed up
        found_files = glob(
            join(NiftyNetGlobalConfigTestBase.config_home,
                 NiftyNetGlobalConfigTestBase.typify('config-backup-*')))
        self.assertTrue(len(found_files) == 1)
        with open(found_files[0], 'r') as backup_file:
            self.assertEqual(backup_file.read(), incorrect_config)

        # cleanup: remove backup file
        NiftyNetGlobalConfigTestBase.remove_path(found_files[0])
예제 #3
0
def download(example_ids,
             download_if_already_existing=False,
             verbose=True):
    """
    Downloads standard NiftyNet examples such as data, samples

    :param example_ids:
        A list of identifiers for the samples to download
    :param download_if_already_existing:
        If true, data will always be downloaded
    :param verbose:
        If true, download info will be printed
    """

    global_config = NiftyNetGlobalConfig()

    config_store = ConfigStore(global_config, verbose=verbose)

    # If a single id is specified, convert to a list
    example_ids = [example_ids] \
        if not isinstance(example_ids, (tuple, list)) else example_ids

    if not example_ids:
        return False

    # Check if the server is running by looking for a known file
    remote_base_url_test = gitlab_raw_file_url(
        global_config.get_download_server_url(), 'README.md')
    server_ok = url_exists(remote_base_url_test)
    if verbose:
        print("Accessing: {}".format(global_config.get_download_server_url()))

    any_error = False
    for example_id in example_ids:
        if not example_id:
            any_error = True
            continue
        if config_store.exists(example_id):
            update_ok = config_store.update_if_required(
                example_id,
                download_if_already_existing)
            any_error = (not update_ok) or any_error
        else:
            any_error = True
            if server_ok:
                print(example_id + ': FAIL. ')
                print('No NiftyNet example was found for ' + example_id + ".")

    # If errors occurred and the server is down report a message
    if any_error and not server_ok:
        print("The NiftyNetExamples server is not running")

    return not any_error
    def test_existing_config_file_loaded(self):
        # create a config file with a custom NiftyNet home
        makedirs(NiftyNetGlobalConfigTest.config_home)
        custom_niftynet_home = '~/customniftynethome'
        custom_niftynet_home_abs = expanduser(custom_niftynet_home)
        config = ''.join(['home = ', custom_niftynet_home])
        with open(NiftyNetGlobalConfigTest.config_file, 'w') as config_file:
            config_file.write('\n'.join(
                [NiftyNetGlobalConfigTest.header, config]))

        global_config = NiftyNetGlobalConfig().setup()
        self.assertEqual(global_config.get_niftynet_home_folder(),
                         custom_niftynet_home_abs)
        NiftyNetGlobalConfigTest.remove_path(custom_niftynet_home_abs)
예제 #5
0
def download(example_ids, download_if_already_existing=False, verbose=True):
    """
    Downloads standard NiftyNet examples such as data, samples

    :param example_ids:
        A list of identifiers for the samples to download
    :param download_if_already_existing:
        If true, data will always be downloaded
    :param verbose:
        If true, download info will be printed
    """

    global_config = NiftyNetGlobalConfig()

    config_store = ConfigStore(global_config, verbose=verbose)

    # If a single id is specified, convert to a list
    example_ids = [example_ids] \
        if not isinstance(example_ids, (tuple, list)) else example_ids

    if not example_ids:
        return False

    # Check if the server is running by looking for a known file
    remote_base_url_test = gitlab_raw_file_url(
        global_config.get_download_server_url(), 'README.md')
    server_ok = url_exists(remote_base_url_test)
    if verbose:
        print("Accessing: {}".format(global_config.get_download_server_url()))

    any_error = False
    for example_id in example_ids:
        if not example_id:
            any_error = True
            continue
        if config_store.exists(example_id):
            update_ok = config_store.update_if_required(
                example_id, download_if_already_existing)
            any_error = (not update_ok) or any_error
        else:
            any_error = True
            if server_ok:
                print(example_id + ': FAIL. ')
                print('No NiftyNet example was found for ' + example_id + ".")

    # If errors occurred and the server is down report a message
    if any_error and not server_ok:
        print("The NiftyNetExamples server is not running")

    return not any_error
def __resolve_config_file_path(cmdline_arg):
    """
    Search for the absolute file name of the configuration file.
    starting from `-c` value provided by the user.

    :param cmdline_arg:
    :return:
    """
    if not cmdline_arg:
        raise IOError("\nNo configuration file has been provided, did you "
                      "forget '-c' command argument?{}".format(EPILOG_STRING))
    # Resolve relative configuration file location
    config_file_path = os.path.expanduser(cmdline_arg)
    try:
        config_file_path = resolve_file_name(config_file_path,
                                             ('.', NIFTYNET_HOME))
        if os.path.isfile(config_file_path):
            return config_file_path
    except (IOError, TypeError):
        config_file_path = os.path.expanduser(cmdline_arg)

    config_file_path = os.path.join(
        NiftyNetGlobalConfig().get_default_examples_folder(), config_file_path,
        config_file_path + "_config.ini")
    if os.path.isfile(config_file_path):
        return config_file_path

    # could not proceed without a configuration file
    raise IOError("\nConfiguration file not found: {}.{}".format(
        os.path.expanduser(cmdline_arg), EPILOG_STRING))
예제 #7
0
    def __init__(self, names):
        # list of file names
        self._file_list = None
        self._input_sources = None
        self._shapes = None
        self._dtypes = None
        self._names = None
        self.names = names

        self._global_config = NiftyNetGlobalConfig()

        # list of image objects
        self.output_list = None
        self.current_id = -1

        self.preprocessors = []
        super(ImageReader, self).__init__(name='image_reader')
예제 #8
0
 def file_path(self, path_array):
     if isinstance(path_array, string_types):
         path_array = (path_array,)
     home_folder = NiftyNetGlobalConfig().get_niftynet_home_folder()
     try:
         self._file_path = tuple(resolve_file_name(path, ('.', home_folder))
                                 for path in path_array)
     except (TypeError, AssertionError, AttributeError, IOError):
         tf.logging.fatal(
             "unrecognised file path format, should be a valid filename,"
             "or a sequence of filenames %s", path_array)
         raise IOError
예제 #9
0
파일: misc_io.py 프로젝트: zaksb/NiftyNet
def resolve_checkpoint(checkpoint_name):
    # For now only supports checkpoint_name where
    # checkpoint_name.index is in the file system
    # eventually will support checkpoint names that can be referenced
    # in a paths file.
    if os.path.isfile(checkpoint_name + '.index'):
        return checkpoint_name
    home_folder = NiftyNetGlobalConfig().get_niftynet_home_folder()
    checkpoint_name = to_absolute_path(input_path=checkpoint_name,
                                       model_root=home_folder)
    if os.path.isfile(checkpoint_name + '.index'):
        return checkpoint_name
    raise ValueError('Invalid checkpoint {}'.format(checkpoint_name))
    def reset(self):
        """
        reset all fields of this singleton class.
        """
        self._file_list = None
        self._partition_ids = None
        self.data_param = None
        self.ratios = None
        self.new_partition = False

        self.data_split_file = ""
        self.default_image_file_location = \
            NiftyNetGlobalConfig().get_niftynet_home_folder()
예제 #11
0
    def test_existing_niftynet_home_not_touched(self):
        niftynet_home = expanduser(
            NiftyNetGlobalConfigTest.default_config_opts['home'])
        makedirs(niftynet_home)
        niftynet_ext = join(
            niftynet_home, NiftyNetGlobalConfigTest.default_config_opts['ext'])
        makedirs(niftynet_ext)
        niftynet_ext_init = join(niftynet_ext, '__init__.py')
        open(niftynet_ext_init, 'w').close()
        mtime_before = getmtime(niftynet_ext_init)

        global_config = NiftyNetGlobalConfig()

        mtime_after = getmtime(niftynet_ext_init)
        self.assertEqual(mtime_before, mtime_after)
예제 #12
0
    def test_non_existing_niftynet_home_created(self):
        niftynet_home = expanduser(
            NiftyNetGlobalConfigTest.default_config_opts['home'])
        NiftyNetGlobalConfigTest.remove_path(niftynet_home)
        self.assertFalse(isdir(niftynet_home))
        niftynet_ext = join(
            niftynet_home, NiftyNetGlobalConfigTest.default_config_opts['ext'])
        self.assertFalse(isfile(join(niftynet_ext, '__init__.py')))
        for mod in NiftyNetGlobalConfigTest.default_config_opts['ext_mods']:
            self.assertFalse(isfile(join(niftynet_ext, mod, '__init__.py')))

        global_config = NiftyNetGlobalConfig().setup()

        self.assertTrue(isdir(niftynet_home))
        self.assertTrue(isfile(join(niftynet_ext, '__init__.py')))
        for mod in NiftyNetGlobalConfigTest.default_config_opts['ext_mods']:
            self.assertTrue(isfile(join(niftynet_ext, mod, '__init__.py')))
    def create_instance(cls, file_path, **kwargs):
        """
        Read image headers and create image instance.

        :param file_path: a file path or a sequence of file paths
        :param kwargs: output properties for transforming the image data
            array into a desired format
        :return: an image instance
        """
        if file_path is None:
            tf.logging.fatal('No file_path provided, '
                             'please check input sources in config file')
            raise ValueError

        ndims = 0
        image_type = None
        home_folder = NiftyNetGlobalConfig().get_niftynet_home_folder()
        try:
            file_path = resolve_file_name(file_path, ('.', home_folder))
            if os.path.isfile(file_path):
                loader = kwargs.get('loader', None) or None
                ndims = misc.infer_ndims_from_file(file_path, loader)
                image_type = cls.INSTANCE_DICT.get(ndims, None)
        except (TypeError, IOError, AttributeError):
            pass

        if image_type is None:
            try:
                file_path = [
                    resolve_file_name(path, ('.', home_folder))
                    for path in file_path
                ]
                loader = kwargs.get('loader', None) or (None, )
                ndims = misc.infer_ndims_from_file(file_path[0], loader[0])
                ndims = ndims + (1 if len(file_path) > 1 else 0)
                image_type = cls.INSTANCE_DICT.get(ndims, None)
            except (AssertionError, TypeError, IOError, AttributeError):
                tf.logging.fatal('Could not load file: %s', file_path)
                raise IOError
        if image_type is None:
            tf.logging.fatal('Not supported image type from:\n%s', file_path)
            raise NotImplementedError(
                "unrecognised spatial rank {}".format(ndims))
        return image_type(file_path, **kwargs)
 def test_non_existing_config_file_created(self):
     self.assertFalse(isfile(NiftyNetGlobalConfigTestBase.config_file))
     global_config = NiftyNetGlobalConfig().setup()
     self.assertTrue(isfile(NiftyNetGlobalConfigTestBase.config_file))
     self.assertEqual(global_config.get_niftynet_config_folder(),
                      NiftyNetGlobalConfigTestBase.config_home)
 def test_global_config_singleton(self):
     global_config_1 = NiftyNetGlobalConfig()
     global_config_2 = NiftyNetGlobalConfig()
     self.assertEqual(global_config_1, global_config_2)
     self.assertTrue(global_config_1 is global_config_2)
예제 #16
0
class ImageReader(Layer):
    """
    For a concrete example:
    _input_sources define multiple modality mappings, e.g.,
    _input_sources {'image': ('T1', 'T2'),
                    'label': ('manual_map',)}
    means
    'image' consists of two components, formed by
    concatenating 'T1' and 'T2' input source images.
    'label' consists of one component, loading from 'manual_map'

    self._names: a tuple of the output names of this reader.
    ('image', 'labels')

    self._shapes: the shapes after combining input sources
    {'image': (192, 160, 192, 1, 2), 'label': (192, 160, 192, 1, 1)}

    self._dtypes: store the dictionary of tensorflow shapes
    {'image': tf.float32, 'label': tf.float32}

    self.output_list is a list of dictionaries, with each item:
    {'image': <niftynet.io.image_type.SpatialImage4D object>,
     'label': <niftynet.io.image_type.SpatialImage3D object>}
    """
    def __init__(self, names):
        # list of file names
        self._file_list = None
        self._input_sources = None
        self._shapes = None
        self._dtypes = None
        self._names = None
        self.names = names

        self._global_config = NiftyNetGlobalConfig()

        # list of image objects
        self.output_list = None
        self.current_id = -1

        self.preprocessors = []
        super(ImageReader, self).__init__(name='image_reader')

    def initialise_reader(self, data_param, task_param):
        """
        task_param specifies how to combine user input modalities
        e.g., for multimodal segmentation 'image' corresponds to multiple
        modality sections, 'label' corresponds to one modality section
        """
        if not self.names:
            tf.logging.fatal('Please specify data names, this should '
                             'be a subset of SUPPORTED_INPUT provided '
                             'in application file')
            raise ValueError
        self._names = [
            name for name in self.names if vars(task_param).get(name, None)
        ]

        self._input_sources = {
            name: vars(task_param).get(name)
            for name in self.names
        }
        data_to_load = {}
        for name in self._names:
            for source in self._input_sources[name]:
                try:
                    data_to_load[source] = data_param[source]
                except KeyError:
                    tf.logging.fatal(
                        'reader name [%s] requires [%s], however it is not '
                        'specified as a section in the config, '
                        'current input section names: %s', name, source,
                        list(data_param))
                    raise ValueError

        default_data_folder = self._global_config.get_niftynet_home_folder()
        self._file_list = util_csv.load_and_merge_csv_files(
            data_to_load, default_data_folder)
        self.output_list = _filename_to_image_list(self._file_list,
                                                   self._input_sources,
                                                   data_param)
        for name in self.names:
            tf.logging.info('image reader: loading [%s] from %s (%d)', name,
                            self.input_sources[name], len(self.output_list))

    def prepare_preprocessors(self):
        for layer in self.preprocessors:
            if isinstance(layer, DataDependentLayer):
                layer.train(self.output_list)

    def add_preprocessing_layers(self, layers):
        assert self.output_list is not None, \
            'Please initialise the reader first, ' \
            'before adding preprocessors.'
        if isinstance(layers, Layer):
            self.preprocessors.append(layers)
        else:
            self.preprocessors.extend(layers)
        self.prepare_preprocessors()

    # pylint: disable=arguments-differ
    def layer_op(self, idx=None, shuffle=True):
        """
        this layer returns a dictionary
          keys: self.output_fields
          values: image volume array
        """
        if idx is None and shuffle:
            # training, with random list output
            idx = np.random.randint(len(self.output_list))

        if idx is None and not shuffle:
            # testing, with sequential output
            # accessing self.current_id, not suitable for multi-thread
            idx = self.current_id + 1
            self.current_id = idx

        try:
            idx = int(idx)
        except ValueError:
            idx = -1

        if idx < 0 or idx >= len(self.output_list):
            return -1, None, None

        image_dict = self.output_list[idx]
        image_data_dict = {
            field: image.get_data()
            for (field, image) in image_dict.items()
        }
        interp_order_dict = {
            field: image.interp_order
            for (field, image) in image_dict.items()
        }
        if self.preprocessors:
            preprocessors = [deepcopy(layer) for layer in self.preprocessors]
            # dictionary of masks is cached
            mask = None
            for layer in preprocessors:
                # import time; local_time = time.time()
                if layer is None:
                    continue
                if isinstance(layer, RandomisedLayer):
                    layer.randomise()
                    image_data_dict = layer(image_data_dict, interp_order_dict)
                else:
                    image_data_dict, mask = layer(image_data_dict, mask)
                # print('%s, %.3f sec'%(layer, -local_time + time.time()))
        return idx, image_data_dict, interp_order_dict

    @property
    def shapes(self):
        """
        image shapes before any preprocessing
        :return: tuple of integers as image shape
        """
        # to have fast access, the spatial dimensions are not accurate
        # 1) only read from the first image in list
        # 2) not considering effects of random augmentation layers
        # but time and modality dimensions should be correct
        if not self.output_list:
            tf.logging.fatal("please initialise the reader first")
            raise RuntimeError
        if not self._shapes:
            first_image = self.output_list[0]
            self._shapes = {
                field: first_image[field].shape
                for field in self.names
            }
        return self._shapes

    @property
    def tf_dtypes(self):
        if not self.output_list:
            tf.logging.fatal("please initialise the reader first")
            raise RuntimeError
        if not self._dtypes:
            first_image = self.output_list[0]
            self._dtypes = {
                field: infer_tf_dtypes(first_image[field])
                for field in self.names
            }
        return self._dtypes

    @property
    def input_sources(self):
        if not self._input_sources:
            tf.logging.fatal("please initialise the reader first")
            raise RuntimeError
        return self._input_sources

    @property
    def names(self):
        return self._names

    @names.setter
    def names(self, fields_tuple):
        # output_fields is a sequence of output names
        # each name might correspond to a list of multiple input sources
        # this should be specified in CUSTOM section in the config
        self._names = make_input_tuple(fields_tuple, string_types)

    def get_subject_id(self, image_index):
        return self._file_list.iloc[image_index, 0]
예제 #17
0
def run():
    """
    meta_parser is first used to find out location
    of the configuration file. based on the application_name
    or meta_parser.prog name, the section parsers are organised
    to find system parameters and application specific
    parameters

    :return: system parameters is a group of parameters including
        SYSTEM_SECTIONS and app_module.REQUIRED_CONFIG_SECTION
        input_data_args is a group of input data sources to be
        used by niftynet.io.ImageReader
    """
    meta_parser = argparse.ArgumentParser(
        epilog='Please visit '
               'https://cmiclab.cs.ucl.ac.uk/CMIC/NiftyNet/tree/dev/demos '
               'for more info.')
    version_string = get_niftynet_version_string()
    meta_parser.add_argument("-v", "--version",
                             action='version',
                             version=version_string)
    meta_parser.add_argument("-c", "--conf",
                             help="Specify configurations from a file",
                             metavar="File", )
    meta_parser.add_argument("-a", "--application_name",
                             help="Specify application name",
                             default="", )
    meta_args, args_from_cmdline = meta_parser.parse_known_args()
    print(version_string)

    # read configurations, to be parsed by sections
    if meta_args.conf is None:
        raise IOError("No configuration file has been provided")

    # Read global config file
    global_config = NiftyNetGlobalConfig()

    config_path = meta_args.conf
    if not os.path.isfile(config_path):
        relative_conf_file = os.path.join(global_config.get_default_examples_folder(), config_path,
                                          config_path + "_config.ini")
        if os.path.isfile(relative_conf_file):
            config_path = relative_conf_file
            os.chdir(os.path.dirname(config_path))
        else:
            raise IOError("Configuration file not found {}".format(config_path))

    config = configparser.ConfigParser()
    config.read([config_path])
    app_module = None
    try:
        if meta_parser.prog[:-3] in SUPPORTED_APP:
            module_name = meta_parser.prog[:-3]
        elif meta_parser.prog in SUPPORTED_APP:
            module_name = meta_parser.prog
        else:
            module_name = meta_args.application_name
        app_module = ApplicationFactory.create(module_name)
        assert app_module.REQUIRED_CONFIG_SECTION, \
            "REQUIRED_CONFIG_SECTION should be static variable " \
            "in {}".format(app_module)
        has_section_in_config(config, app_module.REQUIRED_CONFIG_SECTION)
    except ValueError:
        if app_module:
            section_name = app_module.REQUIRED_CONFIG_SECTION
            raise ValueError(
                '{} requires [{}] section in the config file'.format(
                    module_name, section_name))
        else:
            raise ValueError(
                "unknown application {}, or did you forget '-a' "
                "command argument".format(module_name))

    # check keywords in configuration file
    check_keywords(config)

    # using configuration as default, and parsing all command line arguments
    all_args = {}
    for section in config.sections():
        # try to rename user-specified sections for consistency
        section = standardise_section_name(config, section)
        section_defaults = dict(config.items(section))
        section_args, args_from_cmdline = \
            _parse_arguments_by_section([],
                                        section,
                                        section_defaults,
                                        args_from_cmdline,
                                        app_module.REQUIRED_CONFIG_SECTION)
        all_args[section] = section_args
    # command line parameters should be valid
    assert not args_from_cmdline, \
        'unknown parameter: {}'.format(args_from_cmdline)

    # split parsed results in all_args
    # into dictionary of system_args and input_data_args
    system_args = {}
    input_data_args = {}
    for section in all_args:
        if section in SYSTEM_SECTIONS:
            system_args[section] = all_args[section]
        elif section == app_module.REQUIRED_CONFIG_SECTION:
            system_args['CUSTOM'] = all_args[section]
            vars(system_args['CUSTOM'])['name'] = module_name

    if all_args['SYSTEM'].model_dir is None:
        all_args['SYSTEM'].model_dir = os.path.join(
            os.path.dirname(meta_args.conf), 'model')

    for section in all_args:
        if section in SYSTEM_SECTIONS:
            continue
        if section == app_module.REQUIRED_CONFIG_SECTION:
            continue
        input_data_args[section] = all_args[section]
        # set the output path of csv list if not exists
        csv_path = input_data_args[section].csv_file
        if not os.path.isfile(csv_path):
            csv_filename = os.path.join(
                all_args['SYSTEM'].model_dir, '{}.csv'.format(section))
            input_data_args[section].csv_file = csv_filename
        else:
            # don't search files if csv specified in config
            try:
                delattr(input_data_args[section], 'path_to_search')
            except AttributeError:
                pass

    # update conf path
    system_args['CONFIG_FILE'] = argparse.Namespace(path=meta_args.conf)
    return system_args, input_data_args
예제 #18
0
    default_keys = []
    for action in all_key_parser._actions:
        try:
            default_keys.append(action.option_strings[0][2:])
        except (IndexError, AttributeError, ValueError):
            pass
    # remove duplicates
    default_keys = list(set(default_keys))
    # remove bad names
    default_keys = [keyword for keyword in default_keys if keyword]
    return default_keys


KEYWORDS = available_keywords()
NIFTYNET_HOME = NiftyNetGlobalConfig().get_niftynet_home_folder()


# pylint: disable=too-many-branches
def run():
    """
    meta_parser is first used to find out location
    of the configuration file. Based on the application_name
    or meta_parser.prog name, the section parsers are organised
    to find system parameters and application specific
    parameters.

    :return: system parameters is a group of parameters including
        SYSTEM_SECTIONS and app_module.REQUIRED_CONFIG_SECTION
        input_data_args is a group of input data sources to be
        used by niftynet.io.ImageReader
예제 #19
0
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os, sys, unittest
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import tensorflow as tf

from niftynet.utilities.download import download
from niftynet.utilities.niftynet_global_config import NiftyNetGlobalConfig
from niftynet import main as niftynet_main
from niftynet.application.base_application import SingletonApplication

MODEL_HOME = NiftyNetGlobalConfig().get_niftynet_home_folder()


def net_run_with_sys_argv(argv):
    # for gift-adelie
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    SingletonApplication.clear()
    cache = sys.argv
    argv.extend(['--cuda_devices', '0'])
    sys.argv = argv
    niftynet_main()
    sys.argv = cache


@unittest.skipIf(
 def test_non_existing_config_file_created(self):
     self.assertFalse(isfile(NiftyNetGlobalConfigTest.config_file))
     global_config = NiftyNetGlobalConfig().setup()
     self.assertTrue(isfile(NiftyNetGlobalConfigTest.config_file))
     self.assertEqual(global_config.get_niftynet_config_folder(),
                      NiftyNetGlobalConfigTest.config_home)
예제 #21
0
def run():
    """
    meta_parser is first used to find out location
    of the configuration file. based on the application_name
    or meta_parser.prog name, the section parsers are organised
    to find system parameters and application specific
    parameters.

    :return: system parameters is a group of parameters including
        SYSTEM_SECTIONS and app_module.REQUIRED_CONFIG_SECTION
        input_data_args is a group of input data sources to be
        used by niftynet.io.ImageReader
    """
    meta_parser = argparse.ArgumentParser(
        description="Launch a NiftyNet application.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=textwrap.dedent(epilog_string))
    version_string = get_niftynet_version_string()
    meta_parser.add_argument("action",
                             help="train networks or run inferences",
                             metavar='ACTION',
                             choices=['train', 'inference'])
    meta_parser.add_argument("-v",
                             "--version",
                             action='version',
                             version=version_string)
    meta_parser.add_argument("-c",
                             "--conf",
                             help="specify configurations from a file",
                             metavar="CONFIG_FILE")
    meta_parser.add_argument("-a",
                             "--application_name",
                             help="specify an application module name",
                             metavar='APPLICATION_NAME',
                             default="")

    meta_args, args_from_cmdline = meta_parser.parse_known_args()
    print(version_string)

    # read configurations, to be parsed by sections
    if not meta_args.conf:
        print("\nNo configuration file has been provided, did you "
              "forget '-c' command argument?{}".format(epilog_string))
        raise IOError

    # Resolve relative configuration file location
    config_path = meta_args.conf
    if not os.path.isfile(config_path):
        relative_conf_file = os.path.join(
            NiftyNetGlobalConfig().get_default_examples_folder(), config_path,
            config_path + "_config.ini")
        if os.path.isfile(relative_conf_file):
            config_path = relative_conf_file
            os.chdir(os.path.dirname(config_path))
        else:
            print("\nConfiguration file not found: {}.{}".format(
                config_path, epilog_string))
            raise IOError

    config = configparser.ConfigParser()
    config.read([config_path])
    app_module = None
    module_name = None
    try:
        if meta_parser.prog[:-3] in SUPPORTED_APP:
            module_name = meta_parser.prog[:-3]
        elif meta_parser.prog in SUPPORTED_APP:
            module_name = meta_parser.prog
        else:
            module_name = meta_args.application_name
        app_module = ApplicationFactory.create(module_name)
        assert app_module.REQUIRED_CONFIG_SECTION, \
            "\nREQUIRED_CONFIG_SECTION should be static variable " \
            "in {}".format(app_module)
        has_section_in_config(config, app_module.REQUIRED_CONFIG_SECTION)
    except ValueError:
        if app_module:
            section_name = app_module.REQUIRED_CONFIG_SECTION
            print('\n{} requires [{}] section in the config file.{}'.format(
                module_name, section_name, epilog_string))
        if not module_name:
            print("\nUnknown application {}, or did you forget '-a' "
                  "command argument?{}".format(module_name, epilog_string))
        raise

    # check keywords in configuration file
    check_keywords(config)

    # using configuration as default, and parsing all command line arguments
    all_args = {}
    for section in config.sections():
        # try to rename user-specified sections for consistency
        section = standardise_section_name(config, section)
        section_defaults = dict(config.items(section))
        section_args, args_from_cmdline = \
            _parse_arguments_by_section([],
                                        section,
                                        section_defaults,
                                        args_from_cmdline,
                                        app_module.REQUIRED_CONFIG_SECTION)
        all_args[section] = section_args
    # command line parameters should be valid
    assert not args_from_cmdline, \
        '\nUnknown parameter: {}{}'.format(args_from_cmdline, epilog_string)

    # split parsed results in all_args
    # into dictionaries of system_args and input_data_args
    system_args = {}
    input_data_args = {}

    # copy system default sections to ``system_args``
    for section in all_args:
        if section in SYSTEM_SECTIONS:
            system_args[section] = all_args[section]
        elif section == app_module.REQUIRED_CONFIG_SECTION:
            system_args['CUSTOM'] = all_args[section]
            vars(system_args['CUSTOM'])['name'] = module_name

    if all_args['SYSTEM'].model_dir is None:
        all_args['SYSTEM'].model_dir = os.path.join(
            os.path.dirname(meta_args.conf), 'model')

    # copy non-default sections to ``input_data_args``
    for section in all_args:
        if section in SYSTEM_SECTIONS:
            continue
        if section == app_module.REQUIRED_CONFIG_SECTION:
            continue
        input_data_args[section] = all_args[section]
        # set the output path of csv list if not exists
        csv_path = input_data_args[section].csv_file
        if os.path.isfile(csv_path):
            # don't search files if csv specified in config
            try:
                delattr(input_data_args[section], 'path_to_search')
            except AttributeError:
                pass
        else:
            input_data_args[section].csv_file = ''

    # preserve ``config_file`` and ``action parameter`` from the meta_args
    system_args['CONFIG_FILE'] = argparse.Namespace(path=meta_args.conf)
    system_args['SYSTEM'].action = meta_args.action
    return system_args, input_data_args
예제 #22
0
class ImageSetsPartitioner(object):
    """
    This class maintains a pandas.dataframe of filenames for all input sections.
    The list of filenames are obtained by searching the specified folders
    or loading from an existing csv file.

    Users can query a subset of the dataframe by train/valid/infer partition
    label and input section names.
    """

    # dataframe (table) of file names in a shape of subject x modality
    _file_list = None
    # dataframes of subject_id:phase_id
    _partition_ids = None

    data_param = None
    ratios = None
    new_partition = False

    # for saving the splitting index
    data_split_file = ""
    # default parent folder location for searching the image files
    default_image_file_location = \
        NiftyNetGlobalConfig().get_niftynet_home_folder()

    def initialise(self,
                   data_param,
                   new_partition=False,
                   data_split_file="./test.csv",
                   ratios=None):
        """
        Set the data partitioner parameters
        data_param: corresponding to all config sections
        new_partition: bool value indicating whether to generate new
                       partition ids and overwrite csv file
                       (this class will write partition file iff new_partition)
        data_split_file: location of the partition id file
        ratios: a tuple/list with two elements:
               (fraction of the validation set,
                fraction of the inference set)
               initialise to None will disable data partitioning
               and get_file_list always returns all subjects.
        """
        self.data_param = data_param
        self.data_split_file = data_split_file
        self.ratios = ratios

        self._file_list = None
        self._partition_ids = None

        self.load_data_sections_by_subject()
        self.new_partition = new_partition
        self.randomly_split_dataset(overwrite=new_partition)
        tf.logging.info(self)
        return self

    def number_of_subjects(self, phase=ALL):
        """
        query number of images according to phase
        :param phase:
        :return:
        """
        if self._file_list is None:
            return 0
        phase = look_up_operations(phase, SUPPORTED_PHASES)
        if phase == ALL:
            return self._file_list[COLUMN_UNIQ_ID].count()
        if self._partition_ids is None:
            return 0
        selector = self._partition_ids[COLUMN_PHASE] == phase
        return self._partition_ids[selector].count()[COLUMN_UNIQ_ID]

    def get_file_list(self, phase=ALL, *section_names):
        """
        get file names as a dataframe, by partitioning phase and section names
        set phase to ALL to load all subsets.

        :param phase: the label of the subset generated by self._partition_ids
                    should be one of the SUPPORTED_PHASES
        :param section_names: one or multiple input section names
        :return: a pandas.dataframe of file names
        """
        if self._file_list is None:
            tf.logging.warning('Empty file list, please initialise'
                               'ImageSetsPartitioner first.')
            return []
        try:
            look_up_operations(phase, SUPPORTED_PHASES)
        except ValueError:
            tf.logging.fatal('Unknown phase argument.')
            raise
        for name in section_names:
            try:
                look_up_operations(name, set(self._file_list))
            except ValueError:
                tf.logging.fatal(
                    'Requesting files under input section [%s],\n'
                    'however the section does not exist in the config.', name)
                raise
        if phase == ALL:
            self._file_list = self._file_list.sort_index()
            if section_names:
                section_names = [COLUMN_UNIQ_ID] + list(section_names)
                return self._file_list[section_names]
            return self._file_list
        if self._partition_ids is None:
            tf.logging.fatal('No partition ids available.')
            if self.new_partition:
                tf.logging.fatal(
                    'Unable to create new partitions,'
                    'splitting ratios: %s, writing file %s', self.ratios,
                    self.data_split_file)
            elif os.path.isfile(self.data_split_file):
                tf.logging.fatal(
                    'Unable to load %s, initialise the'
                    'ImageSetsPartitioner with `new_partition=True`'
                    'to overwrite the file.', self.data_split_file)
            raise ValueError

        selector = self._partition_ids[COLUMN_PHASE] == phase
        selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]]
        if selected.empty:
            tf.logging.warning(
                'Empty subset for phase [%s], returning None as file list. '
                'Please adjust splitting fractions.', phase)
            return None
        subset = pandas.merge(self._file_list, selected, on=COLUMN_UNIQ_ID)
        if section_names:
            section_names = [COLUMN_UNIQ_ID] + list(section_names)
            return subset[list(section_names)]
        return subset

    def load_data_sections_by_subject(self):
        """
        Go through all input data sections, converting each section
        to a list of file names.  These lists are merged on COLUMN_UNIQ_ID

        This function sets self._file_list
        """
        if not self.data_param:
            tf.logging.fatal(
                'Nothing to load, please check input sections in the config.')
            raise ValueError
        self._file_list = None
        for section_name in self.data_param:
            modality_file_list = self.grep_files_by_data_section(section_name)
            if self._file_list is None:
                # adding all rows of the first modality
                self._file_list = modality_file_list
                continue
            n_rows = self._file_list[COLUMN_UNIQ_ID].count()
            self._file_list = pandas.merge(self._file_list,
                                           modality_file_list,
                                           how='outer',
                                           on=COLUMN_UNIQ_ID)
            if self._file_list[COLUMN_UNIQ_ID].count() < n_rows:
                tf.logging.warning('rows not matched in section [%s]',
                                   section_name)

        if self._file_list is None or self._file_list.size == 0:
            tf.logging.fatal(
                "empty filename lists, please check the csv "
                "files. (removing csv_file keyword if it is in the config file "
                "to automatically search folders and generate new csv "
                "files again)\n\n"
                "Please note in the matched file names, each subject id are "
                "created by removing all keywords listed `filename_contains` "
                "in the config.\n\n"
                "E.g., `filename_contains=foo, bar` will match file "
                "foo_subject42_bar.nii.gz, and the subject id is _subject42_.")
            raise IOError

    def grep_files_by_data_section(self, modality_name):
        """
        list all files by a given input data section,
        if the `csv_file` property of the section corresponds to a file,
            read the list from the file;
        otherwise
            write the list to `csv_file`.

        returns: a table with two columns,
                 the column names are (COLUMN_UNIQ_ID, modality_name)
        """
        if modality_name not in self.data_param:
            tf.logging.fatal(
                'unknown section name [%s], '
                'current input section names: %s.', modality_name,
                list(self.data_param))
            raise ValueError

        # input data section must have a `csv_file` section for loading
        # or writing filename lists
        try:
            csv_file = self.data_param[modality_name].csv_file
        except AttributeError:
            tf.logging.fatal('Missing `csv_file` field in the config file, '
                             'unknown configuration format.')
            raise

        if hasattr(self.data_param[modality_name], 'path_to_search') and \
                self.data_param[modality_name].path_to_search:
            tf.logging.info('[%s] search file folders, writing csv file %s',
                            modality_name, csv_file)
            section_properties = self.data_param[modality_name].__dict__.items(
            )
            # grep files by section properties and write csv
            try:
                matcher = KeywordsMatching.from_tuple(
                    section_properties, self.default_image_file_location)
                match_and_write_filenames_to_csv([matcher], csv_file)
            except ValueError as reading_error:
                tf.logging.warning(
                    'Ignoring input section: [%s], '
                    'due to the following error:', modality_name)
                tf.logging.warning(repr(reading_error))
                return pandas.DataFrame(
                    columns=[COLUMN_UNIQ_ID, modality_name])
        else:
            tf.logging.info(
                '[%s] using existing csv file %s, skipped filenames search',
                modality_name, csv_file)

        if not os.path.isfile(csv_file):
            tf.logging.fatal('[%s] csv file %s not found.', modality_name,
                             csv_file)
            raise IOError
        try:
            csv_list = pandas.read_csv(csv_file,
                                       header=None,
                                       dtype=(str, str),
                                       names=[COLUMN_UNIQ_ID, modality_name],
                                       skipinitialspace=True)
        except Exception as csv_error:
            tf.logging.fatal(repr(csv_error))
            raise
        return csv_list

    # pylint: disable=broad-except
    def randomly_split_dataset(self, overwrite=False):
        """
        Label each subject as one of the 'TRAIN', 'VALID', 'INFER',
        use self.ratios to compute the size of each set.
        the results will be written to self.data_split_file if overwrite
        otherwise it tries to read partition labels from it.

        This function sets self._partition_ids
        """
        if overwrite:
            try:
                valid_fraction, infer_fraction = self.ratios
                valid_fraction = max(min(1.0, float(valid_fraction)), 0.0)
                infer_fraction = max(min(1.0, float(infer_fraction)), 0.0)
            except (TypeError, ValueError):
                tf.logging.fatal('Unknown format of faction values %s',
                                 self.ratios)
                raise

            if (valid_fraction + infer_fraction) <= 0:
                tf.logging.warning(
                    'To split dataset into training/validation, '
                    'please make sure '
                    '"exclude_fraction_for_validation" parameter is set to '
                    'a float in between 0 and 1. Current value: %s.',
                    valid_fraction)
                # raise ValueError

            n_total = self.number_of_subjects()
            n_valid = int(math.ceil(n_total * valid_fraction))
            n_infer = int(math.ceil(n_total * infer_fraction))
            n_train = int(n_total - n_infer - n_valid)
            phases = [TRAIN] * n_train + \
                     [VALID] * n_valid + \
                     [INFER] * n_infer
            if len(phases) > n_total:
                phases = phases[:n_total]
            random.shuffle(phases)
            write_csv(self.data_split_file,
                      zip(self._file_list[COLUMN_UNIQ_ID], phases))
        elif os.path.isfile(self.data_split_file):
            tf.logging.warning(
                'Loading from existing partitioning file %s, '
                'ignoring partitioning ratios.', self.data_split_file)

        if os.path.isfile(self.data_split_file):
            try:
                self._partition_ids = pandas.read_csv(
                    self.data_split_file,
                    header=None,
                    dtype=(str, str),
                    names=[COLUMN_UNIQ_ID, COLUMN_PHASE],
                    skipinitialspace=True)
                assert not self._partition_ids.empty, \
                    "partition file is empty."
            except Exception as csv_error:
                tf.logging.warning(
                    "Unable to load the existing partition file %s, %s",
                    self.data_split_file, repr(csv_error))
                self._partition_ids = None

            try:
                is_valid_phase = \
                    self._partition_ids[COLUMN_PHASE].isin(SUPPORTED_PHASES)
                assert is_valid_phase.all(), \
                    "Partition file contains unknown phase id."
            except (TypeError, AssertionError):
                tf.logging.warning(
                    'Please make sure the values of the second column '
                    'of data splitting file %s, in the set of phases: %s.\n'
                    'Remove %s to generate random data partition file.',
                    self.data_split_file, SUPPORTED_PHASES,
                    self.data_split_file)
                raise ValueError

    def __str__(self):
        return self.to_string()

    def to_string(self):
        """
        Print summary of the partitioner
        """
        n_subjects = self.number_of_subjects()
        summary_str = '\nNumber of subjects {}, '.format(n_subjects)
        if self._file_list is not None:
            summary_str += 'input section names: {}\n'.format(
                list(self._file_list))
        if self._partition_ids is not None and n_subjects > 0:
            n_valid = self.number_of_subjects(VALID)
            n_train = self.number_of_subjects(TRAIN)
            n_infer = self.number_of_subjects(INFER)
            summary_str += \
                'data partitioning -- number of cases:\n' \
                '-- {} {} ({:.2f}%),\n' \
                '-- {} {} ({:.2f}%),\n' \
                '-- {} {} ({:.2f}%).\n'.format(
                    VALID, n_valid, float(n_valid) / float(n_subjects) * 100.0,
                    TRAIN, n_train, float(n_train) / float(n_subjects) * 100.0,
                    INFER, n_infer, float(n_infer) / float(n_subjects) * 100.0)
        else:
            summary_str += '-- using all subjects ' \
                           '(without data partitioning).\n'
        return summary_str

    def has_phase(self, phase):
        """
        returns True if the `phase` subset of images is not empty
        """
        if self._partition_ids is None:
            return False
        return (self._partition_ids[COLUMN_PHASE] == phase).any()

    @property
    def has_training(self):
        """
        returns True if the TRAIN subset of images is not empty
        """
        return self.has_phase(TRAIN)

    @property
    def has_inference(self):
        """
        returns True if the INFER subset of images is not empty
        """
        return self.has_phase(INFER)

    @property
    def has_validation(self):
        """
        returns True if the VALID subset of images is not empty
        """
        return self.has_phase(VALID)

    @property
    def validation_files(self):
        """
        returns the list of validation filenames
        """
        return self.get_file_list(VALID)

    @property
    def train_files(self):
        """
        returns the list of training filenames
        """
        return self.get_file_list(TRAIN)

    @property
    def inference_files(self):
        """
        returns the list of inference filenames
        (defaulting to list of all filenames if no partition definition)
        """
        if self.has_inference:
            return self.get_file_list(INFER)
        return self.all_files

    @property
    def all_files(self):
        """
        returns list of all filenames
        """
        return self.get_file_list()

    def reset(self):
        """
        reset all fields of this singleton class
        """
        self._file_list = None
        self._partition_ids = None
        self.data_param = None
        self.ratios = None
        self.new_partition = False

        self.data_split_file = ""
        self.default_image_file_location = \
            NiftyNetGlobalConfig().get_niftynet_home_folder()
예제 #23
0
def resolve_module_dir(module_dir_str, create_new=False):
    """
    Interpret `module_dir_str` as an absolute folder path.
    create the folder if `create_new`

    :param module_dir_str:
    :param create_new:
    :return:
    """
    try:
        # interpret input as a module string
        module_from_string = importlib.import_module(module_dir_str)
        folder_path = os.path.dirname(module_from_string.__file__)
        return os.path.abspath(folder_path)
    except (ImportError, AttributeError, TypeError):
        pass

    try:
        # interpret last part of input as a module string
        string_last_part = module_dir_str.rsplit('.', 1)
        module_from_string = importlib.import_module(string_last_part[-1])
        folder_path = os.path.dirname(module_from_string.__file__)
        return os.path.abspath(folder_path)
    except (ImportError, AttributeError, IndexError, TypeError):
        pass

    module_dir_str = os.path.expanduser(module_dir_str)
    try:
        # interpret input as a file folder path string
        if os.path.isdir(module_dir_str):
            return os.path.abspath(module_dir_str)
    except TypeError:
        pass

    try:
        # interpret input as a file path string
        if os.path.isfile(module_dir_str):
            return os.path.abspath(os.path.dirname(module_dir_str))
    except TypeError:
        pass

    try:
        # interpret input as a path string relative to the global home
        home_location = NiftyNetGlobalConfig().get_niftynet_home_folder()
        possible_dir = os.path.join(home_location, module_dir_str)
        if os.path.isdir(possible_dir):
            return os.path.abspath(possible_dir)
    except (TypeError, ImportError, AttributeError):
        pass

    if create_new:
        # try to create the folder
        folder_path = touch_folder(module_dir_str)
        init_file = os.path.join(folder_path, '__init__.py')
        try:
            file_ = os.open(init_file, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
        except OSError as sys_error:
            if sys_error.errno == errno.EEXIST:
                pass
            else:
                tf.logging.fatal(
                    "trying to use '{}' as NiftyNet writing path, "
                    "however cannot write '{}'".format(
                        folder_path, init_file))
                raise
        else:
            with os.fdopen(file_, 'w') as file_object:
                file_object.write("# Created automatically\n")
        return folder_path
    else:
        raise ValueError(
            "Could not resolve [{}].\nMake sure it is a valid folder path "
            "or a module name.\nIf it is string representing a module, "
            "the parent folder of [{}] should be on "
            "the system path.\n\nCurrent system path {}.".format(
                module_dir_str, module_dir_str, sys.path))
class ImageSetsPartitioner(object):
    """
    This class maintains a pandas.dataframe of filenames for all input sections.
    The list of filenames are obtained by searching the specified folders
    or loading from an existing csv file.

    Users can query a subset of the dataframe by train/valid/infer partition
    label and input section names.
    """

    # dataframe (table) of file names in a shape of subject x modality
    _file_list = None
    # dataframes of subject_id:phase_id
    _partition_ids = None

    data_param = None
    ratios = None
    new_partition = False

    # for saving the splitting index
    data_split_file = ""
    # default parent folder location for searching the image files
    default_image_file_location = \
        NiftyNetGlobalConfig().get_niftynet_home_folder()

    def initialise(self,
                   data_param,
                   new_partition=False,
                   data_split_file=None,
                   ratios=None):
        """
        Set the data partitioner parameters

        :param data_param: corresponding to all config sections
        :param new_partition: bool value indicating whether to generate new
            partition ids and overwrite csv file
            (this class will write partition file iff new_partition)
        :param data_split_file: location of the partition id file
        :param ratios: a tuple/list with two elements:
            ``(fraction of the validation set, fraction of the inference set)``
            initialise to None will disable data partitioning
            and get_file_list always returns all subjects.
        """
        self.data_param = data_param
        if data_split_file is None:
            self.data_split_file = os.path.join('.', 'dataset_split.csv')
        else:
            self.data_split_file = data_split_file
        self.ratios = ratios

        self._file_list = None
        self._partition_ids = None

        self.load_data_sections_by_subject()
        self.new_partition = new_partition
        self.randomly_split_dataset(overwrite=new_partition)
        tf.logging.info(self)
        return self

    def number_of_subjects(self, phase=ALL):
        """
        query number of images according to phase.

        :param phase:
        :return:
        """
        if self._file_list is None:
            return 0
        try:
            phase = look_up_operations(phase.lower(), SUPPORTED_PHASES)
        except (ValueError, AttributeError):
            tf.logging.fatal('Unknown phase argument.')
            raise

        if phase == ALL:
            return self._file_list[COLUMN_UNIQ_ID].count()
        if self._partition_ids is None:
            return 0
        selector = self._partition_ids[COLUMN_PHASE] == phase
        selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]]
        subset = pandas.merge(self._file_list,
                              selected,
                              on=COLUMN_UNIQ_ID,
                              sort=True)
        return subset.count()[COLUMN_UNIQ_ID]

    def get_file_list(self, phase=ALL, *section_names):
        """
        get file names as a dataframe, by partitioning phase and section names
        set phase to ALL to load all subsets.

        :param phase: the label of the subset generated by self._partition_ids
                    should be one of the SUPPORTED_PHASES
        :param section_names: one or multiple input section names
        :return: a pandas.dataframe of file names
        """
        if self._file_list is None:
            tf.logging.warning('Empty file list, please initialise'
                               'ImageSetsPartitioner first.')
            return []
        try:
            phase = look_up_operations(phase.lower(), SUPPORTED_PHASES)
        except (ValueError, AttributeError):
            tf.logging.fatal('Unknown phase argument.')
            raise

        for name in section_names:
            try:
                look_up_operations(name, set(self._file_list))
            except ValueError:
                tf.logging.fatal(
                    'Requesting files under input section [%s],\n'
                    'however the section does not exist in the config.', name)
                raise
        if phase == ALL:
            self._file_list = self._file_list.sort_values(COLUMN_UNIQ_ID)
            if section_names:
                section_names = [COLUMN_UNIQ_ID] + list(section_names)
                return self._file_list[section_names]
            return self._file_list
        if self._partition_ids is None or self._partition_ids.empty:
            tf.logging.fatal('No partition ids available.')
            if self.new_partition:
                tf.logging.fatal(
                    'Unable to create new partitions,'
                    'splitting ratios: %s, writing file %s', self.ratios,
                    self.data_split_file)
            elif os.path.isfile(self.data_split_file):
                tf.logging.fatal(
                    'Unable to load %s, initialise the'
                    'ImageSetsPartitioner with `new_partition=True`'
                    'to overwrite the file.', self.data_split_file)
            raise ValueError

        selector = self._partition_ids[COLUMN_PHASE] == phase
        selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]]
        if selected.empty:
            tf.logging.warning(
                'Empty subset for phase [%s], returning None as file list. '
                'Please adjust splitting fractions.', phase)
            return None
        subset = pandas.merge(self._file_list,
                              selected,
                              on=COLUMN_UNIQ_ID,
                              sort=True)
        if subset.empty:
            tf.logging.warning(
                'No subject id matched in between file names and '
                'partition files.\nPlease check the partition files %s,\nor '
                'removing it to generate a new file automatically.',
                self.data_split_file)
        if section_names:
            section_names = [COLUMN_UNIQ_ID] + list(section_names)
            return subset[section_names]
        return subset

    def load_data_sections_by_subject(self):
        """
        Go through all input data sections, converting each section
        to a list of file names.

        These lists are merged on ``COLUMN_UNIQ_ID``.

        This function sets ``self._file_list``.
        """
        if not self.data_param:
            tf.logging.fatal(
                'Nothing to load, please check input sections in the config.')
            raise ValueError
        self._file_list = None
        for section_name in self.data_param:
            modality_file_list = self.grep_files_by_data_section(section_name)
            if self._file_list is None:
                # adding all rows of the first modality
                self._file_list = modality_file_list
                continue
            n_rows = self._file_list[COLUMN_UNIQ_ID].count()
            self._file_list = pandas.merge(self._file_list,
                                           modality_file_list,
                                           how='outer',
                                           on=COLUMN_UNIQ_ID)
            if self._file_list[COLUMN_UNIQ_ID].count() < n_rows:
                tf.logging.warning('rows not matched in section [%s]',
                                   section_name)

        if self._file_list is None or self._file_list.size == 0:
            tf.logging.fatal(
                "Empty filename lists, please check the csv "
                "files (removing csv_file keyword if it is in the config file "
                "to automatically search folders and generate new csv "
                "files again).\n\n"
                "Please note in the matched file names, each subject id are "
                "created by removing all keywords listed `filename_contains` "
                "in the config.\n"
                "E.g., `filename_contains=foo, bar` will match file "
                "foo_subject42_bar.nii.gz, and the subject id is "
                "_subject42_.\n\n")
            raise IOError

    def grep_files_by_data_section(self, modality_name):
        """
        list all files by a given input data section::
            if the ``csv_file`` property of ``data_param[modality_name]``
            corresponds to a file, read the list from the file;
            otherwise
                write the list to ``csv_file``.

        :return: a table with two columns,
                 the column names are ``(COLUMN_UNIQ_ID, modality_name)``.
        """
        if modality_name not in self.data_param:
            tf.logging.fatal(
                'unknown section name [%s], '
                'current input section names: %s.', modality_name,
                list(self.data_param))
            raise ValueError

        # input data section must have a ``csv_file`` section for loading
        # or writing filename lists
        if isinstance(self.data_param[modality_name], dict):
            mod_spec = self.data_param[modality_name]
        else:
            mod_spec = vars(self.data_param[modality_name])

        #########################
        # guess the csv_file path
        #########################
        temp_csv_file = None
        try:
            csv_file = os.path.expanduser(mod_spec.get('csv_file', None))
            if not os.path.isfile(csv_file):
                # writing to the same folder as data_split_file
                default_csv_file = os.path.join(
                    os.path.dirname(self.data_split_file),
                    '{}.csv'.format(modality_name))
                tf.logging.info(
                    '`csv_file = %s` not found, '
                    'writing to "%s" instead.', csv_file, default_csv_file)
                csv_file = default_csv_file
                if os.path.isfile(csv_file):
                    tf.logging.info('Overwriting existing: "%s".', csv_file)
            csv_file = os.path.abspath(csv_file)
        except (AttributeError, KeyError, TypeError):
            tf.logging.debug('`csv_file` not specified, writing the list of '
                             'filenames to a temporary file.')
            import tempfile
            temp_csv_file = os.path.join(tempfile.mkdtemp(),
                                         '{}.csv'.format(modality_name))
            csv_file = temp_csv_file

        #############################################
        # writing csv file if path_to_search specified
        ##############################################
        if mod_spec.get('path_to_search', None):
            if not temp_csv_file:
                tf.logging.info(
                    '[%s] search file folders, writing csv file %s',
                    modality_name, csv_file)
            # grep files by section properties and write csv
            try:
                matcher = KeywordsMatching.from_dict(
                    input_dict=mod_spec,
                    default_folder=self.default_image_file_location)
                match_and_write_filenames_to_csv([matcher], csv_file)
            except (IOError, ValueError) as reading_error:
                tf.logging.warning(
                    'Ignoring input section: [%s], '
                    'due to the following error:', modality_name)
                tf.logging.warning(repr(reading_error))
                return pandas.DataFrame(
                    columns=[COLUMN_UNIQ_ID, modality_name])
        else:
            tf.logging.info(
                '[%s] using existing csv file %s, skipped filenames search',
                modality_name, csv_file)

        if not os.path.isfile(csv_file):
            tf.logging.fatal('[%s] csv file %s not found.', modality_name,
                             csv_file)
            raise IOError
        ###############################
        # loading the file as dataframe
        ###############################
        try:
            csv_list = pandas.read_csv(csv_file,
                                       header=None,
                                       dtype=(str, str),
                                       names=[COLUMN_UNIQ_ID, modality_name],
                                       skipinitialspace=True)
        except Exception as csv_error:
            tf.logging.fatal(repr(csv_error))
            raise

        if temp_csv_file:
            shutil.rmtree(os.path.dirname(temp_csv_file), ignore_errors=True)

        return csv_list

    # pylint: disable=broad-except
    def randomly_split_dataset(self, overwrite=False):
        """
        Label each subject as one of the ``TRAIN``, ``VALID``, ``INFER``,
        use ``self.ratios`` to compute the size of each set.

        The results will be written to ``self.data_split_file`` if overwrite
        otherwise it tries to read partition labels from it.

        This function sets ``self._partition_ids``.
        """
        if overwrite:
            try:
                valid_fraction, infer_fraction = self.ratios
                valid_fraction = max(min(1.0, float(valid_fraction)), 0.0)
                infer_fraction = max(min(1.0, float(infer_fraction)), 0.0)
            except (TypeError, ValueError):
                tf.logging.fatal('Unknown format of faction values %s',
                                 self.ratios)
                raise

            if (valid_fraction + infer_fraction) <= 0:
                tf.logging.warning(
                    'To split dataset into training/validation, '
                    'please make sure '
                    '"exclude_fraction_for_validation" parameter is set to '
                    'a float in between 0 and 1. Current value: %s.',
                    valid_fraction)
                # raise ValueError

            n_total = self.number_of_subjects()
            n_valid = int(math.ceil(n_total * valid_fraction))
            n_infer = int(math.ceil(n_total * infer_fraction))
            n_train = int(n_total - n_infer - n_valid)
            phases = [TRAIN] * n_train + [VALID] * n_valid + [INFER] * n_infer
            if len(phases) > n_total:
                phases = phases[:n_total]
            random.shuffle(phases)
            write_csv(self.data_split_file,
                      zip(self._file_list[COLUMN_UNIQ_ID], phases))
        elif os.path.isfile(self.data_split_file):
            tf.logging.warning(
                'Loading from existing partitioning file %s, '
                'ignoring partitioning ratios.', self.data_split_file)

        if os.path.isfile(self.data_split_file):
            try:
                self._partition_ids = pandas.read_csv(
                    self.data_split_file,
                    header=None,
                    dtype=(str, str),
                    names=[COLUMN_UNIQ_ID, COLUMN_PHASE],
                    skipinitialspace=True)
                assert not self._partition_ids.empty, \
                    "partition file is empty."
            except Exception as csv_error:
                tf.logging.warning(
                    "Unable to load the existing partition file %s, %s",
                    self.data_split_file, repr(csv_error))
                self._partition_ids = None

            try:
                phase_strings = self._partition_ids[COLUMN_PHASE]
                phase_strings = phase_strings.astype(str).str.lower()
                is_valid_phase = phase_strings.isin(SUPPORTED_PHASES)
                assert is_valid_phase.all(), \
                    "Partition file contains unknown phase id."
                self._partition_ids[COLUMN_PHASE] = phase_strings
            except (TypeError, AssertionError):
                tf.logging.warning(
                    'Please make sure the values of the second column '
                    'of data splitting file %s, in the set of phases: %s.\n'
                    'Remove %s to generate random data partition file.',
                    self.data_split_file, SUPPORTED_PHASES,
                    self.data_split_file)
                raise ValueError

    def __str__(self):
        return self.to_string()

    def to_string(self):
        """
        Print summary of the partitioner.
        """
        n_subjects = self.number_of_subjects()
        summary_str = '\n\nNumber of subjects {}, '.format(n_subjects)
        if self._file_list is not None:
            summary_str += 'input section names: {}\n'.format(
                list(self._file_list))
        if self._partition_ids is not None and n_subjects > 0:
            n_train = self.number_of_subjects(TRAIN)
            n_valid = self.number_of_subjects(VALID)
            n_infer = self.number_of_subjects(INFER)
            summary_str += \
                'Dataset partitioning:\n' \
                '-- {} {} cases ({:.2f}%),\n' \
                '-- {} {} cases ({:.2f}%),\n' \
                '-- {} {} cases ({:.2f}%).\n'.format(
                    TRAIN, n_train, float(n_train) / float(n_subjects) * 100.0,
                    VALID, n_valid, float(n_valid) / float(n_subjects) * 100.0,
                    INFER, n_infer, float(n_infer) / float(n_subjects) * 100.0)
        else:
            summary_str += '-- using all subjects ' \
                           '(without data partitioning).\n'
        return summary_str

    def has_phase(self, phase):
        """

        :return: True if the `phase` subset of images is not empty.
        """
        if self._partition_ids is None or self._partition_ids.empty:
            return False
        selector = self._partition_ids[COLUMN_PHASE] == phase
        if not selector.any():
            return False
        selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]]
        subset = pandas.merge(left=self._file_list,
                              right=selected,
                              on=COLUMN_UNIQ_ID,
                              sort=False)
        return not subset.empty

    @property
    def has_training(self):
        """

        :return: True if the TRAIN subset of images is not empty.
        """
        return self.has_phase(TRAIN)

    @property
    def has_inference(self):
        """

        :return: True if the INFER subset of images is not empty.
        """
        return self.has_phase(INFER)

    @property
    def has_validation(self):
        """

        :return: True if the VALID subset of images is not empty.
        """
        return self.has_phase(VALID)

    @property
    def validation_files(self):
        """

        :return: the list of validation filenames.
        """
        if self.has_validation:
            return self.get_file_list(VALID)
        return self.all_files

    @property
    def train_files(self):
        """

        :return: the list of training filenames.
        """
        if self.has_training:
            return self.get_file_list(TRAIN)
        return self.all_files

    @property
    def inference_files(self):
        """

        :return: the list of inference filenames
            (defaulting to list of all filenames if no partition definition)
        """
        if self.has_inference:
            return self.get_file_list(INFER)
        return self.all_files

    @property
    def all_files(self):
        """

        :return: list of all filenames
        """
        return self.get_file_list()

    def get_file_lists_by(self, phase=None, action='train'):
        """
        Get file lists by action and phase.

        This function returns file lists for training/validation/inference
        based on the phase or action specified by the user.

        ``phase`` has a higher priority:
        If `phase` specified, the function returns the corresponding
        file list (as a list).

        otherwise, the function checks ``action``:
        it returns train and validation file lists if it's training action,
        otherwise returns inference file list.

        :param action: an action
        :param phase: an element from ``{TRAIN, VALID, INFER, ALL}``
        :return:
        """
        if phase:
            try:
                return [self.get_file_list(phase=phase)]
            except (ValueError, AttributeError):
                tf.logging.warning('phase `parameter` %s ignored', phase)

        if action and TRAIN.startswith(action):
            file_lists = [self.train_files]
            if self.has_validation:
                file_lists.append(self.validation_files)
            return file_lists

        return [self.inference_files]

    def reset(self):
        """
        reset all fields of this singleton class.
        """
        self._file_list = None
        self._partition_ids = None
        self.data_param = None
        self.ratios = None
        self.new_partition = False

        self.data_split_file = ""
        self.default_image_file_location = \
            NiftyNetGlobalConfig().get_niftynet_home_folder()