Пример #1
0
 def shard(self, worker_index, num_workers):
     custom_print('SHARDS: Worker %s/%s' % (worker_index + 1, num_workers))
     assert (worker_index < num_workers)
     self.files = [
         f for i, f in enumerate(self.all_files)
         if (i % num_workers) == worker_index
     ]
Пример #2
0
    def fit(self, X, y):
        if "hooks" in self.params and "LogTotalSteps" in self.params["hooks"]:
            self.params["hooks"]["LogTotalSteps"][
                "batch_size"] = self.input_fn_config["batch_size"]
            self.params["hooks"]["LogTotalSteps"][
                "epochs"] = self.input_fn_config["num_epochs"]
            self.params["hooks"]["LogTotalSteps"]["train_size"] = train_size(X)

        with BaseTF.lock:
            config = self.config
            if BaseTF.num_instances > 1:
                config["model_dir"] = os.path.join(
                    config["model_dir"], "inst-" + str(self.instance_id))

        self.est_config = config
        self.estimator = tf.estimator.Estimator(
            model_fn=self.model_fn,
            params=self.params,
            config=tf.estimator.RunConfig(**config))

        if self.feature_spec is None:
            self.feature_spec = feature_spec_from(X)

        tf.logging.set_verbosity(tf.logging.ERROR)
        try:
            self.fit_main_training_loop(X, y)
        except KeyboardInterrupt:
            custom_print("\nEarly stop of training, saving model...")
            self.export_estimator()
            return self
        else:
            self.export_estimator()
            return self
Пример #3
0
 def create_graph(self, resized_image_tensor):
     input_tensor_name = self.model_info['resized_input_tensor_name']
     try:
         with gfile.FastGFile(self.model_path, 'rb') as f:
             graph_def = tf.GraphDef()
             graph_def.ParseFromString(f.read())
             custom_print('[Pretrained] Model loaded: ', self.model_path)
             model_features = tf.import_graph_def(
                 graph_def,
                 name='inceptionv3',
                 input_map={
                     input_tensor_name: resized_image_tensor,
                 },
                 return_elements=[
                     self.model_info['bottleneck_tensor_name'],
                 ],
             )
         return tf.reshape(
             model_features[0],
             [-1, self.model_info['bottleneck_tensor_size']],
         )
     except NotFoundError:
         custom_print(('Unable to open file %s; please download the model' +
                       ' there (%s)') %
                      (self.model_path, self.model_info['data_url']))
         exit(0)
Пример #4
0
 def fit_main_training_loop(self, X, y):
     self.estimator.train(
         input_fn=self.gen_input_fn(X, y, True, self.input_fn_config))
     evaluation_fn = self.gen_input_fn(X, y, False, self.input_fn_config)
     if evaluation_fn is None:
         custom_print("No evaluation data available - skipping evaluation.")
         return
     evaluation = self.estimator.evaluate(input_fn=evaluation_fn)
     custom_print(evaluation)
Пример #5
0
    def _generate_masked_tractography(self,
                                      reseed_endpoints=False,
                                      affine=None,
                                      predictor=None):
        """Generate the tractography using the white matter mask.

        Args:
            reseed_endpoints: Boolean. If True, use the end points of the fibers
                produced to generate another tractography. This is to symmetrize
                the process.
        """

        i = 0
        while self.ongoing_fibers:
            i += 1
            predictions = predictor(self._build_next_X(affine))["predictions"]
            directions = self.get_directions_from_predictions(
                predictions, affine)

            # Update the positions of the fibers and check if they are still ongoing
            cur_ongoing = []
            for j, fiber in enumerate(self.ongoing_fibers):
                new_position = fiber[-1] + directions[j] * self.args.step_size

                if i == 1 and self._is_border(new_position):
                    # First step is ambiguous and leads into boarder -> flip it.
                    new_position = fiber[
                        -1] - directions[j] * self.args.step_size

                # Only continue fibers inside the boundaries and short enough
                if self._is_border(new_position) or \
                        i * self.args.step_size > self.max_fiber_length:
                    self.tractography.append(fiber)
                else:
                    fiber.append(new_position)
                    cur_ongoing.append(fiber)
            self.ongoing_fibers = cur_ongoing

            end = "\r"
            if i % 25 == 0:
                end = "\n"
            custom_print("Round num:",
                         '%4d' % i,
                         "; ongoing:",
                         '%7d' % len(self.ongoing_fibers),
                         "; completed:",
                         '%7d' % len(self.tractography),
                         end=end)

        if reseed_endpoints:
            ending_seeds = [[fiber[-1]] for fiber in self.tractography]
            self.ongoing_fibers = ending_seeds
            self._generate_masked_tractography(reseed_endpoints=False,
                                               affine=affine)
