예제 #1
0
 def _reset_partition_file(self):
     test_partitioner = ImageSetsPartitioner()
     test_partitioner.initialise(
         test_sections,
         new_partition=True,
         ratios=(0.2, 0.2),
         data_split_file=partition_output)
 def _reset_partition_file(self):
     test_partitioner = ImageSetsPartitioner()
     test_partitioner.initialise(
         test_sections,
         new_partition=True,
         ratios=(0.2, 0.2),
         data_split_file=partition_output)
    def initialise_application(self, workflow_param, data_param):
        """
        This function receives all parameters from user config file,
        create an instance of application.
        :param workflow_param: a dictionary of user parameters,
        keys correspond to sections in the config file
        :param data_param: a dictionary of input image parameters,
        keys correspond to data properties to be used by image_reader
        :return:
        """
        try:
            system_param = workflow_param.get('SYSTEM', None)
            net_param = workflow_param.get('NETWORK', None)
            infer_param = workflow_param.get('INFERENCE', None)
            eval_param = workflow_param.get('EVALUATION', None)
            app_param = workflow_param.get('CUSTOM', None)
        except AttributeError:
            tf.logging.fatal('parameters should be dictionaries')
            raise
        self.num_threads = 1
        # self.num_threads = max(system_param.num_threads, 1)
        # self.num_gpus = system_param.num_gpus
        # set_cuda_device(system_param.cuda_devices)

        # set output TF model folders
        self.model_dir = touch_folder(
            os.path.join(system_param.model_dir, 'models'))
        self.session_prefix = os.path.join(self.model_dir, FILE_PREFIX)

        assert infer_param, 'inference parameters not specified'

        # create an application instance
        assert app_param, 'application specific param. not specified'
        self.app_param = app_param
        app_module = ApplicationFactory.create(app_param.name)
        self.app = app_module(net_param, infer_param,
                              system_param.action)

        self.eval_param = eval_param

        data_param, self.app_param = \
            self.app.add_inferred_output(data_param, self.app_param)
        # initialise data input
        data_partitioner = ImageSetsPartitioner()
        # clear the cached file lists
        data_partitioner.reset()
        if data_param:
            data_partitioner.initialise(
                data_param=data_param,
                new_partition=False,
                ratios=None,
                data_split_file=system_param.dataset_split_file)

        # initialise data input
        self.app.initialise_dataset_loader(data_param, self.app_param,
                                           data_partitioner)
        self.app.initialise_evaluator(eval_param)
예제 #4
0
 def test_empty(self):
     self._reset_partition_file()
     with open(partition_output, 'w') as partition_file:
         partition_file.write('')
     test_partitioner = ImageSetsPartitioner()
     with self.assertRaisesRegexp(ValueError, ""):
         test_partitioner.initialise(test_sections,
                                     new_partition=False,
                                     data_split_file=partition_output)
 def test_empty(self):
     self._reset_partition_file()
     with open(partition_output, 'w') as partition_file:
         partition_file.write('')
     test_partitioner = ImageSetsPartitioner()
     with self.assertRaisesRegexp(ValueError, ""):
         test_partitioner.initialise(
             test_sections,
             new_partition=False,
             data_split_file=partition_output)
예제 #6
0
 def test_incompatible_partition_file(self):
     self._reset_partition_file()
     # adding invalid line
     with open(partition_output, 'a') as partition_file:
         partition_file.write('foo, bar')
     test_partitioner = ImageSetsPartitioner()
     with self.assertRaisesRegexp(ValueError, ""):
         test_partitioner.initialise(test_sections,
                                     new_partition=False,
                                     data_split_file=partition_output)
 def test_incompatible_partition_file(self):
     self._reset_partition_file()
     # adding invalid line
     with open(partition_output, 'a') as partition_file:
         partition_file.write('foo, bar')
     test_partitioner = ImageSetsPartitioner()
     with self.assertRaisesRegexp(ValueError, ""):
         test_partitioner.initialise(
             test_sections,
             new_partition=False,
             data_split_file=partition_output)
예제 #8
0
def preprocess(
    input_path,
    model_path,
    output_path,
    cutoff,
):
    input_path = Path(input_path)
    output_path = Path(output_path)
    input_dir = input_path.parent

    DATA_PARAM = {
        'Modality0':
        ParserNamespace(
            path_to_search=str(input_dir),
            filename_contains=('nii.gz', ),
            interp_order=0,
            pixdim=None,
            axcodes='RAS',
            loader=None,
        )
    }

    TASK_PARAM = ParserNamespace(image=('Modality0', ))
    data_partitioner = ImageSetsPartitioner()
    file_list = data_partitioner.initialise(DATA_PARAM).get_file_list()
    reader = ImageReader(['image'])
    reader.initialise(DATA_PARAM, TASK_PARAM, file_list)

    binary_masking_func = BinaryMaskingLayer(type_str='mean_plus', )

    hist_norm = HistogramNormalisationLayer(
        image_name='image',
        modalities=['Modality0'],
        model_filename=str(model_path),
        binary_masking_func=binary_masking_func,
        cutoff=cutoff,
        name='hist_norm_layer',
    )

    image = reader.output_list[0]['image']
    data = image.get_data()
    norm_image_dict, mask_dict = hist_norm({'image': data})
    data = norm_image_dict['image']
    nii = nib.Nifti1Image(data.squeeze(), image.original_affine[0])
    dst = output_path
    nii.to_filename(str(dst))
