Exemplo n.º 1
0
def save_config(cur_config, runtime):
    """Save the cur_config or append to existing config"""

    # Read preprocessing.json if exists in input dir
    parent_dir = cur_config['input_dir'].split(os.sep)[:-1]
    parent_dir = os.sep.join(parent_dir)

    prior_config_fname = os.path.join(parent_dir, 'preprocessing_info.json')
    prior_pp_config = None
    if os.path.exists(prior_config_fname):
        prior_pp_config = aux_utils.read_json(prior_config_fname)

    meta_path = os.path.join(cur_config['output_dir'],
                             'preprocessing_info.json')

    processing_info = [{'processing_time': runtime,
                        'config': cur_config}]
    if prior_pp_config is not None:
        prior_pp_config.append(processing_info[0])
        processing_info = prior_pp_config
    aux_utils.write_json(processing_info, meta_path)
Exemplo n.º 2
0
def run_action(action, config, gpu_ids, gpu_mem_frac, model_fname=None):
    """
    Performs training or tunes hyper parameters
    Lambda layers throw errors when converting to yaml!
    model_yaml = self.model.to_yaml()

    :param str action: Currently the only supported action is 'train'
    :param dict config: Training config
    :param int gpu_ids: GPU ID
    :param float gpu_mem_frac: Available GPU memory fraction
    :param str model_fname: Full path to model weights if not starting
        training from scratch
    """
    assert action in {'train'}, "Currently only supported action is train"

    dataset_config = config['dataset']
    trainer_config = config['trainer']
    network_config = config['network']

    # Safety check: 2D UNets needs to have singleton dimension squeezed
    if network_config['class'] == 'UNet2D':
        dataset_config['squeeze'] = True
    elif network_config['class'] == 'UNetStackTo2D':
        dataset_config['squeeze'] = False

    # Check if masked loss exists
    masked_loss = False
    if 'masked_loss' in trainer_config:
        masked_loss = trainer_config["masked_loss"]

    tile_dir, image_format = get_image_dir_format(dataset_config)

    if action == 'train':
        # Create directory where model will be saved
        if not os.path.exists(trainer_config['model_dir']):
            os.makedirs(trainer_config['model_dir'], exist_ok=True)
        # Get tile directory from preprocessing info and load metadata
        tiles_meta = pd.read_csv(os.path.join(tile_dir, 'frames_meta.csv'))
        tiles_meta = aux_utils.sort_meta_by_channel(tiles_meta)
        # Generate training, validation and test data sets
        all_datasets, split_samples = create_datasets(
            tiles_meta,
            tile_dir,
            dataset_config,
            trainer_config,
            image_format,
            masked_loss,
        )
        # Save train, validation and test indices
        split_idx_fname = os.path.join(trainer_config['model_dir'],
                                       'split_samples.json')
        aux_utils.write_json(split_samples, split_idx_fname)

        K.set_image_data_format(network_config['data_format'])

        if gpu_ids == -1:
            sess = None
        else:
            sess = train_utils.set_keras_session(
                gpu_ids=gpu_ids,
                gpu_mem_frac=gpu_mem_frac,
            )

        if model_fname:
            # load model only loads the weights, have to save intermediate
            # states of gradients to resume training
            model = load_model(network_config, model_fname)
        else:
            with open(os.path.join(trainer_config['model_dir'], 'config.yml'),
                      'w') as f:
                yaml.dump(config, f, default_flow_style=False)
            model = create_network(network_config, gpu_ids)
            plot_model(model,
                       to_file=os.path.join(trainer_config['model_dir'],
                                            'model_graph.png'),
                       show_shapes=True,
                       show_layer_names=True)

        num_target_channels = network_config['num_target_channels']
        trainer = BaseKerasTrainer(sess=sess,
                                   train_config=trainer_config,
                                   train_dataset=all_datasets['df_train'],
                                   val_dataset=all_datasets['df_val'],
                                   model=model,
                                   num_target_channels=num_target_channels,
                                   gpu_ids=gpu_ids,
                                   gpu_mem_frac=gpu_mem_frac)
        trainer.train()

    else:
        raise TypeError(('action {} not permitted. options: only train'
                         'supported currently').format(action))