Пример #6
0
    def predict(self, X, args):
        """Generate the tracktography with the current model on the given brain."""
        # Check model
        check_is_fitted(self, ["n_incoming", "block_size"])
        assert isinstance(X, dict)

        #predictor = tf.contrib.predictor.from_saved_model(self._restore_path)

        predictor = self.predictor(self.feature_spec)

        self.args = args

        try:
            self.wm_mask = X['mask']
            self.nii = X['dwi']
            if 'max_fiber_length' in args:
                self.max_fiber_length = args.max_fiber_length
            else:
                self.max_fiber_length = 400
        except KeyError as err:
            custom_print("KeyError: {}".format(err))

        # Get brain information
        self.brain_data = X['dwi']

        # If no seeds are specified, build them from the wm mask
        if 'seeds' not in self.args:
            seeds = self._seeds_from_wm_mask()

        # The final result will be here
        self.tractography = []
        # Fibers that are still under construction. At first seeds.
        self.ongoing_fibers = seeds

        if predictor is not None:
            # Start tractography generation
            if 'reseed_endpoints' in self.args:
                self._generate_masked_tractography(
                    self.args.reseed_endpoints,
                    affine=X["header"]["vox_to_ras"],
                    predictor=predictor)
            else:
                self._generate_masked_tractography(
                    affine=X["header"]["vox_to_ras"], predictor=predictor)

            # Save the Fibers
            fiber_path = os.path.join(self.save_path, "fibers.trk")
            save_fibers(self.tractography, X["header"], fiber_path)
Пример #7
0
    def __init__(self, run_config, sumatra_outcome_config, *args, **kwargs):
        self.is_model_first_run = True
        self.run_config = run_config
        self.sumatra_outcome_config = sumatra_outcome_config
        self.sumatra_outcome = {}
        self.sumatra_outcome['run_tags'] = []

        tf_run_config = copy.deepcopy(run_config['tf_estimator_run_config'])
        if 'session_config' in tf_run_config:
            session_config = tf_run_config['session_config']
            gpu_config = session_config.get('gpu_options')
            if gpu_config is not None:
                del session_config['gpu_options']
            session_config_obj = tf.ConfigProto(**session_config)
            if gpu_config is not None:
                for conf_key, conf_val in gpu_config.items():
                    setattr(session_config_obj.gpu_options, conf_key, conf_val)
            tf_run_config['session_config'] = session_config_obj

        super(Estimator, self).__init__(
            config=tf_run_config,
            *args,
            **kwargs
        )

        # Data provider
        provider = self.input_fn_config['data_provider']
        module = getattr(src.data.providers, provider)
        if module is None:
            custom_print(
                'FATAL: Data provider ' +
                self.input_fn_config['data_provider'] +
                ' not found'
            )
        assert(module is not None)
        self.data_provider = module.DataProvider(self.input_fn_config)
        ft_def.all_features.feature_info[ft_def.MRI]['shape'] = \
            self.data_provider.get_mri_shape()

        self.feature_spec = {
            name: tf.placeholder(
                    shape=[1] + ft_info['shape'],
                    dtype=ft_info['type']
                )
            for name, ft_info in ft_def.all_features.feature_info.items()
        }
Пример #8
0
    def _seeds_from_wm_mask(self):
        """Compute the seeds for the streamlining from the white matter mask.

        This is invoked only if no seeds are specified.
        The seeds are selected on the interface between white and gray matter, i.e. they are the
        white matter voxels that have at least one gray matter neighboring voxel.
        These points are furthermore perturbed with some gaussian noise to have a wider range of
        starting points.

        Returns:
            seeds: The list of voxel that are seeds.
        """
        # Take te border voxels as seeds
        seeds = self._find_borders()
        custom_print("Number of seeds on the white matter mask:", len(seeds))
        custom_print("Number of requested seeds:", self.args.n_fibers)
        new_idxs = np.random.choice(len(seeds),
                                    self.args.n_fibers,
                                    replace=True)
        new_seeds = [[
            seeds[i] + np.clip(np.random.normal(0, 0.25, 3), -0.5, 0.5)
        ] for i in new_idxs]
        return new_seeds