예제 #9
0
 def read_data(self, data_param, grouping_param, data_split_file):
     # Dictionary with parameters for NiftyNet Reader
     data_param = literal_eval(data_param)
     grouping_param = literal_eval(grouping_param)
     image_sets_partitioner = ImageSetsPartitioner().initialise(
         data_param=data_param,
         data_split_file=data_split_file,
         new_partition=False)
     return data_param, grouping_param, image_sets_partitioner
    def test_no_partition_file(self):
        if os.path.isfile(partition_output):
            os.remove(partition_output)

        data_param = test_sections
        test_partitioner = ImageSetsPartitioner()
        test_partitioner.initialise(
            data_param,
            new_partition=False,
            data_split_file=partition_output)
        self.assertEquals(
            test_partitioner.get_file_list()[COLUMN_UNIQ_ID].count(), 4)
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list(TRAIN)
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list(VALID)
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list(INFER)
예제 #11
0
def main():
    opt = parsing_data()

    print("[INFO]Reading data")
    # Dictionary with data parameters for NiftyNet Reader
    if torch.cuda.is_available():
        print('[INFO] GPU available.')
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    else:
        raise Exception(
            "[INFO] No GPU found or Wrong gpu id, please run without --cuda")

    # Dictionary with data parameters for NiftyNet Reader
    data_param = {
        'image': {
            'path_to_search': opt.image_path,
            'filename_contains': 'CC'
        },
        'label': {
            'path_to_search': opt.label_path,
            'filename_contains': 'CC'
        }
    }

    image_sets_partitioner = ImageSetsPartitioner().initialise(
        data_param=data_param,
        data_split_file=opt.data_split_file,
        new_partition=False,
        ratios=opt.ratios)

    readers = {
        x: get_reader(data_param, image_sets_partitioner, x)
        for x in ['training', 'validation', 'inference']
    }
    samplers = {
        x: get_sampler(readers[x], opt.patch_size, x)
        for x in ['training', 'validation', 'inference']
    }

    # Training stage only
    dsets = {
        x: dset_utils.DatasetNiftySampler(sampler=samplers[x])
        for x in ['training', 'validation']
    }

    print("[INFO] Building model")
    model = cnn_utils.UNet3D(opt.in_channels, opt.n_classes)
    criterion = loss_utils.SoftDiceLoss()
    optimizer = optim.RMSprop(model.parameters(), lr=opt.lr)

    print("[INFO] Training")
    train(dsets, model, criterion, optimizer, opt.num_epochs, device,
          opt.cp_path, opt.batch_size)

    print("[INFO] Inference")
    inference(samplers['inference'], model, device, opt.pred_path, opt.cp_path)
예제 #12
0
    def __init__(self):
        self.app = None

        self.is_training_action = True
        self.num_threads = 0
        self.num_gpus = 0
        self.model_dir = None

        self.max_checkpoints = 2
        self.save_every_n = 0
        self.tensorboard_every_n = -1

        self.initial_iter = 0
        self.final_iter = 0
        self.validation_every_n = -1
        self.validation_max_iter = 1

        self.data_partitioner = ImageSetsPartitioner()

        self._event_handlers = None
        self._generator = None
예제 #13
0
    def initialise_application(self, workflow_param, data_param):
        """
        This function receives all parameters from user config file,
        create an instance of application.
        :param workflow_param: a dictionary of user parameters,
        keys correspond to sections in the config file
        :param data_param: a dictionary of input image parameters,
        keys correspond to data properties to be used by image_reader
        :return:
        """
        try:
            system_param = workflow_param.get('SYSTEM', None)
            net_param = workflow_param.get('NETWORK', None)
            infer_param = workflow_param.get('INFERENCE', None)
            eval_param = workflow_param.get('EVALUATION', None)
            app_param = workflow_param.get('CUSTOM', None)
        except AttributeError:
            tf.logging.fatal('parameters should be dictionaries')
            raise
        self.num_threads = 1
        # self.num_threads = max(system_param.num_threads, 1)
        # self.num_gpus = system_param.num_gpus
        # set_cuda_device(system_param.cuda_devices)

        # set output TF model folders
        self.model_dir = touch_folder(
            os.path.join(system_param.model_dir, 'models'))
        self.session_prefix = os.path.join(self.model_dir, FILE_PREFIX)

        assert infer_param, 'inference parameters not specified'

        # create an application instance
        assert app_param, 'application specific param. not specified'
        self.app_param = app_param
        app_module = ApplicationFactory.create(app_param.name)
        self.app = app_module(net_param, infer_param, system_param.action)

        self.eval_param = eval_param

        data_param, self.app_param = \
            self.app.add_inferred_output(data_param, self.app_param)
        # initialise data input
        data_partitioner = ImageSetsPartitioner()
        # clear the cached file lists
        data_partitioner.reset()
        if data_param:
            data_partitioner.initialise(
                data_param=data_param,
                new_partition=False,
                ratios=None,
                data_split_file=system_param.dataset_split_file)

        # initialise data input
        self.app.initialise_dataset_loader(data_param, self.app_param,
                                           data_partitioner)
        self.app.initialise_evaluator(eval_param)
예제 #14
0
 def test_replicated_ids(self):
     self._reset_partition_file()
     with open(partition_output, 'a') as partition_file:
         partition_file.write('1065,Training\n')
         partition_file.write('1065,Validation')
     test_partitioner = ImageSetsPartitioner()
     test_partitioner.initialise(test_sections,
                                 new_partition=False,
                                 data_split_file=partition_output)
     self.assertEquals(
         test_partitioner.get_file_list()[COLUMN_UNIQ_ID].count(), 4)
     self.assertEquals(
         test_partitioner.get_file_list(TRAIN)[COLUMN_UNIQ_ID].count(), 3)
     self.assertEquals(
         test_partitioner.get_file_list(VALID)[COLUMN_UNIQ_ID].count(), 2)
     self.assertEquals(
         test_partitioner.get_file_list(INFER)[COLUMN_UNIQ_ID].count(), 1)
 def test_replicated_ids(self):
     self._reset_partition_file()
     with open(partition_output, 'a') as partition_file:
         partition_file.write('1065,Training\n')
         partition_file.write('1065,Validation')
     test_partitioner = ImageSetsPartitioner()
     test_partitioner.initialise(
         test_sections,
         new_partition=False,
         data_split_file=partition_output)
     self.assertEquals(
         test_partitioner.get_file_list()[COLUMN_UNIQ_ID].count(), 4)
     self.assertEquals(
         test_partitioner.get_file_list(TRAIN)[COLUMN_UNIQ_ID].count(), 3)
     self.assertEquals(
         test_partitioner.get_file_list(VALID)[COLUMN_UNIQ_ID].count(), 2)
     self.assertEquals(
         test_partitioner.get_file_list(INFER)[COLUMN_UNIQ_ID].count(), 1)
예제 #16
0
    def test_no_partition_file(self):
        if os.path.isfile(partition_output):
            os.remove(partition_output)

        data_param = test_sections
        test_partitioner = ImageSetsPartitioner()
        test_partitioner.initialise(data_param,
                                    new_partition=False,
                                    data_split_file=partition_output)
        self.assertEquals(
            test_partitioner.get_file_list()[COLUMN_UNIQ_ID].count(), 4)
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list(TRAIN)
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list(VALID)
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list(INFER)
예제 #17
0
LABEL_TASK = ParserNamespace(label=('parcellation',))

BAD_DATA = {
    'lesion': ParserNamespace(
        csv_file=os.path.join('testing_data', 'lesion.csv'),
        path_to_search='testing_data',
        filename_contains=('Lesion',),
        filename_not_contains=('Parcellation',),
        pixdim=None,
        axcodes=None
    )
}
BAD_TASK = ParserNamespace(image=('test',))