Exemplo n.º 3
0
    def setUp(self, mock_model):
        """
        Set up a directory with 3D images
        """
        mock_model.return_value = 'dummy_model'

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        self.tempdir.makedir('image_dir')
        self.tempdir.makedir('mask_dir')
        self.tempdir.makedir('model_dir')
        self.image_dir = os.path.join(self.temp_path, 'image_dir')
        self.mask_dir = os.path.join(self.temp_path, 'mask_dir')
        self.model_dir = os.path.join(self.temp_path, 'model_dir')
        # Create a temp image dir
        self.im = np.zeros((10, 10, 8), dtype=np.uint8)
        self.frames_meta = aux_utils.make_dataframe()
        self.time_idx = 2
        self.slice_idx = 0
        for p in range(5):
            for c in range(3):
                im_name = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=c,
                    slice_idx=self.slice_idx,
                    pos_idx=p,
                    ext='.npy',
                )
                np.save(os.path.join(self.image_dir, im_name),
                        self.im + c * 10,
                        allow_pickle=True,
                        fix_imports=True)
                self.frames_meta = self.frames_meta.append(
                    aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
                    ignore_index=True,
                )
        # Write frames meta to image dir too
        self.frames_meta.to_csv(os.path.join(self.image_dir,
                                             'frames_meta.csv'))
        # Save masks and mask meta
        self.mask_meta = aux_utils.make_dataframe()
        self.mask_channel = 50
        # Mask half the image
        mask = np.zeros_like(self.im)
        mask[:5, ...] = 1
        for p in range(5):
            im_name = aux_utils.get_im_name(
                time_idx=self.time_idx,
                channel_idx=self.mask_channel,
                slice_idx=self.slice_idx,
                pos_idx=p,
                ext='.npy',
            )
            np.save(os.path.join(self.mask_dir, im_name), mask)
            self.mask_meta = self.mask_meta.append(
                aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
                ignore_index=True,
            )
        # Write frames meta to mask dir too
        self.mask_meta.to_csv(os.path.join(self.mask_dir, 'frames_meta.csv'))
        # Setup model dir
        split_samples = {
            "train": [0, 1],
            "val": [2],
            "test": [3, 4],
        }
        aux_utils.write_json(
            split_samples,
            os.path.join(self.model_dir, 'split_samples.json'),
        )
        # Make configs with fields necessary for 2.5D segmentation inference
        self.train_config = {
            'network': {
                'class': 'UNet3D',
                'data_format': 'channels_first',
                'num_filters_per_block': [8, 16],
                'depth': 5,
                'width': 5,
                'height': 5
            },
            'dataset': {
                'split_by_column': 'pos_idx',
                'input_channels': [1],
                'target_channels': [2],
                'model_task': 'regression',
            },
        }
        self.inference_config = {
            'model_dir': self.model_dir,
            'model_fname': 'dummy_weights.hdf5',
            'image_dir': self.image_dir,
            'data_split': 'test',
            'images': {
                'image_format': 'zyx',
                'image_ext': '.png',
            },
            'metrics': {
                'metrics': ['mse'],
                'metrics_orientations': ['xyz'],
            },
            'masks': {
                'mask_dir': self.mask_dir,
                'mask_type': 'metrics',
                'mask_channel': 50,
            },
            'inference_3d': {
                'tile_shape': [5, 5, 5],
                'num_overlap': [1, 1, 1],
                'overlap_operation': 'mean',
            },
        }
        # Instantiate class
        self.infer_inst = image_inference.ImagePredictor(
            train_config=self.train_config,
            inference_config=self.inference_config,
        )