Пример #9
0
    def filter_xml(self,
                   files,
                   xml_image_id,
                   filters,
                   xml_class=None,
                   xml_patient_id=None):
        self.image_id_enabled = set()
        self.image_id_to_class = {}
        discarded_count = 0
        for f in glob.glob(files):
            tree = ET.parse(f)
            root = tree.getroot()
            image_id = int(xml_elem_unique(root, xml_image_id))

            pass_all_filters = True
            for filter in filters:
                value = xml_elem_unique(root, filter['key'], allowNone=True)
                if not filters_match(value, filter['value']):
                    pass_all_filters = False
                    break
            if not pass_all_filters:
                discarded_count += 1
                continue
            self.image_id_enabled.add(image_id)
            if xml_class is not None:
                self.image_id_to_class[image_id] = xml_elem_unique(
                    root,
                    xml_class,
                )
            if xml_patient_id is not None:
                self.image_id_to_patient_id[image_id] = xml_elem_unique(
                    root,
                    xml_patient_id,
                )

        custom_print('[filter_xml] %s images discarded' % (discarded_count))
Пример #10
0
    def transform(self, X=None):
        steps_registered = {
            'no_operation': self.no_operation,
            'exec_command': self.exec_command,
            'brain_extraction': self.brain_extraction,
            'template_registration': self.template_registration,
            'image_crop': self.image_crop,
            'eddy_correct': self.eddy_correct,
            'dtifit': self.dtifit,
        }
        for step in self.steps:
            self._mkdir(step['subfolder'])
        custom_print('Applying MRI pipeline to %s files' % (len(self.files)))
        all_images_ids = []
        for i, mri_raw in enumerate(self.files):
            image_id, patient_id = self.extract_image_and_patient(mri_raw)
            custom_print('Image %s/%s [image_id = %s]' %
                         (i, len(self.files), image_id))
            if self.image_id_enabled is not None:
                if image_id not in self.image_id_enabled:
                    custom_print('... Skipped')
                    continue

            for step_id, step in enumerate(self.steps):
                if 'skip' in step:
                    continue
                if step['from_subfolder'] == 'raw':
                    path_from = mri_raw
                else:
                    path_from = os.path.join(
                        self.path,
                        step['from_subfolder'],
                        'I{image_id}.nii.gz'.format(image_id=image_id),
                    )
                path_to = os.path.join(
                    self.path,
                    step['subfolder'],
                    'I{image_id}.nii.gz'.format(image_id=image_id),
                )
                if not os.path.exists(path_from):
                    continue
                if os.path.exists(path_to) and not step['overwrite']:
                    continue
                steps_registered[step['type']](
                    path_from,
                    path_to,
                    image_id,
                    step,
                )
            all_images_ids.append(image_id)

        # Split train/test
        if self.split_train_test is not None:
            self.do_split_train_test(all_images_ids, **self.split_train_test)