# default data_partitioner
data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()
single_mod_list = data_partitioner.initialise(SINGLE_MOD_DATA).get_file_list()
existing_list = data_partitioner.initialise(EXISTING_DATA).get_file_list()
label_list = data_partitioner.initialise(LABEL_DATA).get_file_list()
bad_data_list = data_partitioner.initialise(BAD_DATA).get_file_list()


class ImageReaderTest(tf.test.TestCase):
    def test_initialisation(self):
        with self.assertRaisesRegexp(ValueError, ''):
            reader = ImageReader(['test'])
            reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
        with self.assertRaisesRegexp(AssertionError, ''):
            reader = ImageReader(None)
            # reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
    ),
    'FLAIR': ParserNamespace(
        csv_file=os.path.join('testing_data', 'FLAIRsampler.csv'),
        path_to_search='testing_data',
        filename_contains=('FLAIR_',),
        filename_not_contains=('Parcellation',),
        interp_order=3,
        pixdim=None,
        axcodes=None,
        spatial_window_size=(7, 10, 2),
        loader=None
    )
}
MULTI_MOD_TASK = ParserNamespace(image=('T1', 'FLAIR'))

data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()


def get_3d_reader():
    reader = ImageReader(['image'])
    reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
    return reader


class LinearInterpolateSamplerTest(tf.test.TestCase):
    def test_init(self):
        sampler = LinearInterpolateSampler(
            reader=get_3d_reader(),
            data_param=MULTI_MOD_DATA,
            batch_size=1,
예제 #19
0
}
LABEL_TASK = ParserNamespace(label=('parcellation', ))

BAD_DATA = {
    'lesion':
    ParserNamespace(csv_file=os.path.join('testing_data', 'lesion.csv'),
                    path_to_search='testing_data',
                    filename_contains=('Lesion', ),
                    filename_not_contains=('Parcellation', ),
                    pixdim=None,
                    axcodes=None)
}
BAD_TASK = ParserNamespace(image=('test', ))

# default data_partitioner
data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()
single_mod_list = data_partitioner.initialise(SINGLE_MOD_DATA).get_file_list()
existing_list = data_partitioner.initialise(EXISTING_DATA).get_file_list()
label_list = data_partitioner.initialise(LABEL_DATA).get_file_list()
bad_data_list = data_partitioner.initialise(BAD_DATA).get_file_list()


class ImageReaderTest(tf.test.TestCase):
    def test_initialisation(self):
        with self.assertRaisesRegexp(ValueError, ''):
            reader = ImageReader(['test'])
            reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
        with self.assertRaisesRegexp(AssertionError, ''):
            reader = ImageReader(None)
            # reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
예제 #20
0
SINGLE_25D_DATA = {
    'T1':
    ParserNamespace(csv_file=os.path.join('testing_data', 'T1sampler.csv'),
                    path_to_search='testing_data',
                    filename_contains=('_o_T1_time', '106'),
                    filename_not_contains=('Parcellation', ),
                    interp_order=3,
                    pixdim=(3.0, 5.0, 5.0),
                    axcodes='LAS',
                    spatial_window_size=(40, 30, 1),
                    loader=None),
}
SINGLE_25D_TASK = ParserNamespace(image=('T1', ))

data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(
    MULTI_MOD_DATA,
    data_split_file='testing_data/resize_split.csv').get_file_list()
mod_2d_list = data_partitioner.initialise(
    MOD_2D_DATA,
    data_split_file='testing_data/resize_split.csv').get_file_list()
mod_label_list = data_partitioner.initialise(
    MOD_LABEL_DATA,
    data_split_file='testing_data/resize_split.csv').get_file_list()
single_25d_list = data_partitioner.initialise(
    SINGLE_25D_DATA,
    data_split_file='testing_data/resize_split.csv').get_file_list()


def get_3d_reader():
    def test_new_partition(self):
        data_param = test_sections
        test_partitioner = ImageSetsPartitioner()
        with self.assertRaisesRegexp(TypeError, ''):
            test_partitioner.initialise(
                data_param,
                new_partition=True,
                data_split_file=partition_output)
        test_partitioner.initialise(
            data_param,
            new_partition=True,
            ratios=(2.0, 2.0),
            data_split_file=partition_output)
        self.assertEquals(
            test_partitioner.get_file_list()[COLUMN_UNIQ_ID].count(), 4)
        self.assertEquals(
            test_partitioner.get_file_list(TRAIN), None)
        self.assertEquals(
            test_partitioner.get_file_list(VALID)[COLUMN_UNIQ_ID].count(), 4)
        self.assertEquals(
            test_partitioner.get_file_list(INFER), None)
        self.assertEquals(
            test_partitioner.get_file_list(
                VALID, 'T1', 'Flair')[COLUMN_UNIQ_ID].count(), 4)
        self.assertEquals(
            test_partitioner.get_file_list(
                VALID, 'Flair')[COLUMN_UNIQ_ID].count(), 4)
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list(VALID, 'foo')
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list('T1')

        self.assertFalse(test_partitioner.has_training)
        self.assertFalse(test_partitioner.has_inference)
        self.assertTrue(test_partitioner.has_validation)
예제 #22
0
                    filename_not_contains=('Parcellation', ),
                    interp_order=3,
                    pixdim=None,
                    axcodes=None),
    'FLAIR':
    ParserNamespace(csv_file=os.path.join('testing_data', 'FLAIR.csv'),
                    path_to_search='testing_data',
                    filename_contains=('FLAIR_', ),
                    filename_not_contains=('Parcellation', ),
                    interp_order=3,
                    pixdim=None,
                    axcodes=None)
}
TASK_PARAM = ParserNamespace(image=('T1', 'FLAIR'))
MODEL_FILE = os.path.join('testing_data', 'std_models.txt')
data_partitioner = ImageSetsPartitioner()
file_list = data_partitioner.initialise(DATA_PARAM).get_file_list()