Exemplo n.º 4
0
    def setUp(self, mock_model):
        """
        Set up a directory with images
        """
        mock_model.return_value = 'dummy_model'

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        self.tempdir.makedir('image_dir')
        self.tempdir.makedir('mask_dir')
        self.tempdir.makedir('model_dir')
        self.image_dir = os.path.join(self.temp_path, 'image_dir')
        self.mask_dir = os.path.join(self.temp_path, 'mask_dir')
        self.model_dir = os.path.join(self.temp_path, 'model_dir')
        # Create a temp image dir
        self.im = np.zeros((10, 16), dtype=np.uint8)
        self.frames_meta = aux_utils.make_dataframe()
        self.time_idx = 2
        for p in range(5):
            for c in range(3):
                for z in range(6):
                    im_name = aux_utils.get_im_name(
                        time_idx=self.time_idx,
                        channel_idx=c,
                        slice_idx=z,
                        pos_idx=p,
                    )
                    cv2.imwrite(os.path.join(self.image_dir, im_name),
                                self.im + c * 10)
                    self.frames_meta = self.frames_meta.append(
                        aux_utils.parse_idx_from_name(im_name,
                                                      aux_utils.DF_NAMES),
                        ignore_index=True,
                    )
        # Write frames meta to image dir too
        self.frames_meta.to_csv(os.path.join(self.image_dir,
                                             'frames_meta.csv'))
        # Save masks and mask meta
        self.mask_meta = aux_utils.make_dataframe()
        self.mask_channel = 50
        for p in range(5):
            for z in range(6):
                im_name = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=self.mask_channel,
                    slice_idx=z,
                    pos_idx=p,
                )
                cv2.imwrite(os.path.join(self.mask_dir, im_name), self.im + 1)
                self.mask_meta = self.mask_meta.append(
                    aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
                    ignore_index=True,
                )
        # Write frames meta to mask dir too
        self.mask_meta.to_csv(os.path.join(self.mask_dir, 'frames_meta.csv'))
        # Setup model dir
        split_samples = {
            "train": [0, 1],
            "val": [2],
            "test": [3, 4],
        }
        aux_utils.write_json(
            split_samples,
            os.path.join(self.model_dir, 'split_samples.json'),
        )
        # Make configs with fields necessary for 2.5D segmentation inference
        self.train_config = {
            'network': {
                'class': 'UNetStackTo2D',
                'data_format': 'channels_first',
                'depth': 5,
                'width': 10,
                'height': 10
            },
            'dataset': {
                'split_by_column': 'pos_idx',
                'input_channels': [1],
                'target_channels': [self.mask_channel],
                'model_task': 'segmentation',
            },
        }
        self.inference_config = {
            'model_dir': self.model_dir,
            'model_fname': 'dummy_weights.hdf5',
            'image_dir': self.image_dir,
            'data_split': 'test',
            'images': {
                'image_format': 'zyx',
                'image_ext': '.png',
            },
            'metrics': {
                'metrics': ['dice'],
                'metrics_orientations': ['xy'],
            },
            'masks': {
                'mask_dir': self.mask_dir,
                'mask_type': 'target',
                'mask_channel': 50,
            }
        }
        # Instantiate class
        self.infer_inst = image_inference.ImagePredictor(
            train_config=self.train_config,
            inference_config=self.inference_config,
        )
Exemplo n.º 5
0
    def setUp(self):
        """
        Set up a directory with some images to generate frames_meta.csv for
        """
        self.tempdir = TempDirectory()
        self.temp_dir = self.tempdir.path
        self.model_dir = os.path.join(self.temp_dir, 'model_dir')
        self.pred_dir = os.path.join(self.model_dir, 'predictions')
        self.image_dir = os.path.join(self.temp_dir, 'image_dir')
        self.tempdir.makedir(self.model_dir)
        self.tempdir.makedir(self.pred_dir)
        self.tempdir.makedir(self.image_dir)
        # Write images
        self.time_idx = 5
        self.pos_idx = 7
        self.im = 1500 * np.ones((30, 20), dtype=np.uint16)
        im_add = np.zeros((30, 20), dtype=np.uint16)
        im_add[15:, :] = 10
        self.ext = '.tif'
        # Start frames meta file
        self.meta_name = 'frames_meta.csv'
        self.frames_meta = aux_utils.make_dataframe()

        for c in range(3):
            for z in range(5, 10):
                im_name = aux_utils.get_im_name(
                    channel_idx=c,
                    slice_idx=z,
                    time_idx=self.time_idx,
                    pos_idx=self.pos_idx,
                    ext=self.ext,
                )
                cv2.imwrite(os.path.join(self.image_dir, im_name), self.im)
                if c == 2:
                    norm_im = normalize.zscore(self.im + im_add).astype(np.float32)
                    cv2.imwrite(
                        os.path.join(self.pred_dir, im_name),
                        norm_im,
                    )
                self.frames_meta = self.frames_meta.append(
                    aux_utils.parse_idx_from_name(im_name),
                    ignore_index=True,
                )
        # Write metadata
        self.frames_meta.to_csv(
            os.path.join(self.image_dir, self.meta_name),
            sep=',',
        )
        # Write as test metadata in model dir too
        self.frames_meta.to_csv(
            os.path.join(self.model_dir, 'test_metadata.csv'),
            sep=',',
        )
        # Write split samples
        split_idx_fname = os.path.join(self.model_dir, 'split_samples.json')
        split_samples = {'test': [5, 6, 7, 8, 9]}
        aux_utils.write_json(split_samples, split_idx_fname)
        # Write config in model dir
        config = {
            'dataset': {
                'input_channels': [0, 1],
                'target_channels': [2],
                'split_by_column': 'slice_idx'
            },
            'network': {}
        }
        config_name = os.path.join(self.model_dir, 'config.yml')
        with open(config_name, 'w') as outfile:
            yaml.dump(config, outfile, default_flow_style=False)