Пример #11
0
    def fit_main_training_loop(self, X, y):
        if self.streamer is not None:
            self.streamer.dump_split(self.save_path)
            self.streamer.dump_normalization(self.save_path)
            self.streamer.dump_train_val_test_split(self.save_path)

        n_epochs = self.input_fn_config["num_epochs"]
        self.n_epochs = n_epochs
        self.input_fn_config["num_epochs"] = 1

        # Compute the number of steps per epoch
        #n_train_steps = self.streamer.get_number_train_batches()
        #self.config["keep_checkpoint_max"] = n_epochs
        #self.config["save_checkpoints_steps"] = n_train_steps + 1

        output_dir = self.config["model_dir"]
        self.metric_logger = MetricLogger(output_dir, "Evaluation metrics")
        for i in range(n_epochs):
            self.current_epoch = i
            # train
            self.params["validation"] = False
            self.estimator = tf.estimator.Estimator(
                model_fn=self.model_fn,
                params=self.params,
                config=tf.estimator.RunConfig(**self.est_config))
            self.estimator.train(input_fn=self.gen_input_fn(
                X, y, "train", self.input_fn_config))

            # validation
            self.params["validation"] = True
            self.estimator = tf.estimator.Estimator(
                model_fn=self.model_fn,
                params=self.params,
                config=tf.estimator.RunConfig(**self.est_config))
            if "do_validation" in self.sumatra_params and \
                    self.sumatra_params["do_validation"]:
                validation_fn = self.gen_input_fn(X, y, "validation",
                                                  self.input_fn_config)
                if validation_fn is None:
                    custom_print("No evaluation - skipping evaluation.")
                    return
                validation = self.estimator.evaluate(input_fn=validation_fn,
                                                     name="validation")
                print(validation)
                self.metric_logger.add_evaluations("validation", validation)

            # evaluate
            # evaluation on test set
            self.params["validation"] = False
            self.estimator = tf.estimator.Estimator(
                model_fn=self.model_fn,
                params=self.params,
                config=tf.estimator.RunConfig(**self.est_config))
            if "no_test" not in self.sumatra_params:
                evaluation_fn = self.gen_input_fn(X, y, "test",
                                                  self.input_fn_config)
                if evaluation_fn is None:
                    custom_print("No evaluation - skipping evaluation.")
                    return
                evaluation = self.estimator.evaluate(input_fn=evaluation_fn,
                                                     name="test")
                print(evaluation)
                self.metric_logger.add_evaluations("test", evaluation)

            # persist evaluations to json file
            self.metric_logger.dump()
            sys.stdout.flush()

            if "keep_epoch_checkpoints" in self.data_params and \
                    self.data_params['keep_epoch_checkpoints']:
                # Copy checkpoint to new epoch folder
                dest_path = os.path.join(self.save_path, "epoch{}".format(i))
                os.makedirs(dest_path)
                for fname in os.listdir(self.save_path):
                    full_path = os.path.join(self.save_path, fname)
                    if os.path.isfile(
                            full_path) and "outcome" not in full_path:
                        shutil.copy(full_path, dest_path)

        if "keep_embeddings" in self.data_params:
            self.compress_data(True)
        else:
            self.compress_data(False)
        if "keep_checkpoint" in self.data_params:
            if not self.data_params["keep_checkpoint"]:
                self.remove_checkpoints()
        else:
            self.remove_checkpoints()
        if "keep_tfevents" in self.data_params:
            if not self.data_params["keep_tfevents"]:
                self.remove_tfevents()
        else:
            self.remove_tfevents()
        self.streamer = None
Пример #12
0
 def export_estimator(self):
     receiver_fn = input_receiver_fn(self.feature_spec)
     self._restore_path = self.estimator.export_savedmodel(
         self.save_path, receiver_fn)
     custom_print("Model saved to {}".format(self._restore_path))
Пример #13
0
    def fit_main_training_loop(self, X, y):
        """
        Trains and runs validation regularly at the same time
        """
        self.evaluations = []
        self.training_metrics = []

        custom_print('[INFO] Main training loop. Model dir is %s' % (
            self.config["model_dir"]))
        if 'reason' in self.sumatra_outcome_config:
            custom_print('[INFO] Provided run reason in config: %s' % (
                self.sumatra_outcome_config['reason'],
            ))

        # Dump list of train/test images
        dataset = self.data_provider.export_dataset()
        if dataset is not None:
            with open('%s/dataset.json' % (
                self.config["model_dir"]
            ), 'w') as outfile:
                json.dump(dataset, outfile)

        printed_count = [0]  # Workaround to modify variable inside nested func

        def one_run_finished(t='.'):
            sys.stdout.write(t)
            printed_count[0] += 1
            if printed_count[0] % 10 == 0:
                sys.stdout.write('\n')
            sys.stdout.flush()
            self.write_outcome()

        def do_evaluate():
            evaluate_fn = self.gen_input_fn(X, y, False, self.input_fn_config)
            assert(evaluate_fn is not None)
            self.evaluations.append(
                    self.estimator.evaluate(input_fn=evaluate_fn)
                )
            one_run_finished('V')

        num_epochs = self.run_config['num_epochs']
        validations_per_epoch = self.run_config['validations_per_epoch']

        one_run_finished('\n')
        # 1st case, evaluation every few epochs
        if validations_per_epoch <= 1:
            validation_counter = 0
            train_fn = self.gen_input_fn(X, y, True, self.input_fn_config)
            for i in range(num_epochs):
                self.estimator.train(input_fn=train_fn)
                one_run_finished()

                # Check if we need to run validation
                validation_counter += validations_per_epoch
                if validation_counter >= 1:
                    validation_counter -= 1
                    do_evaluate()

        # 2nd case, several evaluations per epoch (should be an int then!)
        else:
            iters = num_epochs * validations_per_epoch
            for i in range(iters):
                train_fn = self.gen_input_fn(
                    X, y, True, self.input_fn_config,
                    shard=(i % validations_per_epoch, validations_per_epoch),
                )
                self.estimator.train(input_fn=train_fn)
                one_run_finished()
                do_evaluate()