# @unittest.skipIf(os.environ.get('QUICKTEST', "").lower() == "true", 'Skipping slow tests')
class HistTest(tf.test.TestCase):
    def test_volume_loader(self):
        expected_T1 = np.array([
            0.0, 8.24277910972, 21.4917343731, 27.0551695202, 32.6186046672,
            43.5081573038, 53.3535675285, 61.9058849776, 70.0929786194,
            73.9944243858, 77.7437509974, 88.5331971492, 100.0
        ])
        expected_FLAIR = np.array([
            0.0, 5.36540863446, 15.5386130103, 20.7431912042, 26.1536608309,
            36.669150376, 44.7821246138, 50.7930589961, 56.1703089214,
            59.2393548654, 63.1565641037, 78.7271261392, 100.0
SINGLE_25D_DATA = {
    'T1': ParserNamespace(
        csv_file=os.path.join('testing_data', 'T1sampler.csv'),
        path_to_search='testing_data',
        filename_contains=('_o_T1_time', '106'),
        filename_not_contains=('Parcellation',),
        interp_order=3,
        pixdim=(3.0, 5.0, 5.0),
        axcodes='LAS',
        spatial_window_size=(40, 30, 1)
    ),
}
SINGLE_25D_TASK = ParserNamespace(image=('T1',))

data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()
mod_2d_list = data_partitioner.initialise(MOD_2D_DATA).get_file_list()
mod_label_list = data_partitioner.initialise(MOD_LABEL_DATA).get_file_list()
single_25d_list = data_partitioner.initialise(SINGLE_25D_DATA).get_file_list()


def get_3d_reader():
    reader = ImageReader(['image'])
    reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
    return reader


def get_2d_reader():
    reader = ImageReader(['image'])
    reader.initialise(MOD_2D_DATA, MOD_2D_TASK, mod_2d_list)
예제 #24
0
    def initialise(self, data_param, task_param=None, file_list=None):
        """
        ``task_param`` specifies how to combine user input modalities.
        e.g., for multimodal segmentation 'image' corresponds to multiple
        modality sections, 'label' corresponds to one modality section

        This function converts elements of ``file_list`` into
        dictionaries of image objects, and save them to ``self.output_list``.
        e.g.::

             data_param = {'T1': {'path_to_search': 'path/to/t1'}
                           'T2': {'path_to_search': 'path/to/t2'}}

        loads pairs of T1 and T1 images (grouped by matching the filename).
        The reader's output is in the form of
        ``{'T1': np.array, 'T2': np.array}``.
        If the (optional) ``task_param`` is specified::

             task_param = {'image': ('T1', 'T2')}

        the reader loads pairs of T1 and T2 and returns the concatenated
        image (both modalities should have the same spatial dimensions).
        The reader's output is in the form of ``{'image': np.array}``.


        :param data_param: dictionary of input sections
        :param task_param: dictionary of grouping
        :param file_list: a dataframe generated by ImagePartitioner
            for cross validation, so
            that the reader only loads files in training/inference phases.
        :return: the initialised reader instance
        """
        data_param = param_to_dict(data_param)

        if not task_param:
            task_param = {mod: (mod,) for mod in list(data_param)}
        try:
            if not isinstance(task_param, dict):
                task_param = vars(task_param)
        except ValueError:
            tf.logging.fatal(
                "To concatenate multiple input data arrays,\n"
                "task_param should be a dictionary in the form:\n"
                "{'new_modality_name': ['modality_1', 'modality_2',...]}.")
            raise
        if file_list is None:
            # defaulting to all files detected by the input specification
            file_list = ImageSetsPartitioner().initialise(data_param).all_files
        if not self.names:
            # defaulting to load all sections defined in the task_param
            self.names = list(task_param)
        valid_names = [name for name in self.names
                       if task_param.get(name, None)]
        if not valid_names:
            tf.logging.fatal("Reader requires task input keywords %s, but "
                             "not exist in the config file.\n"
                             "Available task keywords: %s",
                             self.names, list(task_param))
            raise ValueError
        self.names = valid_names

        self._input_sources = dict((name, task_param.get(name))
                                   for name in self.names)
        required_sections = \
            sum([list(task_param.get(name)) for name in self.names], [])

        for required in required_sections:
            try:
                if (file_list is None) or \
                        (required not in list(file_list)) or \
                        (file_list[required].isnull().all()):
                    tf.logging.fatal('Reader required input section '
                                     'name [%s], but in the filename list '
                                     'the column is empty.', required)
                    raise ValueError
            except (AttributeError, TypeError, ValueError):
                tf.logging.fatal(
                    'file_list parameter should be a '
                    'pandas.DataFrame instance and has input '
                    'section name [%s] as a column name.', required)
                if required_sections:
                    tf.logging.fatal('Reader requires section(s): %s',
                                     required_sections)
                if file_list is not None:
                    tf.logging.fatal('Configuration input sections are: %s',
                                     list(file_list))
                raise

        self.output_list, self._file_list = _filename_to_image_list(
            file_list, self._input_sources, data_param)
        for name in self.names:
            tf.logging.info(
                'Image reader: loading %d subjects '
                'from sections %s as input [%s]',
                len(self.output_list), self.input_sources[name], name)
        return self
예제 #25
0
# LABEL_TASK = {
#     'Parcellation': ParserNamespace(
#         csv_file=os.path.join('testing_data', 'labels.csv'),
#         path_to_search='testing_data',
#         filename_contains=('Parcellation',),
#         filename_not_constains=('FLAIR_',),
#         interp_order=1,
#         pixdim=None,
#         axcodes=None,
#         spatial_window_size=(8,2)
#     )
# }

DYNAMIC_MOD_TASK = ParserNamespace(image=('T1', 'FLAIR'), label=('Label', ))

data_partitioner = ImageSetsPartitioner()


def get_3d_reader():
    multi_mod_list = data_partitioner.initialise(
        MULTI_MOD_DATA).get_file_list()
    print(MULTI_MOD_DATA, MULTI_MOD_TASK)
    reader = ImageReader(['image', 'label'])
    reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
    return reader


def get_2d_reader():
    mod_2d_list = data_partitioner.initialise(MOD_2D_DATA).get_file_list()
    reader = ImageReader(['image'])
    reader.initialise(MOD_2D_DATA, MOD_2D_TASK, mod_2d_list)
SINGLE_25D_DATA = {
    'T1':
    ParserNamespace(csv_file=os.path.join('testing_data', 'T1sampler.csv'),
                    path_to_search='testing_data',
                    filename_contains=('_o_T1_time', '106'),
                    filename_not_contains=('Parcellation', ),
                    interp_order=0,
                    pixdim=(3.0, 5.0, 5.0),
                    axcodes='LAS',
                    spatial_window_size=(40, 30, 1),
                    loader=None),
}
SINGLE_25D_TASK = ParserNamespace(image=('T1', ))

data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()
mod_2d_list = data_partitioner.initialise(MOD_2D_DATA).get_file_list()
mod_label_list = data_partitioner.initialise(MOD_LABEL_DATA).get_file_list()
single_25d_list = data_partitioner.initialise(SINGLE_25D_DATA).get_file_list()


def get_3d_reader():
    '''
    define the 3d reader
    :return: 3d reader
    '''
    reader = ImageReader(['image'])
    reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
    return reader
예제 #27
0
    ),
    'FLAIR': ParserNamespace(
        csv_file=os.path.join('testing_data', 'FLAIRsampler.csv'),
        path_to_search='testing_data',
        filename_contains=('FLAIR_',),
        filename_not_contains=('Parcellation',),
        interp_order=3,
        pixdim=None,
        axcodes=None,
        spatial_window_size=(8, 2)
    )
}
DYNAMIC_MOD_TASK = ParserNamespace(image=('T1', 'FLAIR'),
                                   sampler=('FLAIR',))

