def test_incorrect_config_file_backed_up(self): # create an incorrect config file at the correct location makedirs(NiftyNetGlobalConfigTest.config_home) incorrect_config = '\n'.join([NiftyNetGlobalConfigTest.header, 'invalid_home_tag = ~/niftynet']) with open(NiftyNetGlobalConfigTest.config_file, 'w') as config_file: config_file.write(incorrect_config) # the following should back it up and replace it with default config global_config = NiftyNetGlobalConfig().setup() self.assertTrue(isfile(NiftyNetGlobalConfigTest.config_file)) self.assertEqual(global_config.get_niftynet_config_folder(), NiftyNetGlobalConfigTest.config_home) # check if incorrect file was backed up found_files = glob( join(NiftyNetGlobalConfigTest.config_home, NiftyNetGlobalConfigTest.typify('config-backup-*'))) self.assertTrue(len(found_files) == 1) with open(found_files[0], 'r') as backup_file: self.assertEqual(backup_file.read(), incorrect_config) # cleanup: remove backup file NiftyNetGlobalConfigTest.remove_path(found_files[0])
def test_incorrect_config_file_backed_up(self): # create an incorrect config file at the correct location makedirs(NiftyNetGlobalConfigTestBase.config_home) incorrect_config = '\n'.join([ NiftyNetGlobalConfigTestBase.header, 'invalid_home_tag = ~/niftynet' ]) with open(NiftyNetGlobalConfigTestBase.config_file, 'w') as config_file: config_file.write(incorrect_config) # the following should back it up and replace it with default config global_config = NiftyNetGlobalConfig().setup() self.assertTrue(isfile(NiftyNetGlobalConfigTestBase.config_file)) self.assertEqual(global_config.get_niftynet_config_folder(), NiftyNetGlobalConfigTestBase.config_home) # check if incorrect file was backed up found_files = glob( join(NiftyNetGlobalConfigTestBase.config_home, NiftyNetGlobalConfigTestBase.typify('config-backup-*'))) self.assertTrue(len(found_files) == 1) with open(found_files[0], 'r') as backup_file: self.assertEqual(backup_file.read(), incorrect_config) # cleanup: remove backup file NiftyNetGlobalConfigTestBase.remove_path(found_files[0])
def download(example_ids, download_if_already_existing=False, verbose=True): """ Downloads standard NiftyNet examples such as data, samples :param example_ids: A list of identifiers for the samples to download :param download_if_already_existing: If true, data will always be downloaded :param verbose: If true, download info will be printed """ global_config = NiftyNetGlobalConfig() config_store = ConfigStore(global_config, verbose=verbose) # If a single id is specified, convert to a list example_ids = [example_ids] \ if not isinstance(example_ids, (tuple, list)) else example_ids if not example_ids: return False # Check if the server is running by looking for a known file remote_base_url_test = gitlab_raw_file_url( global_config.get_download_server_url(), 'README.md') server_ok = url_exists(remote_base_url_test) if verbose: print("Accessing: {}".format(global_config.get_download_server_url())) any_error = False for example_id in example_ids: if not example_id: any_error = True continue if config_store.exists(example_id): update_ok = config_store.update_if_required( example_id, download_if_already_existing) any_error = (not update_ok) or any_error else: any_error = True if server_ok: print(example_id + ': FAIL. ') print('No NiftyNet example was found for ' + example_id + ".") # If errors occurred and the server is down report a message if any_error and not server_ok: print("The NiftyNetExamples server is not running") return not any_error
def test_existing_config_file_loaded(self): # create a config file with a custom NiftyNet home makedirs(NiftyNetGlobalConfigTest.config_home) custom_niftynet_home = '~/customniftynethome' custom_niftynet_home_abs = expanduser(custom_niftynet_home) config = ''.join(['home = ', custom_niftynet_home]) with open(NiftyNetGlobalConfigTest.config_file, 'w') as config_file: config_file.write('\n'.join( [NiftyNetGlobalConfigTest.header, config])) global_config = NiftyNetGlobalConfig().setup() self.assertEqual(global_config.get_niftynet_home_folder(), custom_niftynet_home_abs) NiftyNetGlobalConfigTest.remove_path(custom_niftynet_home_abs)
def download(example_ids, download_if_already_existing=False, verbose=True): """ Downloads standard NiftyNet examples such as data, samples :param example_ids: A list of identifiers for the samples to download :param download_if_already_existing: If true, data will always be downloaded :param verbose: If true, download info will be printed """ global_config = NiftyNetGlobalConfig() config_store = ConfigStore(global_config, verbose=verbose) # If a single id is specified, convert to a list example_ids = [example_ids] \ if not isinstance(example_ids, (tuple, list)) else example_ids if not example_ids: return False # Check if the server is running by looking for a known file remote_base_url_test = gitlab_raw_file_url( global_config.get_download_server_url(), 'README.md') server_ok = url_exists(remote_base_url_test) if verbose: print("Accessing: {}".format(global_config.get_download_server_url())) any_error = False for example_id in example_ids: if not example_id: any_error = True continue if config_store.exists(example_id): update_ok = config_store.update_if_required( example_id, download_if_already_existing) any_error = (not update_ok) or any_error else: any_error = True if server_ok: print(example_id + ': FAIL. ') print('No NiftyNet example was found for ' + example_id + ".") # If errors occurred and the server is down report a message if any_error and not server_ok: print("The NiftyNetExamples server is not running") return not any_error
def __resolve_config_file_path(cmdline_arg): """ Search for the absolute file name of the configuration file. starting from `-c` value provided by the user. :param cmdline_arg: :return: """ if not cmdline_arg: raise IOError("\nNo configuration file has been provided, did you " "forget '-c' command argument?{}".format(EPILOG_STRING)) # Resolve relative configuration file location config_file_path = os.path.expanduser(cmdline_arg) try: config_file_path = resolve_file_name(config_file_path, ('.', NIFTYNET_HOME)) if os.path.isfile(config_file_path): return config_file_path except (IOError, TypeError): config_file_path = os.path.expanduser(cmdline_arg) config_file_path = os.path.join( NiftyNetGlobalConfig().get_default_examples_folder(), config_file_path, config_file_path + "_config.ini") if os.path.isfile(config_file_path): return config_file_path # could not proceed without a configuration file raise IOError("\nConfiguration file not found: {}.{}".format( os.path.expanduser(cmdline_arg), EPILOG_STRING))
def __init__(self, names): # list of file names self._file_list = None self._input_sources = None self._shapes = None self._dtypes = None self._names = None self.names = names self._global_config = NiftyNetGlobalConfig() # list of image objects self.output_list = None self.current_id = -1 self.preprocessors = [] super(ImageReader, self).__init__(name='image_reader')
def file_path(self, path_array): if isinstance(path_array, string_types): path_array = (path_array,) home_folder = NiftyNetGlobalConfig().get_niftynet_home_folder() try: self._file_path = tuple(resolve_file_name(path, ('.', home_folder)) for path in path_array) except (TypeError, AssertionError, AttributeError, IOError): tf.logging.fatal( "unrecognised file path format, should be a valid filename," "or a sequence of filenames %s", path_array) raise IOError
def resolve_checkpoint(checkpoint_name): # For now only supports checkpoint_name where # checkpoint_name.index is in the file system # eventually will support checkpoint names that can be referenced # in a paths file. if os.path.isfile(checkpoint_name + '.index'): return checkpoint_name home_folder = NiftyNetGlobalConfig().get_niftynet_home_folder() checkpoint_name = to_absolute_path(input_path=checkpoint_name, model_root=home_folder) if os.path.isfile(checkpoint_name + '.index'): return checkpoint_name raise ValueError('Invalid checkpoint {}'.format(checkpoint_name))
def reset(self): """ reset all fields of this singleton class. """ self._file_list = None self._partition_ids = None self.data_param = None self.ratios = None self.new_partition = False self.data_split_file = "" self.default_image_file_location = \ NiftyNetGlobalConfig().get_niftynet_home_folder()
def test_existing_niftynet_home_not_touched(self): niftynet_home = expanduser( NiftyNetGlobalConfigTest.default_config_opts['home']) makedirs(niftynet_home) niftynet_ext = join( niftynet_home, NiftyNetGlobalConfigTest.default_config_opts['ext']) makedirs(niftynet_ext) niftynet_ext_init = join(niftynet_ext, '__init__.py') open(niftynet_ext_init, 'w').close() mtime_before = getmtime(niftynet_ext_init) global_config = NiftyNetGlobalConfig() mtime_after = getmtime(niftynet_ext_init) self.assertEqual(mtime_before, mtime_after)
def test_non_existing_niftynet_home_created(self): niftynet_home = expanduser( NiftyNetGlobalConfigTest.default_config_opts['home']) NiftyNetGlobalConfigTest.remove_path(niftynet_home) self.assertFalse(isdir(niftynet_home)) niftynet_ext = join( niftynet_home, NiftyNetGlobalConfigTest.default_config_opts['ext']) self.assertFalse(isfile(join(niftynet_ext, '__init__.py'))) for mod in NiftyNetGlobalConfigTest.default_config_opts['ext_mods']: self.assertFalse(isfile(join(niftynet_ext, mod, '__init__.py'))) global_config = NiftyNetGlobalConfig().setup() self.assertTrue(isdir(niftynet_home)) self.assertTrue(isfile(join(niftynet_ext, '__init__.py'))) for mod in NiftyNetGlobalConfigTest.default_config_opts['ext_mods']: self.assertTrue(isfile(join(niftynet_ext, mod, '__init__.py')))
def create_instance(cls, file_path, **kwargs): """ Read image headers and create image instance. :param file_path: a file path or a sequence of file paths :param kwargs: output properties for transforming the image data array into a desired format :return: an image instance """ if file_path is None: tf.logging.fatal('No file_path provided, ' 'please check input sources in config file') raise ValueError ndims = 0 image_type = None home_folder = NiftyNetGlobalConfig().get_niftynet_home_folder() try: file_path = resolve_file_name(file_path, ('.', home_folder)) if os.path.isfile(file_path): loader = kwargs.get('loader', None) or None ndims = misc.infer_ndims_from_file(file_path, loader) image_type = cls.INSTANCE_DICT.get(ndims, None) except (TypeError, IOError, AttributeError): pass if image_type is None: try: file_path = [ resolve_file_name(path, ('.', home_folder)) for path in file_path ] loader = kwargs.get('loader', None) or (None, ) ndims = misc.infer_ndims_from_file(file_path[0], loader[0]) ndims = ndims + (1 if len(file_path) > 1 else 0) image_type = cls.INSTANCE_DICT.get(ndims, None) except (AssertionError, TypeError, IOError, AttributeError): tf.logging.fatal('Could not load file: %s', file_path) raise IOError if image_type is None: tf.logging.fatal('Not supported image type from:\n%s', file_path) raise NotImplementedError( "unrecognised spatial rank {}".format(ndims)) return image_type(file_path, **kwargs)
def test_non_existing_config_file_created(self): self.assertFalse(isfile(NiftyNetGlobalConfigTestBase.config_file)) global_config = NiftyNetGlobalConfig().setup() self.assertTrue(isfile(NiftyNetGlobalConfigTestBase.config_file)) self.assertEqual(global_config.get_niftynet_config_folder(), NiftyNetGlobalConfigTestBase.config_home)
def test_global_config_singleton(self): global_config_1 = NiftyNetGlobalConfig() global_config_2 = NiftyNetGlobalConfig() self.assertEqual(global_config_1, global_config_2) self.assertTrue(global_config_1 is global_config_2)
class ImageReader(Layer): """ For a concrete example: _input_sources define multiple modality mappings, e.g., _input_sources {'image': ('T1', 'T2'), 'label': ('manual_map',)} means 'image' consists of two components, formed by concatenating 'T1' and 'T2' input source images. 'label' consists of one component, loading from 'manual_map' self._names: a tuple of the output names of this reader. ('image', 'labels') self._shapes: the shapes after combining input sources {'image': (192, 160, 192, 1, 2), 'label': (192, 160, 192, 1, 1)} self._dtypes: store the dictionary of tensorflow shapes {'image': tf.float32, 'label': tf.float32} self.output_list is a list of dictionaries, with each item: {'image': <niftynet.io.image_type.SpatialImage4D object>, 'label': <niftynet.io.image_type.SpatialImage3D object>} """ def __init__(self, names): # list of file names self._file_list = None self._input_sources = None self._shapes = None self._dtypes = None self._names = None self.names = names self._global_config = NiftyNetGlobalConfig() # list of image objects self.output_list = None self.current_id = -1 self.preprocessors = [] super(ImageReader, self).__init__(name='image_reader') def initialise_reader(self, data_param, task_param): """ task_param specifies how to combine user input modalities e.g., for multimodal segmentation 'image' corresponds to multiple modality sections, 'label' corresponds to one modality section """ if not self.names: tf.logging.fatal('Please specify data names, this should ' 'be a subset of SUPPORTED_INPUT provided ' 'in application file') raise ValueError self._names = [ name for name in self.names if vars(task_param).get(name, None) ] self._input_sources = { name: vars(task_param).get(name) for name in self.names } data_to_load = {} for name in self._names: for source in self._input_sources[name]: try: data_to_load[source] = data_param[source] except KeyError: tf.logging.fatal( 'reader name [%s] requires [%s], however it is not ' 'specified as a section in the config, ' 'current input section names: %s', name, source, list(data_param)) raise ValueError default_data_folder = self._global_config.get_niftynet_home_folder() self._file_list = util_csv.load_and_merge_csv_files( data_to_load, default_data_folder) self.output_list = _filename_to_image_list(self._file_list, self._input_sources, data_param) for name in self.names: tf.logging.info('image reader: loading [%s] from %s (%d)', name, self.input_sources[name], len(self.output_list)) def prepare_preprocessors(self): for layer in self.preprocessors: if isinstance(layer, DataDependentLayer): layer.train(self.output_list) def add_preprocessing_layers(self, layers): assert self.output_list is not None, \ 'Please initialise the reader first, ' \ 'before adding preprocessors.' if isinstance(layers, Layer): self.preprocessors.append(layers) else: self.preprocessors.extend(layers) self.prepare_preprocessors() # pylint: disable=arguments-differ def layer_op(self, idx=None, shuffle=True): """ this layer returns a dictionary keys: self.output_fields values: image volume array """ if idx is None and shuffle: # training, with random list output idx = np.random.randint(len(self.output_list)) if idx is None and not shuffle: # testing, with sequential output # accessing self.current_id, not suitable for multi-thread idx = self.current_id + 1 self.current_id = idx try: idx = int(idx) except ValueError: idx = -1 if idx < 0 or idx >= len(self.output_list): return -1, None, None image_dict = self.output_list[idx] image_data_dict = { field: image.get_data() for (field, image) in image_dict.items() } interp_order_dict = { field: image.interp_order for (field, image) in image_dict.items() } if self.preprocessors: preprocessors = [deepcopy(layer) for layer in self.preprocessors] # dictionary of masks is cached mask = None for layer in preprocessors: # import time; local_time = time.time() if layer is None: continue if isinstance(layer, RandomisedLayer): layer.randomise() image_data_dict = layer(image_data_dict, interp_order_dict) else: image_data_dict, mask = layer(image_data_dict, mask) # print('%s, %.3f sec'%(layer, -local_time + time.time())) return idx, image_data_dict, interp_order_dict @property def shapes(self): """ image shapes before any preprocessing :return: tuple of integers as image shape """ # to have fast access, the spatial dimensions are not accurate # 1) only read from the first image in list # 2) not considering effects of random augmentation layers # but time and modality dimensions should be correct if not self.output_list: tf.logging.fatal("please initialise the reader first") raise RuntimeError if not self._shapes: first_image = self.output_list[0] self._shapes = { field: first_image[field].shape for field in self.names } return self._shapes @property def tf_dtypes(self): if not self.output_list: tf.logging.fatal("please initialise the reader first") raise RuntimeError if not self._dtypes: first_image = self.output_list[0] self._dtypes = { field: infer_tf_dtypes(first_image[field]) for field in self.names } return self._dtypes @property def input_sources(self): if not self._input_sources: tf.logging.fatal("please initialise the reader first") raise RuntimeError return self._input_sources @property def names(self): return self._names @names.setter def names(self, fields_tuple): # output_fields is a sequence of output names # each name might correspond to a list of multiple input sources # this should be specified in CUSTOM section in the config self._names = make_input_tuple(fields_tuple, string_types) def get_subject_id(self, image_index): return self._file_list.iloc[image_index, 0]
def run(): """ meta_parser is first used to find out location of the configuration file. based on the application_name or meta_parser.prog name, the section parsers are organised to find system parameters and application specific parameters :return: system parameters is a group of parameters including SYSTEM_SECTIONS and app_module.REQUIRED_CONFIG_SECTION input_data_args is a group of input data sources to be used by niftynet.io.ImageReader """ meta_parser = argparse.ArgumentParser( epilog='Please visit ' 'https://cmiclab.cs.ucl.ac.uk/CMIC/NiftyNet/tree/dev/demos ' 'for more info.') version_string = get_niftynet_version_string() meta_parser.add_argument("-v", "--version", action='version', version=version_string) meta_parser.add_argument("-c", "--conf", help="Specify configurations from a file", metavar="File", ) meta_parser.add_argument("-a", "--application_name", help="Specify application name", default="", ) meta_args, args_from_cmdline = meta_parser.parse_known_args() print(version_string) # read configurations, to be parsed by sections if meta_args.conf is None: raise IOError("No configuration file has been provided") # Read global config file global_config = NiftyNetGlobalConfig() config_path = meta_args.conf if not os.path.isfile(config_path): relative_conf_file = os.path.join(global_config.get_default_examples_folder(), config_path, config_path + "_config.ini") if os.path.isfile(relative_conf_file): config_path = relative_conf_file os.chdir(os.path.dirname(config_path)) else: raise IOError("Configuration file not found {}".format(config_path)) config = configparser.ConfigParser() config.read([config_path]) app_module = None try: if meta_parser.prog[:-3] in SUPPORTED_APP: module_name = meta_parser.prog[:-3] elif meta_parser.prog in SUPPORTED_APP: module_name = meta_parser.prog else: module_name = meta_args.application_name app_module = ApplicationFactory.create(module_name) assert app_module.REQUIRED_CONFIG_SECTION, \ "REQUIRED_CONFIG_SECTION should be static variable " \ "in {}".format(app_module) has_section_in_config(config, app_module.REQUIRED_CONFIG_SECTION) except ValueError: if app_module: section_name = app_module.REQUIRED_CONFIG_SECTION raise ValueError( '{} requires [{}] section in the config file'.format( module_name, section_name)) else: raise ValueError( "unknown application {}, or did you forget '-a' " "command argument".format(module_name)) # check keywords in configuration file check_keywords(config) # using configuration as default, and parsing all command line arguments all_args = {} for section in config.sections(): # try to rename user-specified sections for consistency section = standardise_section_name(config, section) section_defaults = dict(config.items(section)) section_args, args_from_cmdline = \ _parse_arguments_by_section([], section, section_defaults, args_from_cmdline, app_module.REQUIRED_CONFIG_SECTION) all_args[section] = section_args # command line parameters should be valid assert not args_from_cmdline, \ 'unknown parameter: {}'.format(args_from_cmdline) # split parsed results in all_args # into dictionary of system_args and input_data_args system_args = {} input_data_args = {} for section in all_args: if section in SYSTEM_SECTIONS: system_args[section] = all_args[section] elif section == app_module.REQUIRED_CONFIG_SECTION: system_args['CUSTOM'] = all_args[section] vars(system_args['CUSTOM'])['name'] = module_name if all_args['SYSTEM'].model_dir is None: all_args['SYSTEM'].model_dir = os.path.join( os.path.dirname(meta_args.conf), 'model') for section in all_args: if section in SYSTEM_SECTIONS: continue if section == app_module.REQUIRED_CONFIG_SECTION: continue input_data_args[section] = all_args[section] # set the output path of csv list if not exists csv_path = input_data_args[section].csv_file if not os.path.isfile(csv_path): csv_filename = os.path.join( all_args['SYSTEM'].model_dir, '{}.csv'.format(section)) input_data_args[section].csv_file = csv_filename else: # don't search files if csv specified in config try: delattr(input_data_args[section], 'path_to_search') except AttributeError: pass # update conf path system_args['CONFIG_FILE'] = argparse.Namespace(path=meta_args.conf) return system_args, input_data_args
default_keys = [] for action in all_key_parser._actions: try: default_keys.append(action.option_strings[0][2:]) except (IndexError, AttributeError, ValueError): pass # remove duplicates default_keys = list(set(default_keys)) # remove bad names default_keys = [keyword for keyword in default_keys if keyword] return default_keys KEYWORDS = available_keywords() NIFTYNET_HOME = NiftyNetGlobalConfig().get_niftynet_home_folder() # pylint: disable=too-many-branches def run(): """ meta_parser is first used to find out location of the configuration file. Based on the application_name or meta_parser.prog name, the section parsers are organised to find system parameters and application specific parameters. :return: system parameters is a group of parameters including SYSTEM_SECTIONS and app_module.REQUIRED_CONFIG_SECTION input_data_args is a group of input data sources to be used by niftynet.io.ImageReader
# -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import os, sys, unittest os.environ["CUDA_VISIBLE_DEVICES"] = "0" import tensorflow as tf from niftynet.utilities.download import download from niftynet.utilities.niftynet_global_config import NiftyNetGlobalConfig from niftynet import main as niftynet_main from niftynet.application.base_application import SingletonApplication MODEL_HOME = NiftyNetGlobalConfig().get_niftynet_home_folder() def net_run_with_sys_argv(argv): # for gift-adelie os.environ["CUDA_VISIBLE_DEVICES"] = "0" SingletonApplication.clear() cache = sys.argv argv.extend(['--cuda_devices', '0']) sys.argv = argv niftynet_main() sys.argv = cache @unittest.skipIf(
def test_non_existing_config_file_created(self): self.assertFalse(isfile(NiftyNetGlobalConfigTest.config_file)) global_config = NiftyNetGlobalConfig().setup() self.assertTrue(isfile(NiftyNetGlobalConfigTest.config_file)) self.assertEqual(global_config.get_niftynet_config_folder(), NiftyNetGlobalConfigTest.config_home)
def run(): """ meta_parser is first used to find out location of the configuration file. based on the application_name or meta_parser.prog name, the section parsers are organised to find system parameters and application specific parameters. :return: system parameters is a group of parameters including SYSTEM_SECTIONS and app_module.REQUIRED_CONFIG_SECTION input_data_args is a group of input data sources to be used by niftynet.io.ImageReader """ meta_parser = argparse.ArgumentParser( description="Launch a NiftyNet application.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=textwrap.dedent(epilog_string)) version_string = get_niftynet_version_string() meta_parser.add_argument("action", help="train networks or run inferences", metavar='ACTION', choices=['train', 'inference']) meta_parser.add_argument("-v", "--version", action='version', version=version_string) meta_parser.add_argument("-c", "--conf", help="specify configurations from a file", metavar="CONFIG_FILE") meta_parser.add_argument("-a", "--application_name", help="specify an application module name", metavar='APPLICATION_NAME', default="") meta_args, args_from_cmdline = meta_parser.parse_known_args() print(version_string) # read configurations, to be parsed by sections if not meta_args.conf: print("\nNo configuration file has been provided, did you " "forget '-c' command argument?{}".format(epilog_string)) raise IOError # Resolve relative configuration file location config_path = meta_args.conf if not os.path.isfile(config_path): relative_conf_file = os.path.join( NiftyNetGlobalConfig().get_default_examples_folder(), config_path, config_path + "_config.ini") if os.path.isfile(relative_conf_file): config_path = relative_conf_file os.chdir(os.path.dirname(config_path)) else: print("\nConfiguration file not found: {}.{}".format( config_path, epilog_string)) raise IOError config = configparser.ConfigParser() config.read([config_path]) app_module = None module_name = None try: if meta_parser.prog[:-3] in SUPPORTED_APP: module_name = meta_parser.prog[:-3] elif meta_parser.prog in SUPPORTED_APP: module_name = meta_parser.prog else: module_name = meta_args.application_name app_module = ApplicationFactory.create(module_name) assert app_module.REQUIRED_CONFIG_SECTION, \ "\nREQUIRED_CONFIG_SECTION should be static variable " \ "in {}".format(app_module) has_section_in_config(config, app_module.REQUIRED_CONFIG_SECTION) except ValueError: if app_module: section_name = app_module.REQUIRED_CONFIG_SECTION print('\n{} requires [{}] section in the config file.{}'.format( module_name, section_name, epilog_string)) if not module_name: print("\nUnknown application {}, or did you forget '-a' " "command argument?{}".format(module_name, epilog_string)) raise # check keywords in configuration file check_keywords(config) # using configuration as default, and parsing all command line arguments all_args = {} for section in config.sections(): # try to rename user-specified sections for consistency section = standardise_section_name(config, section) section_defaults = dict(config.items(section)) section_args, args_from_cmdline = \ _parse_arguments_by_section([], section, section_defaults, args_from_cmdline, app_module.REQUIRED_CONFIG_SECTION) all_args[section] = section_args # command line parameters should be valid assert not args_from_cmdline, \ '\nUnknown parameter: {}{}'.format(args_from_cmdline, epilog_string) # split parsed results in all_args # into dictionaries of system_args and input_data_args system_args = {} input_data_args = {} # copy system default sections to ``system_args`` for section in all_args: if section in SYSTEM_SECTIONS: system_args[section] = all_args[section] elif section == app_module.REQUIRED_CONFIG_SECTION: system_args['CUSTOM'] = all_args[section] vars(system_args['CUSTOM'])['name'] = module_name if all_args['SYSTEM'].model_dir is None: all_args['SYSTEM'].model_dir = os.path.join( os.path.dirname(meta_args.conf), 'model') # copy non-default sections to ``input_data_args`` for section in all_args: if section in SYSTEM_SECTIONS: continue if section == app_module.REQUIRED_CONFIG_SECTION: continue input_data_args[section] = all_args[section] # set the output path of csv list if not exists csv_path = input_data_args[section].csv_file if os.path.isfile(csv_path): # don't search files if csv specified in config try: delattr(input_data_args[section], 'path_to_search') except AttributeError: pass else: input_data_args[section].csv_file = '' # preserve ``config_file`` and ``action parameter`` from the meta_args system_args['CONFIG_FILE'] = argparse.Namespace(path=meta_args.conf) system_args['SYSTEM'].action = meta_args.action return system_args, input_data_args
class ImageSetsPartitioner(object): """ This class maintains a pandas.dataframe of filenames for all input sections. The list of filenames are obtained by searching the specified folders or loading from an existing csv file. Users can query a subset of the dataframe by train/valid/infer partition label and input section names. """ # dataframe (table) of file names in a shape of subject x modality _file_list = None # dataframes of subject_id:phase_id _partition_ids = None data_param = None ratios = None new_partition = False # for saving the splitting index data_split_file = "" # default parent folder location for searching the image files default_image_file_location = \ NiftyNetGlobalConfig().get_niftynet_home_folder() def initialise(self, data_param, new_partition=False, data_split_file="./test.csv", ratios=None): """ Set the data partitioner parameters data_param: corresponding to all config sections new_partition: bool value indicating whether to generate new partition ids and overwrite csv file (this class will write partition file iff new_partition) data_split_file: location of the partition id file ratios: a tuple/list with two elements: (fraction of the validation set, fraction of the inference set) initialise to None will disable data partitioning and get_file_list always returns all subjects. """ self.data_param = data_param self.data_split_file = data_split_file self.ratios = ratios self._file_list = None self._partition_ids = None self.load_data_sections_by_subject() self.new_partition = new_partition self.randomly_split_dataset(overwrite=new_partition) tf.logging.info(self) return self def number_of_subjects(self, phase=ALL): """ query number of images according to phase :param phase: :return: """ if self._file_list is None: return 0 phase = look_up_operations(phase, SUPPORTED_PHASES) if phase == ALL: return self._file_list[COLUMN_UNIQ_ID].count() if self._partition_ids is None: return 0 selector = self._partition_ids[COLUMN_PHASE] == phase return self._partition_ids[selector].count()[COLUMN_UNIQ_ID] def get_file_list(self, phase=ALL, *section_names): """ get file names as a dataframe, by partitioning phase and section names set phase to ALL to load all subsets. :param phase: the label of the subset generated by self._partition_ids should be one of the SUPPORTED_PHASES :param section_names: one or multiple input section names :return: a pandas.dataframe of file names """ if self._file_list is None: tf.logging.warning('Empty file list, please initialise' 'ImageSetsPartitioner first.') return [] try: look_up_operations(phase, SUPPORTED_PHASES) except ValueError: tf.logging.fatal('Unknown phase argument.') raise for name in section_names: try: look_up_operations(name, set(self._file_list)) except ValueError: tf.logging.fatal( 'Requesting files under input section [%s],\n' 'however the section does not exist in the config.', name) raise if phase == ALL: self._file_list = self._file_list.sort_index() if section_names: section_names = [COLUMN_UNIQ_ID] + list(section_names) return self._file_list[section_names] return self._file_list if self._partition_ids is None: tf.logging.fatal('No partition ids available.') if self.new_partition: tf.logging.fatal( 'Unable to create new partitions,' 'splitting ratios: %s, writing file %s', self.ratios, self.data_split_file) elif os.path.isfile(self.data_split_file): tf.logging.fatal( 'Unable to load %s, initialise the' 'ImageSetsPartitioner with `new_partition=True`' 'to overwrite the file.', self.data_split_file) raise ValueError selector = self._partition_ids[COLUMN_PHASE] == phase selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]] if selected.empty: tf.logging.warning( 'Empty subset for phase [%s], returning None as file list. ' 'Please adjust splitting fractions.', phase) return None subset = pandas.merge(self._file_list, selected, on=COLUMN_UNIQ_ID) if section_names: section_names = [COLUMN_UNIQ_ID] + list(section_names) return subset[list(section_names)] return subset def load_data_sections_by_subject(self): """ Go through all input data sections, converting each section to a list of file names. These lists are merged on COLUMN_UNIQ_ID This function sets self._file_list """ if not self.data_param: tf.logging.fatal( 'Nothing to load, please check input sections in the config.') raise ValueError self._file_list = None for section_name in self.data_param: modality_file_list = self.grep_files_by_data_section(section_name) if self._file_list is None: # adding all rows of the first modality self._file_list = modality_file_list continue n_rows = self._file_list[COLUMN_UNIQ_ID].count() self._file_list = pandas.merge(self._file_list, modality_file_list, how='outer', on=COLUMN_UNIQ_ID) if self._file_list[COLUMN_UNIQ_ID].count() < n_rows: tf.logging.warning('rows not matched in section [%s]', section_name) if self._file_list is None or self._file_list.size == 0: tf.logging.fatal( "empty filename lists, please check the csv " "files. (removing csv_file keyword if it is in the config file " "to automatically search folders and generate new csv " "files again)\n\n" "Please note in the matched file names, each subject id are " "created by removing all keywords listed `filename_contains` " "in the config.\n\n" "E.g., `filename_contains=foo, bar` will match file " "foo_subject42_bar.nii.gz, and the subject id is _subject42_.") raise IOError def grep_files_by_data_section(self, modality_name): """ list all files by a given input data section, if the `csv_file` property of the section corresponds to a file, read the list from the file; otherwise write the list to `csv_file`. returns: a table with two columns, the column names are (COLUMN_UNIQ_ID, modality_name) """ if modality_name not in self.data_param: tf.logging.fatal( 'unknown section name [%s], ' 'current input section names: %s.', modality_name, list(self.data_param)) raise ValueError # input data section must have a `csv_file` section for loading # or writing filename lists try: csv_file = self.data_param[modality_name].csv_file except AttributeError: tf.logging.fatal('Missing `csv_file` field in the config file, ' 'unknown configuration format.') raise if hasattr(self.data_param[modality_name], 'path_to_search') and \ self.data_param[modality_name].path_to_search: tf.logging.info('[%s] search file folders, writing csv file %s', modality_name, csv_file) section_properties = self.data_param[modality_name].__dict__.items( ) # grep files by section properties and write csv try: matcher = KeywordsMatching.from_tuple( section_properties, self.default_image_file_location) match_and_write_filenames_to_csv([matcher], csv_file) except ValueError as reading_error: tf.logging.warning( 'Ignoring input section: [%s], ' 'due to the following error:', modality_name) tf.logging.warning(repr(reading_error)) return pandas.DataFrame( columns=[COLUMN_UNIQ_ID, modality_name]) else: tf.logging.info( '[%s] using existing csv file %s, skipped filenames search', modality_name, csv_file) if not os.path.isfile(csv_file): tf.logging.fatal('[%s] csv file %s not found.', modality_name, csv_file) raise IOError try: csv_list = pandas.read_csv(csv_file, header=None, dtype=(str, str), names=[COLUMN_UNIQ_ID, modality_name], skipinitialspace=True) except Exception as csv_error: tf.logging.fatal(repr(csv_error)) raise return csv_list # pylint: disable=broad-except def randomly_split_dataset(self, overwrite=False): """ Label each subject as one of the 'TRAIN', 'VALID', 'INFER', use self.ratios to compute the size of each set. the results will be written to self.data_split_file if overwrite otherwise it tries to read partition labels from it. This function sets self._partition_ids """ if overwrite: try: valid_fraction, infer_fraction = self.ratios valid_fraction = max(min(1.0, float(valid_fraction)), 0.0) infer_fraction = max(min(1.0, float(infer_fraction)), 0.0) except (TypeError, ValueError): tf.logging.fatal('Unknown format of faction values %s', self.ratios) raise if (valid_fraction + infer_fraction) <= 0: tf.logging.warning( 'To split dataset into training/validation, ' 'please make sure ' '"exclude_fraction_for_validation" parameter is set to ' 'a float in between 0 and 1. Current value: %s.', valid_fraction) # raise ValueError n_total = self.number_of_subjects() n_valid = int(math.ceil(n_total * valid_fraction)) n_infer = int(math.ceil(n_total * infer_fraction)) n_train = int(n_total - n_infer - n_valid) phases = [TRAIN] * n_train + \ [VALID] * n_valid + \ [INFER] * n_infer if len(phases) > n_total: phases = phases[:n_total] random.shuffle(phases) write_csv(self.data_split_file, zip(self._file_list[COLUMN_UNIQ_ID], phases)) elif os.path.isfile(self.data_split_file): tf.logging.warning( 'Loading from existing partitioning file %s, ' 'ignoring partitioning ratios.', self.data_split_file) if os.path.isfile(self.data_split_file): try: self._partition_ids = pandas.read_csv( self.data_split_file, header=None, dtype=(str, str), names=[COLUMN_UNIQ_ID, COLUMN_PHASE], skipinitialspace=True) assert not self._partition_ids.empty, \ "partition file is empty." except Exception as csv_error: tf.logging.warning( "Unable to load the existing partition file %s, %s", self.data_split_file, repr(csv_error)) self._partition_ids = None try: is_valid_phase = \ self._partition_ids[COLUMN_PHASE].isin(SUPPORTED_PHASES) assert is_valid_phase.all(), \ "Partition file contains unknown phase id." except (TypeError, AssertionError): tf.logging.warning( 'Please make sure the values of the second column ' 'of data splitting file %s, in the set of phases: %s.\n' 'Remove %s to generate random data partition file.', self.data_split_file, SUPPORTED_PHASES, self.data_split_file) raise ValueError def __str__(self): return self.to_string() def to_string(self): """ Print summary of the partitioner """ n_subjects = self.number_of_subjects() summary_str = '\nNumber of subjects {}, '.format(n_subjects) if self._file_list is not None: summary_str += 'input section names: {}\n'.format( list(self._file_list)) if self._partition_ids is not None and n_subjects > 0: n_valid = self.number_of_subjects(VALID) n_train = self.number_of_subjects(TRAIN) n_infer = self.number_of_subjects(INFER) summary_str += \ 'data partitioning -- number of cases:\n' \ '-- {} {} ({:.2f}%),\n' \ '-- {} {} ({:.2f}%),\n' \ '-- {} {} ({:.2f}%).\n'.format( VALID, n_valid, float(n_valid) / float(n_subjects) * 100.0, TRAIN, n_train, float(n_train) / float(n_subjects) * 100.0, INFER, n_infer, float(n_infer) / float(n_subjects) * 100.0) else: summary_str += '-- using all subjects ' \ '(without data partitioning).\n' return summary_str def has_phase(self, phase): """ returns True if the `phase` subset of images is not empty """ if self._partition_ids is None: return False return (self._partition_ids[COLUMN_PHASE] == phase).any() @property def has_training(self): """ returns True if the TRAIN subset of images is not empty """ return self.has_phase(TRAIN) @property def has_inference(self): """ returns True if the INFER subset of images is not empty """ return self.has_phase(INFER) @property def has_validation(self): """ returns True if the VALID subset of images is not empty """ return self.has_phase(VALID) @property def validation_files(self): """ returns the list of validation filenames """ return self.get_file_list(VALID) @property def train_files(self): """ returns the list of training filenames """ return self.get_file_list(TRAIN) @property def inference_files(self): """ returns the list of inference filenames (defaulting to list of all filenames if no partition definition) """ if self.has_inference: return self.get_file_list(INFER) return self.all_files @property def all_files(self): """ returns list of all filenames """ return self.get_file_list() def reset(self): """ reset all fields of this singleton class """ self._file_list = None self._partition_ids = None self.data_param = None self.ratios = None self.new_partition = False self.data_split_file = "" self.default_image_file_location = \ NiftyNetGlobalConfig().get_niftynet_home_folder()
def resolve_module_dir(module_dir_str, create_new=False): """ Interpret `module_dir_str` as an absolute folder path. create the folder if `create_new` :param module_dir_str: :param create_new: :return: """ try: # interpret input as a module string module_from_string = importlib.import_module(module_dir_str) folder_path = os.path.dirname(module_from_string.__file__) return os.path.abspath(folder_path) except (ImportError, AttributeError, TypeError): pass try: # interpret last part of input as a module string string_last_part = module_dir_str.rsplit('.', 1) module_from_string = importlib.import_module(string_last_part[-1]) folder_path = os.path.dirname(module_from_string.__file__) return os.path.abspath(folder_path) except (ImportError, AttributeError, IndexError, TypeError): pass module_dir_str = os.path.expanduser(module_dir_str) try: # interpret input as a file folder path string if os.path.isdir(module_dir_str): return os.path.abspath(module_dir_str) except TypeError: pass try: # interpret input as a file path string if os.path.isfile(module_dir_str): return os.path.abspath(os.path.dirname(module_dir_str)) except TypeError: pass try: # interpret input as a path string relative to the global home home_location = NiftyNetGlobalConfig().get_niftynet_home_folder() possible_dir = os.path.join(home_location, module_dir_str) if os.path.isdir(possible_dir): return os.path.abspath(possible_dir) except (TypeError, ImportError, AttributeError): pass if create_new: # try to create the folder folder_path = touch_folder(module_dir_str) init_file = os.path.join(folder_path, '__init__.py') try: file_ = os.open(init_file, os.O_CREAT | os.O_EXCL | os.O_WRONLY) except OSError as sys_error: if sys_error.errno == errno.EEXIST: pass else: tf.logging.fatal( "trying to use '{}' as NiftyNet writing path, " "however cannot write '{}'".format( folder_path, init_file)) raise else: with os.fdopen(file_, 'w') as file_object: file_object.write("# Created automatically\n") return folder_path else: raise ValueError( "Could not resolve [{}].\nMake sure it is a valid folder path " "or a module name.\nIf it is string representing a module, " "the parent folder of [{}] should be on " "the system path.\n\nCurrent system path {}.".format( module_dir_str, module_dir_str, sys.path))
class ImageSetsPartitioner(object): """ This class maintains a pandas.dataframe of filenames for all input sections. The list of filenames are obtained by searching the specified folders or loading from an existing csv file. Users can query a subset of the dataframe by train/valid/infer partition label and input section names. """ # dataframe (table) of file names in a shape of subject x modality _file_list = None # dataframes of subject_id:phase_id _partition_ids = None data_param = None ratios = None new_partition = False # for saving the splitting index data_split_file = "" # default parent folder location for searching the image files default_image_file_location = \ NiftyNetGlobalConfig().get_niftynet_home_folder() def initialise(self, data_param, new_partition=False, data_split_file=None, ratios=None): """ Set the data partitioner parameters :param data_param: corresponding to all config sections :param new_partition: bool value indicating whether to generate new partition ids and overwrite csv file (this class will write partition file iff new_partition) :param data_split_file: location of the partition id file :param ratios: a tuple/list with two elements: ``(fraction of the validation set, fraction of the inference set)`` initialise to None will disable data partitioning and get_file_list always returns all subjects. """ self.data_param = data_param if data_split_file is None: self.data_split_file = os.path.join('.', 'dataset_split.csv') else: self.data_split_file = data_split_file self.ratios = ratios self._file_list = None self._partition_ids = None self.load_data_sections_by_subject() self.new_partition = new_partition self.randomly_split_dataset(overwrite=new_partition) tf.logging.info(self) return self def number_of_subjects(self, phase=ALL): """ query number of images according to phase. :param phase: :return: """ if self._file_list is None: return 0 try: phase = look_up_operations(phase.lower(), SUPPORTED_PHASES) except (ValueError, AttributeError): tf.logging.fatal('Unknown phase argument.') raise if phase == ALL: return self._file_list[COLUMN_UNIQ_ID].count() if self._partition_ids is None: return 0 selector = self._partition_ids[COLUMN_PHASE] == phase selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]] subset = pandas.merge(self._file_list, selected, on=COLUMN_UNIQ_ID, sort=True) return subset.count()[COLUMN_UNIQ_ID] def get_file_list(self, phase=ALL, *section_names): """ get file names as a dataframe, by partitioning phase and section names set phase to ALL to load all subsets. :param phase: the label of the subset generated by self._partition_ids should be one of the SUPPORTED_PHASES :param section_names: one or multiple input section names :return: a pandas.dataframe of file names """ if self._file_list is None: tf.logging.warning('Empty file list, please initialise' 'ImageSetsPartitioner first.') return [] try: phase = look_up_operations(phase.lower(), SUPPORTED_PHASES) except (ValueError, AttributeError): tf.logging.fatal('Unknown phase argument.') raise for name in section_names: try: look_up_operations(name, set(self._file_list)) except ValueError: tf.logging.fatal( 'Requesting files under input section [%s],\n' 'however the section does not exist in the config.', name) raise if phase == ALL: self._file_list = self._file_list.sort_values(COLUMN_UNIQ_ID) if section_names: section_names = [COLUMN_UNIQ_ID] + list(section_names) return self._file_list[section_names] return self._file_list if self._partition_ids is None or self._partition_ids.empty: tf.logging.fatal('No partition ids available.') if self.new_partition: tf.logging.fatal( 'Unable to create new partitions,' 'splitting ratios: %s, writing file %s', self.ratios, self.data_split_file) elif os.path.isfile(self.data_split_file): tf.logging.fatal( 'Unable to load %s, initialise the' 'ImageSetsPartitioner with `new_partition=True`' 'to overwrite the file.', self.data_split_file) raise ValueError selector = self._partition_ids[COLUMN_PHASE] == phase selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]] if selected.empty: tf.logging.warning( 'Empty subset for phase [%s], returning None as file list. ' 'Please adjust splitting fractions.', phase) return None subset = pandas.merge(self._file_list, selected, on=COLUMN_UNIQ_ID, sort=True) if subset.empty: tf.logging.warning( 'No subject id matched in between file names and ' 'partition files.\nPlease check the partition files %s,\nor ' 'removing it to generate a new file automatically.', self.data_split_file) if section_names: section_names = [COLUMN_UNIQ_ID] + list(section_names) return subset[section_names] return subset def load_data_sections_by_subject(self): """ Go through all input data sections, converting each section to a list of file names. These lists are merged on ``COLUMN_UNIQ_ID``. This function sets ``self._file_list``. """ if not self.data_param: tf.logging.fatal( 'Nothing to load, please check input sections in the config.') raise ValueError self._file_list = None for section_name in self.data_param: modality_file_list = self.grep_files_by_data_section(section_name) if self._file_list is None: # adding all rows of the first modality self._file_list = modality_file_list continue n_rows = self._file_list[COLUMN_UNIQ_ID].count() self._file_list = pandas.merge(self._file_list, modality_file_list, how='outer', on=COLUMN_UNIQ_ID) if self._file_list[COLUMN_UNIQ_ID].count() < n_rows: tf.logging.warning('rows not matched in section [%s]', section_name) if self._file_list is None or self._file_list.size == 0: tf.logging.fatal( "Empty filename lists, please check the csv " "files (removing csv_file keyword if it is in the config file " "to automatically search folders and generate new csv " "files again).\n\n" "Please note in the matched file names, each subject id are " "created by removing all keywords listed `filename_contains` " "in the config.\n" "E.g., `filename_contains=foo, bar` will match file " "foo_subject42_bar.nii.gz, and the subject id is " "_subject42_.\n\n") raise IOError def grep_files_by_data_section(self, modality_name): """ list all files by a given input data section:: if the ``csv_file`` property of ``data_param[modality_name]`` corresponds to a file, read the list from the file; otherwise write the list to ``csv_file``. :return: a table with two columns, the column names are ``(COLUMN_UNIQ_ID, modality_name)``. """ if modality_name not in self.data_param: tf.logging.fatal( 'unknown section name [%s], ' 'current input section names: %s.', modality_name, list(self.data_param)) raise ValueError # input data section must have a ``csv_file`` section for loading # or writing filename lists if isinstance(self.data_param[modality_name], dict): mod_spec = self.data_param[modality_name] else: mod_spec = vars(self.data_param[modality_name]) ######################### # guess the csv_file path ######################### temp_csv_file = None try: csv_file = os.path.expanduser(mod_spec.get('csv_file', None)) if not os.path.isfile(csv_file): # writing to the same folder as data_split_file default_csv_file = os.path.join( os.path.dirname(self.data_split_file), '{}.csv'.format(modality_name)) tf.logging.info( '`csv_file = %s` not found, ' 'writing to "%s" instead.', csv_file, default_csv_file) csv_file = default_csv_file if os.path.isfile(csv_file): tf.logging.info('Overwriting existing: "%s".', csv_file) csv_file = os.path.abspath(csv_file) except (AttributeError, KeyError, TypeError): tf.logging.debug('`csv_file` not specified, writing the list of ' 'filenames to a temporary file.') import tempfile temp_csv_file = os.path.join(tempfile.mkdtemp(), '{}.csv'.format(modality_name)) csv_file = temp_csv_file ############################################# # writing csv file if path_to_search specified ############################################## if mod_spec.get('path_to_search', None): if not temp_csv_file: tf.logging.info( '[%s] search file folders, writing csv file %s', modality_name, csv_file) # grep files by section properties and write csv try: matcher = KeywordsMatching.from_dict( input_dict=mod_spec, default_folder=self.default_image_file_location) match_and_write_filenames_to_csv([matcher], csv_file) except (IOError, ValueError) as reading_error: tf.logging.warning( 'Ignoring input section: [%s], ' 'due to the following error:', modality_name) tf.logging.warning(repr(reading_error)) return pandas.DataFrame( columns=[COLUMN_UNIQ_ID, modality_name]) else: tf.logging.info( '[%s] using existing csv file %s, skipped filenames search', modality_name, csv_file) if not os.path.isfile(csv_file): tf.logging.fatal('[%s] csv file %s not found.', modality_name, csv_file) raise IOError ############################### # loading the file as dataframe ############################### try: csv_list = pandas.read_csv(csv_file, header=None, dtype=(str, str), names=[COLUMN_UNIQ_ID, modality_name], skipinitialspace=True) except Exception as csv_error: tf.logging.fatal(repr(csv_error)) raise if temp_csv_file: shutil.rmtree(os.path.dirname(temp_csv_file), ignore_errors=True) return csv_list # pylint: disable=broad-except def randomly_split_dataset(self, overwrite=False): """ Label each subject as one of the ``TRAIN``, ``VALID``, ``INFER``, use ``self.ratios`` to compute the size of each set. The results will be written to ``self.data_split_file`` if overwrite otherwise it tries to read partition labels from it. This function sets ``self._partition_ids``. """ if overwrite: try: valid_fraction, infer_fraction = self.ratios valid_fraction = max(min(1.0, float(valid_fraction)), 0.0) infer_fraction = max(min(1.0, float(infer_fraction)), 0.0) except (TypeError, ValueError): tf.logging.fatal('Unknown format of faction values %s', self.ratios) raise if (valid_fraction + infer_fraction) <= 0: tf.logging.warning( 'To split dataset into training/validation, ' 'please make sure ' '"exclude_fraction_for_validation" parameter is set to ' 'a float in between 0 and 1. Current value: %s.', valid_fraction) # raise ValueError n_total = self.number_of_subjects() n_valid = int(math.ceil(n_total * valid_fraction)) n_infer = int(math.ceil(n_total * infer_fraction)) n_train = int(n_total - n_infer - n_valid) phases = [TRAIN] * n_train + [VALID] * n_valid + [INFER] * n_infer if len(phases) > n_total: phases = phases[:n_total] random.shuffle(phases) write_csv(self.data_split_file, zip(self._file_list[COLUMN_UNIQ_ID], phases)) elif os.path.isfile(self.data_split_file): tf.logging.warning( 'Loading from existing partitioning file %s, ' 'ignoring partitioning ratios.', self.data_split_file) if os.path.isfile(self.data_split_file): try: self._partition_ids = pandas.read_csv( self.data_split_file, header=None, dtype=(str, str), names=[COLUMN_UNIQ_ID, COLUMN_PHASE], skipinitialspace=True) assert not self._partition_ids.empty, \ "partition file is empty." except Exception as csv_error: tf.logging.warning( "Unable to load the existing partition file %s, %s", self.data_split_file, repr(csv_error)) self._partition_ids = None try: phase_strings = self._partition_ids[COLUMN_PHASE] phase_strings = phase_strings.astype(str).str.lower() is_valid_phase = phase_strings.isin(SUPPORTED_PHASES) assert is_valid_phase.all(), \ "Partition file contains unknown phase id." self._partition_ids[COLUMN_PHASE] = phase_strings except (TypeError, AssertionError): tf.logging.warning( 'Please make sure the values of the second column ' 'of data splitting file %s, in the set of phases: %s.\n' 'Remove %s to generate random data partition file.', self.data_split_file, SUPPORTED_PHASES, self.data_split_file) raise ValueError def __str__(self): return self.to_string() def to_string(self): """ Print summary of the partitioner. """ n_subjects = self.number_of_subjects() summary_str = '\n\nNumber of subjects {}, '.format(n_subjects) if self._file_list is not None: summary_str += 'input section names: {}\n'.format( list(self._file_list)) if self._partition_ids is not None and n_subjects > 0: n_train = self.number_of_subjects(TRAIN) n_valid = self.number_of_subjects(VALID) n_infer = self.number_of_subjects(INFER) summary_str += \ 'Dataset partitioning:\n' \ '-- {} {} cases ({:.2f}%),\n' \ '-- {} {} cases ({:.2f}%),\n' \ '-- {} {} cases ({:.2f}%).\n'.format( TRAIN, n_train, float(n_train) / float(n_subjects) * 100.0, VALID, n_valid, float(n_valid) / float(n_subjects) * 100.0, INFER, n_infer, float(n_infer) / float(n_subjects) * 100.0) else: summary_str += '-- using all subjects ' \ '(without data partitioning).\n' return summary_str def has_phase(self, phase): """ :return: True if the `phase` subset of images is not empty. """ if self._partition_ids is None or self._partition_ids.empty: return False selector = self._partition_ids[COLUMN_PHASE] == phase if not selector.any(): return False selected = self._partition_ids[selector][[COLUMN_UNIQ_ID]] subset = pandas.merge(left=self._file_list, right=selected, on=COLUMN_UNIQ_ID, sort=False) return not subset.empty @property def has_training(self): """ :return: True if the TRAIN subset of images is not empty. """ return self.has_phase(TRAIN) @property def has_inference(self): """ :return: True if the INFER subset of images is not empty. """ return self.has_phase(INFER) @property def has_validation(self): """ :return: True if the VALID subset of images is not empty. """ return self.has_phase(VALID) @property def validation_files(self): """ :return: the list of validation filenames. """ if self.has_validation: return self.get_file_list(VALID) return self.all_files @property def train_files(self): """ :return: the list of training filenames. """ if self.has_training: return self.get_file_list(TRAIN) return self.all_files @property def inference_files(self): """ :return: the list of inference filenames (defaulting to list of all filenames if no partition definition) """ if self.has_inference: return self.get_file_list(INFER) return self.all_files @property def all_files(self): """ :return: list of all filenames """ return self.get_file_list() def get_file_lists_by(self, phase=None, action='train'): """ Get file lists by action and phase. This function returns file lists for training/validation/inference based on the phase or action specified by the user. ``phase`` has a higher priority: If `phase` specified, the function returns the corresponding file list (as a list). otherwise, the function checks ``action``: it returns train and validation file lists if it's training action, otherwise returns inference file list. :param action: an action :param phase: an element from ``{TRAIN, VALID, INFER, ALL}`` :return: """ if phase: try: return [self.get_file_list(phase=phase)] except (ValueError, AttributeError): tf.logging.warning('phase `parameter` %s ignored', phase) if action and TRAIN.startswith(action): file_lists = [self.train_files] if self.has_validation: file_lists.append(self.validation_files) return file_lists return [self.inference_files] def reset(self): """ reset all fields of this singleton class. """ self._file_list = None self._partition_ids = None self.data_param = None self.ratios = None self.new_partition = False self.data_split_file = "" self.default_image_file_location = \ NiftyNetGlobalConfig().get_niftynet_home_folder()