Пример #14
0
    def do_split_train_test(
        self,
        image_ids,
        random_seed,
        pkl_prefix,
        test_images_def,
    ):
        if self.image_id_to_class is None:
            if self.set_all_single_class is None:
                custom_print('Train/Test split: No class loaded from XML.')
                return
            self.image_id_to_class = {
                img_id: self.set_all_single_class
                for img_id in image_ids
            }

        num_images = len(image_ids)
        image_ids = [id for id in image_ids if id in self.image_id_to_class]
        if num_images != len(image_ids):
            custom_print('Train/Test split: %d/%d images with unknown class!' %
                         (
                             num_images - len(image_ids),
                             num_images,
                         ))
        # r = random.Random(random_seed)
        all_test = []
        all_train = []
        patients_dict = {}
        for class_idx, class_def in enumerate(test_images_def):
            if not isinstance(class_def['class'], list):
                assert (isinstance(class_def['class'], str))
                class_def['class'] = [class_def['class']]
            # We want to split by patient ID
            # and reach a minimum number of images
            class_images = [
                img_id for img_id in image_ids
                if self.image_id_to_class[img_id] in class_def['class']
            ]
            class_images = sorted(
                class_images,
                key=self.image_id_to_patient_id.__getitem__,
            )
            custom_print('Class %d [%s]' % (
                class_idx,
                ','.join(class_def['class']),
            ))
            num_test_images = class_def['count']
            if num_test_images == 0:
                test_images = []
                train_images = class_images
            else:
                last_test_patient_id = self.image_id_to_patient_id[
                    class_images[num_test_images - 1]]
                while self.image_id_to_patient_id[
                        class_images[num_test_images]] == last_test_patient_id:
                    num_test_images += 1
                test_images = class_images[:num_test_images]
                train_images = class_images[num_test_images:]
            custom_print('  TRAIN: %d images / %d patients' % (
                len(train_images),
                len(
                    set([
                        self.image_id_to_patient_id[img_id]
                        for img_id in train_images
                    ])),
            ))
            custom_print('  VALID: %d images / %d patients' % (
                len(test_images),
                len(
                    set([
                        self.image_id_to_patient_id[img_id]
                        for img_id in test_images
                    ])),
            ))
            all_test += test_images
            all_train += train_images
            patients_dict.update(
                {'I%d' % id: class_idx
                 for id in class_images})
        pickle.dump(
            ['I%d' % i for i in all_test],
            open(os.path.join(self.path, '%stest.pkl' % pkl_prefix), 'wb'),
        )
        pickle.dump(
            ['I%d' % i for i in all_train],
            open(os.path.join(self.path, '%strain.pkl' % pkl_prefix), 'wb'),
        )
        pickle.dump(
            patients_dict,
            open(os.path.join(self.path, '%slabels.pkl' % pkl_prefix), 'wb'),
        )
Пример #15
0
 def print_shape(self, text):
     if self.enable_print_shapes:
         custom_print(text)
Пример #16
0
 def log(text):
     if text in UniqueLogger.printed:
         return
     UniqueLogger.printed.add(text)
     custom_print(text)
Пример #17
0
 def _exec(self, cmd):
     custom_print('[Exec] ' + cmd)
     os.environ['FSLOUTPUTTYPE'] = 'NIFTI_GZ'
     os.environ['FSLDIR'] = '/local/fsl'
     subprocess.call(cmd, shell=True)