data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()
mod_2d_list = data_partitioner.initialise(MOD_2D_DATA).get_file_list()
dynamic_list = data_partitioner.initialise(DYNAMIC_MOD_DATA).get_file_list()


def get_3d_reader():
    reader = ImageReader(['image', 'sampler'])
    reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
    return reader


def get_2d_reader():
    reader = ImageReader(['image', 'sampler'])
    reader.initialise(MOD_2D_DATA, MOD_2D_TASK, mod_2d_list)
    return reader
예제 #28
0
    def test_new_partition(self):
        data_param = test_sections
        test_partitioner = ImageSetsPartitioner()
        with self.assertRaisesRegexp(TypeError, ''):
            test_partitioner.initialise(data_param,
                                        new_partition=True,
                                        data_split_file=partition_output)
        test_partitioner.initialise(data_param,
                                    new_partition=True,
                                    ratios=(2.0, 2.0),
                                    data_split_file=partition_output)
        self.assertEquals(
            test_partitioner.get_file_list()[COLUMN_UNIQ_ID].count(), 4)
        self.assertEquals(test_partitioner.get_file_list(TRAIN), None)
        self.assertEquals(
            test_partitioner.get_file_list(VALID)[COLUMN_UNIQ_ID].count(), 4)
        self.assertEquals(test_partitioner.get_file_list(INFER), None)
        self.assertEquals(
            test_partitioner.get_file_list(VALID, 'T1',
                                           'Flair')[COLUMN_UNIQ_ID].count(), 4)
        self.assertEquals(
            test_partitioner.get_file_list(VALID,
                                           'Flair')[COLUMN_UNIQ_ID].count(), 4)
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list(VALID, 'foo')
        with self.assertRaisesRegexp(ValueError, ''):
            test_partitioner.get_file_list('T1')

        self.assertFalse(test_partitioner.has_training)
        self.assertFalse(test_partitioner.has_inference)
        self.assertTrue(test_partitioner.has_validation)
예제 #29
0
class ApplicationDriver(object):
    """
    This class initialises an application by building a TF graph,
    and maintaining a session. It controls the
    starting/stopping of an application. Applications should be
    implemented by inheriting ``niftynet.application.base_application``
    to be compatible with this driver.
    """
    def __init__(self):
        self.app = None

        self.is_training_action = True
        self.num_threads = 0
        self.num_gpus = 0
        self.model_dir = None

        self.max_checkpoints = 2
        self.save_every_n = 0
        self.tensorboard_every_n = -1

        self.initial_iter = 0
        self.final_iter = 0
        self.validation_every_n = -1
        self.validation_max_iter = 1

        self.data_partitioner = ImageSetsPartitioner()

        self._event_handlers = None
        self._generator = None

    def initialise_application(self, workflow_param, data_param=None):
        """
        This function receives all parameters from user config file,
        create an instance of application.

        :param workflow_param: a dictionary of user parameters,
            keys correspond to sections in the config file
        :param data_param: a dictionary of input image parameters,
            keys correspond to data properties to be used by image_reader
        :return:
        """
        try:
            system_param = workflow_param.get('SYSTEM', None)
            net_param = workflow_param.get('NETWORK', None)
            train_param = workflow_param.get('TRAINING', None)
            infer_param = workflow_param.get('INFERENCE', None)
            app_param = workflow_param.get('CUSTOM', None)
        except AttributeError:
            tf.logging.fatal('parameters should be dictionaries')
            raise

        assert os.path.exists(system_param.model_dir), \
            'Model folder not exists {}'.format(system_param.model_dir)
        self.model_dir = system_param.model_dir

        self.is_training_action = TRAIN.startswith(system_param.action.lower())
        # hardware-related parameters
        self.num_threads = max(system_param.num_threads, 1) \
            if self.is_training_action else 1
        self.num_gpus = system_param.num_gpus \
            if self.is_training_action else min(system_param.num_gpus, 1)
        set_cuda_device(system_param.cuda_devices)

        # set training params.
        if self.is_training_action:
            assert train_param, 'training parameters not specified'
            self.initial_iter = train_param.starting_iter
            self.final_iter = max(train_param.max_iter, self.initial_iter)
            self.save_every_n = train_param.save_every_n
            self.tensorboard_every_n = train_param.tensorboard_every_n
            self.max_checkpoints = max(self.max_checkpoints,
                                       train_param.max_checkpoints)
            self.validation_every_n = train_param.validation_every_n
            if self.validation_every_n > 0:
                self.validation_max_iter = max(self.validation_max_iter,
                                               train_param.validation_max_iter)
            action_param = train_param
        else:  # set inference params.
            assert infer_param, 'inference parameters not specified'
            self.initial_iter = infer_param.inference_iter
            action_param = infer_param

        # infer the initial iteration from model files
        if self.initial_iter < 0:
            self.initial_iter = infer_latest_model_file(
                os.path.join(self.model_dir, 'models'))

        # create an application instance
        assert app_param, 'application specific param. not specified'
        app_module = ApplicationFactory.create(app_param.name)
        self.app = app_module(net_param, action_param, system_param.action)

        # clear the cached file lists
        self.data_partitioner.reset()
        if data_param:
            do_new_partition = \
                self.is_training_action and self.initial_iter == 0 and \
                (not os.path.isfile(system_param.dataset_split_file)) and \
                (train_param.exclude_fraction_for_validation > 0 or
                 train_param.exclude_fraction_for_inference > 0)
            data_fractions = (train_param.exclude_fraction_for_validation,
                              train_param.exclude_fraction_for_inference) \
                if do_new_partition else None

            self.data_partitioner.initialise(
                data_param=data_param,
                new_partition=do_new_partition,
                ratios=data_fractions,
                data_split_file=system_param.dataset_split_file)
            assert self.data_partitioner.has_validation or \
                self.validation_every_n <= 0, \
                'validation_every_n is set to {}, ' \
                'but train/validation splitting not available.\nPlease ' \
                'check dataset partition list {} ' \
                '(remove file to generate a new dataset partition), ' \
                'check "exclude_fraction_for_validation" ' \
                '(current config value: {}).\nAlternatively, ' \
                'set "validation_every_n" to -1.'.format(
                    self.validation_every_n,
                    system_param.dataset_split_file,
                    train_param.exclude_fraction_for_validation)

        # initialise readers
        self.app.initialise_dataset_loader(data_param, app_param,
                                           self.data_partitioner)

        # make the list of initialised event handler instances.
        self.load_event_handlers(system_param.event_handler
                                 or DEFAULT_EVENT_HANDLERS)
        self._generator = IteratorFactory.create(
            system_param.iteration_generator or DEFAULT_ITERATION_GENERATOR)

    def run(self, application, graph=None):
        """
        Initialise a TF graph, connect data sampler and network within
        the graph context, run training loops or inference loops.

        :param application: a niftynet application
        :param graph: default base graph to run the application
        :return:
        """
        if graph is None:
            graph = ApplicationDriver.create_graph(
                application=application,
                num_gpus=self.num_gpus,
                num_threads=self.num_threads,
                is_training_action=self.is_training_action)

        start_time = time.time()
        loop_status = {'current_iter': self.initial_iter, 'normal_exit': False}

        with tf.Session(config=tf_config(), graph=graph):
            try:
                # broadcasting event of session started
                SESS_STARTED.send(application, iter_msg=None)

                # create a iteration message generator and
                # iteratively run the graph (the main engine loop)
                iteration_messages = self._generator(**vars(self))()
                ApplicationDriver.loop(application=application,
                                       iteration_messages=iteration_messages,
                                       loop_status=loop_status)

            except KeyboardInterrupt:
                tf.logging.warning('User cancelled application')
            except (tf.errors.OutOfRangeError, EOFError):
                if not loop_status.get('normal_exit', False):
                    # reached the end of inference Dataset
                    loop_status['normal_exit'] = True
            except RuntimeError:
                import sys
                import traceback
                exc_type, exc_value, exc_traceback = sys.exc_info()
                traceback.print_exception(exc_type,
                                          exc_value,
                                          exc_traceback,
                                          file=sys.stdout)
            finally:
                tf.logging.info('cleaning up...')
                # broadcasting session finished event
                iter_msg = IterationMessage()
                iter_msg.current_iter = loop_status.get('current_iter', -1)
                SESS_FINISHED.send(application, iter_msg=iter_msg)

        application.stop()
        if not loop_status.get('normal_exit', False):
            # loop didn't finish normally
            tf.logging.warning('stopped early, incomplete iterations.')
        tf.logging.info("%s stopped (time in second %.2f).",
                        type(application).__name__, (time.time() - start_time))

    # pylint: disable=not-context-manager
    @staticmethod
    def create_graph(application,
                     num_gpus=1,
                     num_threads=1,
                     is_training_action=False):
        """
        Create a TF graph based on self.app properties
        and engine parameters.

        :return:
        """
        graph = tf.Graph()
        main_device = device_string(num_gpus, 0, False, is_training_action)
        outputs_collector = OutputsCollector(n_devices=max(num_gpus, 1))
        gradients_collector = GradientsCollector(n_devices=max(num_gpus, 1))
        # start constructing the graph, handling training and inference cases
        with graph.as_default(), tf.device(main_device):
            # initialise sampler
            with tf.name_scope('Sampler'):
                application.initialise_sampler()
                for sampler in traverse_nested(application.get_sampler()):
                    sampler.set_num_threads(num_threads)

            # initialise network, these are connected in
            # the context of multiple gpus
            application.initialise_network()
            application.add_validation_flag()

            # for data parallelism --
            #     defining and collecting variables from multiple devices
            for gpu_id in range(0, max(num_gpus, 1)):
                worker_device = device_string(num_gpus, gpu_id, True,
                                              is_training_action)
                scope_string = 'worker_{}'.format(gpu_id)
                with tf.name_scope(scope_string), tf.device(worker_device):
                    # setup network for each of the multiple devices
                    application.connect_data_and_network(
                        outputs_collector, gradients_collector)
            with tf.name_scope('MergeOutputs'):
                outputs_collector.finalise_output_op()
            application.outputs_collector = outputs_collector
            application.gradients_collector = gradients_collector
            GRAPH_CREATED.send(application, iter_msg=None)
        return graph

    def load_event_handlers(self, names):
        """
        Import event handler modules and create a list of handler instances.
        The event handler instances will be stored with this engine.

        :param names: strings of event handlers
        :return:
        """
        if not names:
            return
        if self._event_handlers:
            # disconnect all handlers (assuming always weak connection)
            for handler in list(self._event_handlers):
                del self._event_handlers[handler]
        self._event_handlers = {}
        for name in set(names):
            the_event_class = EventHandlerFactory.create(name)
            # initialise all registered event handler classes
            engine_config_dict = vars(self)
            key = '{}'.format(the_event_class)
            self._event_handlers[key] = the_event_class(**engine_config_dict)

    @staticmethod
    def loop(application, iteration_messages=(), loop_status=None):
        """
        Running ``loop_step`` with ``IterationMessage`` instances
        generated by ``iteration_generator``.

        This loop stops when any of the condition satisfied:
            1. no more element from the ``iteration_generator``;
            2. ``application.interpret_output`` returns False;
            3. any exception raised.

        Broadcasting SESS_* signals at the beginning and end of this method.

        This function should be used in a context of
        ``tf.Session`` or ``session.as_default()``.

        :param application: a niftynet.application instance, application
            will provides ``tensors`` to be fetched by ``tf.session.run()``.
        :param iteration_messages:
            a generator of ``engine.IterationMessage`` instances
        :param loop_status: optional dictionary used to capture the loop status,
            useful when the loop exited in an unexpected manner.
        :return:
        """
        loop_status = loop_status or {}
        for iter_msg in iteration_messages:
            loop_status['current_iter'] = iter_msg.current_iter

            # run an iteration
            ApplicationDriver.loop_step(application, iter_msg)

            # Checking stopping conditions
            if iter_msg.should_stop:
                tf.logging.info('stopping -- event handler: %s.',
                                iter_msg.should_stop)
                break
        # loop finished without any exception
        loop_status['normal_exit'] = True

    @staticmethod
    def loop_step(application, iteration_message):
        """
        Calling ``tf.session.run`` with parameters encapsulated in
        iteration message as an iteration.
        Broadcasting ITER_* events before and afterward.

        :param application:
        :param iteration_message: an ``engine.IterationMessage`` instances
        :return:
        """
        # broadcasting event of starting an iteration
        ITER_STARTED.send(application, iter_msg=iteration_message)

        # ``iter_msg.ops_to_run`` are populated with the ops to run in
        # each iteration, fed into ``session.run()`` and then
        # passed to the application (and observers) for interpretation.
        sess = tf.get_default_session()
        assert sess, 'method should be called within a TF session context.'

        iteration_message.current_iter_output = sess.run(
            iteration_message.ops_to_run,
            feed_dict=iteration_message.data_feed_dict)

        # broadcasting event of finishing an iteration
        ITER_FINISHED.send(application, iter_msg=iteration_message)
                    spatial_window_size=(7, 10, 2),
                    loader=None),
    'FLAIR':
    ParserNamespace(csv_file=os.path.join('testing_data', 'FLAIRsampler.csv'),
                    path_to_search='testing_data',
                    filename_contains=('FLAIR_', ),
                    filename_not_contains=('Parcellation', ),
                    interp_order=3,
                    pixdim=None,
                    axcodes=None,
                    spatial_window_size=(7, 10, 2),
                    loader=None)
}
MULTI_MOD_TASK = ParserNamespace(image=('T1', 'FLAIR'))

data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()


def get_3d_reader():
    reader = ImageReader(['image'])
    reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
    return reader


class LinearInterpolateSamplerTest(tf.test.TestCase):
    def test_init(self):
        sampler = LinearInterpolateSampler(reader=get_3d_reader(),
                                           data_param=MULTI_MOD_DATA,
                                           batch_size=1,
                                           n_interpolations=8,
        pixdim=None,
        axcodes=None
    ),
    'FLAIR': ParserNamespace(
        csv_file=os.path.join('testing_data', 'FLAIR.csv'),
        path_to_search='testing_data',
        filename_contains=('FLAIR_',),
        filename_not_contains=('Parcellation',),
        interp_order=3,
        pixdim=None,
        axcodes=None
    )
}
TASK_PARAM = ParserNamespace(image=('T1', 'FLAIR'))
MODEL_FILE = os.path.join('testing_data', 'std_models.txt')
data_partitioner = ImageSetsPartitioner()
file_list = data_partitioner.initialise(DATA_PARAM).get_file_list()


# @unittest.skipIf(os.environ.get('QUICKTEST', "").lower() == "true", 'Skipping slow tests')
class HistTest(tf.test.TestCase):
    def test_volume_loader(self):
        expected_T1 = np.array(
            [0.0, 8.24277910972, 21.4917343731,
             27.0551695202, 32.6186046672, 43.5081573038,
             53.3535675285, 61.9058849776, 70.0929786194,
             73.9944243858, 77.7437509974, 88.5331971492,
             100.0])
        expected_FLAIR = np.array(
            [0.0, 5.36540863446, 15.5386130103,
             20.7431912042, 26.1536608309, 36.669150376,
예제 #32
0
    def initialise_application(self, workflow_param, data_param):
        """
        This function receives all parameters from user config file,
        create an instance of application.

        :param workflow_param: a dictionary of user parameters,
            keys correspond to sections in the config file
        :param data_param: a dictionary of input image parameters,
            keys correspond to data properties to be used by image_reader
        :return:
        """
        try:
            system_param = workflow_param.get('SYSTEM', None)
            net_param = workflow_param.get('NETWORK', None)
            train_param = workflow_param.get('TRAINING', None)
            infer_param = workflow_param.get('INFERENCE', None)
            app_param = workflow_param.get('CUSTOM', None)
        except AttributeError:
            tf.logging.fatal('parameters should be dictionaries')
            raise

        assert os.path.exists(system_param.model_dir), \
            'Model folder not exists {}'.format(system_param.model_dir)
        self.is_training = (system_param.action == "train")
        # hardware-related parameters
        self.num_threads = max(system_param.num_threads, 1) \
            if self.is_training else 1
        self.num_gpus = system_param.num_gpus \
            if self.is_training else min(system_param.num_gpus, 1)
        set_cuda_device(system_param.cuda_devices)

        # set output TF model folders
        self.model_dir = touch_folder(
            os.path.join(system_param.model_dir, 'models'))
        self.session_prefix = os.path.join(self.model_dir, FILE_PREFIX)

        # set training params.
        if self.is_training:
            assert train_param, 'training parameters not specified'
            summary_root = os.path.join(system_param.model_dir, 'logs')
            self.summary_dir = get_latest_subfolder(
                summary_root,
                create_new=train_param.starting_iter == 0)

            self.initial_iter = train_param.starting_iter
            self.final_iter = max(train_param.max_iter, self.initial_iter)
            self.save_every_n = train_param.save_every_n
            self.tensorboard_every_n = train_param.tensorboard_every_n
            self.max_checkpoints = \
                max(train_param.max_checkpoints, self.max_checkpoints)
            self.gradients_collector = GradientsCollector(
                n_devices=max(self.num_gpus, 1))
            self.validation_every_n = train_param.validation_every_n
            if self.validation_every_n > 0:
                self.validation_max_iter = max(self.validation_max_iter,
                                               train_param.validation_max_iter)
            action_param = train_param
        else: # set inference params.
            assert infer_param, 'inference parameters not specified'
            self.initial_iter = infer_param.inference_iter
            action_param = infer_param

        self.outputs_collector = OutputsCollector(
            n_devices=max(self.num_gpus, 1))

        # create an application instance
        assert app_param, 'application specific param. not specified'
        app_module = ApplicationDriver._create_app(app_param.name)
        self.app = app_module(net_param, action_param, system_param.action)

        # initialise data input
        data_partitioner = ImageSetsPartitioner()
        # clear the cached file lists
        data_partitioner.reset()
        do_new_partition = \
            self.is_training and self.initial_iter == 0 and \
            (not os.path.isfile(system_param.dataset_split_file)) and \
            (train_param.exclude_fraction_for_validation > 0 or
             train_param.exclude_fraction_for_inference > 0)
        data_fractions = None
        if do_new_partition:
            assert train_param.exclude_fraction_for_validation > 0 or \
                   self.validation_every_n <= 0, \
                'validation_every_n is set to {}, ' \
                'but train/validation splitting not available,\nplease ' \
                'check "exclude_fraction_for_validation" in the config ' \
                'file (current config value: {}).'.format(
                    self.validation_every_n,
                    train_param.exclude_fraction_for_validation)
            data_fractions = (train_param.exclude_fraction_for_validation,
                              train_param.exclude_fraction_for_inference)

        if data_param:
            data_partitioner.initialise(
                data_param=data_param,
                new_partition=do_new_partition,
                ratios=data_fractions,
                data_split_file=system_param.dataset_split_file)

        if data_param and self.is_training and self.validation_every_n > 0:
            assert data_partitioner.has_validation, \
                'validation_every_n is set to {}, ' \
                'but train/validation splitting not available.\nPlease ' \
                'check dataset partition list {} ' \
                '(remove file to generate a new dataset partition). ' \
                'Or set validation_every_n to -1.'.format(
                    self.validation_every_n, system_param.dataset_split_file)

        # initialise readers
        self.app.initialise_dataset_loader(
            data_param, app_param, data_partitioner)

        self._data_partitioner = data_partitioner

        # pylint: disable=not-context-manager
        with self.graph.as_default(), tf.name_scope('Sampler'):
            self.app.initialise_sampler()
예제 #33
0
    def initialise_application(self, workflow_param, data_param):
        """
        This function receives all parameters from user config file,
        create an instance of application.

        :param workflow_param: a dictionary of user parameters,
            keys correspond to sections in the config file
        :param data_param: a dictionary of input image parameters,
            keys correspond to data properties to be used by image_reader
        :return:
        """
        try:
            system_param = workflow_param.get('SYSTEM', None)
            net_param = workflow_param.get('NETWORK', None)
            train_param = workflow_param.get('TRAINING', None)
            infer_param = workflow_param.get('INFERENCE', None)
            app_param = workflow_param.get('CUSTOM', None)
        except AttributeError:
            tf.logging.fatal('parameters should be dictionaries')
            raise

        assert os.path.exists(system_param.model_dir), \
            'Model folder not exists {}'.format(system_param.model_dir)
        self.is_training = (system_param.action == "train")
        # hardware-related parameters
        self.num_threads = max(system_param.num_threads, 1) \
            if self.is_training else 1
        self.num_gpus = system_param.num_gpus \
            if self.is_training else min(system_param.num_gpus, 1)
        set_cuda_device(system_param.cuda_devices)

        # set output TF model folders
        self.model_dir = touch_folder(
            os.path.join(system_param.model_dir, 'models'))
        self.session_prefix = os.path.join(self.model_dir, FILE_PREFIX)

        if self.is_training:
            assert train_param, 'training parameters not specified'
            summary_root = os.path.join(system_param.model_dir, 'logs')
            self.summary_dir = get_latest_subfolder(
                summary_root, create_new=train_param.starting_iter == 0)

            self.initial_iter = train_param.starting_iter
            self.final_iter = max(train_param.max_iter, self.initial_iter)
            self.save_every_n = train_param.save_every_n
            self.tensorboard_every_n = train_param.tensorboard_every_n
            self.max_checkpoints = \
                max(train_param.max_checkpoints, self.max_checkpoints)
            self.gradients_collector = GradientsCollector(
                n_devices=max(self.num_gpus, 1))
            self.validation_every_n = train_param.validation_every_n
            if self.validation_every_n > 0:
                self.validation_max_iter = max(self.validation_max_iter,
                                               train_param.validation_max_iter)
            action_param = train_param
        else:
            assert infer_param, 'inference parameters not specified'
            self.initial_iter = infer_param.inference_iter
            action_param = infer_param

        self.outputs_collector = OutputsCollector(
            n_devices=max(self.num_gpus, 1))

        # create an application instance
        assert app_param, 'application specific param. not specified'
        app_module = ApplicationDriver._create_app(app_param.name)
        self.app = app_module(net_param, action_param, self.is_training)

        # initialise data input
        data_partitioner = ImageSetsPartitioner()
        # clear the cached file lists
        data_partitioner.reset()
        do_new_partition = \
            self.is_training and self.initial_iter == 0 and \
            (not os.path.isfile(system_param.dataset_split_file)) and \
            (train_param.exclude_fraction_for_validation > 0 or
             train_param.exclude_fraction_for_inference > 0)
        data_fractions = None
        if do_new_partition:
            assert train_param.exclude_fraction_for_validation > 0 or \
                   self.validation_every_n <= 0, \
                'validation_every_n is set to {}, ' \
                'but train/validation splitting not available,\nplease ' \
                'check "exclude_fraction_for_validation" in the config ' \
                'file (current config value: {}).'.format(
                    self.validation_every_n,
                    train_param.exclude_fraction_for_validation)
            data_fractions = (train_param.exclude_fraction_for_validation,
                              train_param.exclude_fraction_for_inference)

        if data_param:
            data_partitioner.initialise(
                data_param=data_param,
                new_partition=do_new_partition,
                ratios=data_fractions,
                data_split_file=system_param.dataset_split_file)

        if data_param and self.is_training and self.validation_every_n > 0:
            assert data_partitioner.has_validation, \
                'validation_every_n is set to {}, ' \
                'but train/validation splitting not available.\nPlease ' \
                'check dataset partition list {} ' \
                '(remove file to generate a new dataset partition). ' \
                'Or set validation_every_n to -1.'.format(
                    self.validation_every_n, system_param.dataset_split_file)

        # initialise readers
        self.app.initialise_dataset_loader(data_param, app_param,
                                           data_partitioner)

        self._data_partitioner = data_partitioner

        # pylint: disable=not-context-manager
        with self.graph.as_default(), tf.name_scope('Sampler'):
            self.app.initialise_sampler()
예제 #34
0
                    spatial_window_size=(8, 2),
                    loader=None),
    'FLAIR':
    ParserNamespace(csv_file=os.path.join('testing_data', 'FLAIRsampler.csv'),
                    path_to_search='testing_data',
                    filename_contains=('FLAIR_', ),
                    filename_not_contains=('Parcellation', ),
                    interp_order=3,
                    pixdim=None,
                    axcodes=None,
                    spatial_window_size=(8, 2),
                    loader=None)
}
DYNAMIC_MOD_TASK = ParserNamespace(image=('T1', 'FLAIR'))

data_partitioner = ImageSetsPartitioner()
multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()
mod_2d_list = data_partitioner.initialise(MOD_2D_DATA).get_file_list()
dynamic_list = data_partitioner.initialise(DYNAMIC_MOD_DATA).get_file_list()


def get_3d_reader():
    reader = ImageReader(['image'])
    reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
    return reader


def get_2d_reader():
    reader = ImageReader(['image'])
    reader.initialise(MOD_2D_DATA, MOD_2D_TASK, mod_2d_list)
    return reader
CSVBAD_DATA = {
    'sampler':
    ParserNamespace(csv_file='',
                    path_to_search='',
                    filename_contains=(),
                    filename_not_contains=(),
                    interp_order=0,
                    pixdim=None,
                    axcodes=None,
                    spatial_window_size=(),
                    loader=None,
                    csv_data_file='data/csv_data/ICBMTest.csv')
}

data_partitioner = ImageSetsPartitioner()
# multi_mod_list = data_partitioner.initialise(MULTI_MOD_DATA).get_file_list()
# mod_2d_list = data_partitioner.initialise(MOD_2D_DATA).get_file_list()
dynamic_list = data_partitioner.initialise(DYNAMIC_MOD_DATA).get_file_list()

# def get_3d_reader():
#     reader = ImageReader(['image'])
#     reader.initialise(MULTI_MOD_DATA, MULTI_MOD_TASK, multi_mod_list)
#     return reader

# def get_2d_reader():
#     reader = ImageReader(['image'])
#     reader.initialise(MOD_2D_DATA, MOD_2D_TASK, mod_2d_list)
#     return reader