Esempio n. 1
0
def get_image_dir_format(dataset_config):
    """Get dir with input images for generating full path from frames_meta

    If the tiled dir is passed as data dir there will be no
    preprocessing_info.json. If json present use it, else read images from the
    given dir.
    """

    # tile dir pass directly as data_dir
    tile_dir = dataset_config['data_dir']
    image_format = 'zyx'

    # If the parent dir with tile dir, mask dir is passed as data_dir,
    # it should contain a json with directory names
    json_fname = os.path.join(dataset_config['data_dir'],
                              'preprocessing_info.json')
    if os.path.exists(json_fname):
        preprocessing_info = aux_utils.read_json(json_filename=json_fname)

        # Preprocessing_info is a list of jsons. Use the last json. If a tile
        # (training data) dir is specified and exists in info json use that
        recent_json = preprocessing_info[-1]
        pp_config = recent_json['config']
        if 'tile' in pp_config and 'tile_dir' in pp_config['tile']:
            tile_dir = pp_config['tile']['tile_dir']

        # Get shape order from recent_json
        if 'image_format' in pp_config['tile']:
            image_format = pp_config['tile']['image_format']

    return tile_dir, image_format
Esempio n. 2
0
def test_read_json():
    with TempDirectory() as tempdir:
        valid_json = {
            "a": 5,
            "b": 'test',
        }
        tempdir.write('json_file.json', json.dumps(valid_json).encode())
        json_object = aux_utils.read_json(
            os.path.join(tempdir.path, "json_file.json"), )
        nose.tools.assert_equal(json_object, valid_json)
 def test_save_config(self):
     cur_config = self.pp_config
     cur_config['masks']['mask_dir'] = os.path.join(
         self.output_dir, 'mask_channels_3')
     cur_config['tile']['tile_dir'] = os.path.join(
         self.output_dir, 'tiles_10-10_step_10-10')
     pp.save_config(cur_config, 11.1)
     # Load json back up
     saved_info = aux_utils.read_json(
         os.path.join(self.output_dir, 'preprocessing_info.json'),
     )
     self.assertEqual(len(saved_info), 1)
     saved_config = saved_info[0]['config']
     self.assertDictEqual(saved_config, cur_config)
     # Save one more config
     cur_config['input_dir'] = cur_config['tile']['tile_dir']
     pp.save_config(cur_config, 666.66)
     # Load json back up
     saved_info = aux_utils.read_json(
         os.path.join(self.output_dir, 'preprocessing_info.json'),
     )
     self.assertEqual(len(saved_info), 2)
     saved_config = saved_info[1]['config']
     self.assertDictEqual(saved_config, cur_config)
Esempio n. 4
0
    def _get_split_ids(self, data_split='test'):
        """
        Get the indices for data_split

        :param str data_split: in [train, val, test]
        :return list inference_ids: Indices for inference given data split
        :return str split_col: Dataframe column name, which was split in training
        """
        split_col = self.config['dataset']['split_by_column']
        try:
            split_fname = os.path.join(self.model_dir, 'split_samples.json')
            split_samples = aux_utils.read_json(split_fname)
            inference_ids = split_samples[data_split]
        except FileNotFoundError as e:
            print("No split_samples file. " "Will predict all images in dir.")
            frames_meta = aux_utils.read_meta(self.image_dir)
            inference_ids = np.unique(frames_meta[split_col]).tolist()
        return split_col, inference_ids
Esempio n. 5
0
def save_config(cur_config, runtime):
    """Save the cur_config or append to existing config"""

    # Read preprocessing.json if exists in input dir
    parent_dir = cur_config['input_dir'].split(os.sep)[:-1]
    parent_dir = os.sep.join(parent_dir)

    prior_config_fname = os.path.join(parent_dir, 'preprocessing_info.json')
    prior_pp_config = None
    if os.path.exists(prior_config_fname):
        prior_pp_config = aux_utils.read_json(prior_config_fname)

    meta_path = os.path.join(cur_config['output_dir'],
                             'preprocessing_info.json')

    processing_info = [{'processing_time': runtime,
                        'config': cur_config}]
    if prior_pp_config is not None:
        prior_pp_config.append(processing_info[0])
        processing_info = prior_pp_config
    aux_utils.write_json(processing_info, meta_path)
Esempio n. 6
0
def run_inference(args, gpu_id, gpu_mem_frac):
    """Evaluate model performance"""

    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    df_test = pd.read_csv(
        os.path.join(config['trainer']['model_dir'], 'test_metadata.csv'))

    if 'masked_loss' in config['trainer']:
        ds_test = DataSetWithMask(input_fnames=df_test['fpaths_input'],
                                  target_fnames=df_test['fpaths_target'],
                                  mask_fnames=df_test['fpaths_mask'],
                                  batch_size=config['trainer']['batch_size'])
    else:
        ds_test = BaseDataSet(input_fnames=df_test['fpaths_input'],
                              target_fnames=df_test['fpaths_target'],
                              batch_size=config['trainer']['batch_size'])

    ev_inst = ModelEvaluator(config,
                             model_fname=args.model_fname,
                             gpu_ids=gpu_id,
                             gpu_mem_frac=gpu_mem_frac)
    test_perf_metrics = ev_inst.evaluate_model(ds_test)

    ev_inst.predict_on_tiles(ds_test, nb_batches=args.num_batches)
    idx_fname = os.path.join(config['trainer']['model_dir'],
                             'split_samples.json')
    split_samples = aux_utils.read_json(idx_fname)

    image_meta = pd.read_csv(args.image_meta_fname)
    # for regression tasks change place_operation to 'mean'
    ev_inst.predict_on_full_image(image_meta=image_meta,
                                  test_samples=split_samples['test'],
                                  focal_plane_idx=args.focal_plane_idx,
                                  flat_field_correct=args.flat_field,
                                  base_image_dir=args.base_image_dir,
                                  place_operation='max')
    return test_perf_metrics
Esempio n. 7
0
def compute_metrics(model_dir,
                    image_dir,
                    metrics_list,
                    orientations_list,
                    test_data=True):
    """
    Compute specified metrics for given orientations for predictions, which
    are assumed to be stored in model_dir/predictions. Targets are stored in
    image_dir.
    Writes metrics csv files for each orientation in model_dir/predictions.

    :param str model_dir: Assumed to contain config, split_samples.json and
        subdirectory predictions/
    :param str image_dir: Directory containing target images with frames_meta.csv
    :param list metrics_list: See inference/evaluation_metrics.py for options
    :param list orientations_list: Any subset of {xy, xz, yz, xyz}
        (see evaluation_metrics)
    :param bool test_data: Uses test indices in split_samples.json,
    otherwise all indices
    """
    # Load config file
    config_name = os.path.join(model_dir, 'config.yml')
    with open(config_name, 'r') as f:
        config = yaml.safe_load(f)
    # Load frames metadata and determine indices
    frames_meta = pd.read_csv(os.path.join(image_dir, 'frames_meta.csv'))

    if isinstance(metrics_list, str):
        metrics_list = [metrics_list]
    metrics_inst = metrics.MetricsEstimator(metrics_list=metrics_list)

    split_idx_name = config['dataset']['split_by_column']
    if test_data:
        idx_fname = os.path.join(model_dir, 'split_samples.json')
        try:
            split_samples = aux_utils.read_json(idx_fname)
            test_ids = split_samples['test']
        except FileNotFoundError as e:
            print("No split_samples file. Will predict all images in dir.")
    else:
        test_ids = np.unique(frames_meta[split_idx_name])

    # Find other indices to iterate over than split index name
    # E.g. if split is position, we also need to iterate over time and slice
    test_meta = pd.read_csv(os.path.join(model_dir, 'test_metadata.csv'))
    metadata_ids = {split_idx_name: test_ids}
    iter_ids = ['slice_idx', 'pos_idx', 'time_idx']

    for id in iter_ids:
        if id != split_idx_name:
            metadata_ids[id] = np.unique(test_meta[id])

    # Create image subdirectory to write predicted images
    pred_dir = os.path.join(model_dir, 'predictions')

    target_channel = config['dataset']['target_channels'][0]

    # If network depth is > 3 determine depth margins for +-z
    depth = 1
    if 'depth' in config['network']:
        depth = config['network']['depth']

    # Get channel name and extension for predictions
    pred_fnames = [f for f in os.listdir(pred_dir) if f.startswith('im_')]
    meta_row = aux_utils.parse_idx_from_name(pred_fnames[0])
    pred_channel = meta_row['channel_idx']
    _, ext = os.path.splitext(pred_fnames[0])

    if isinstance(orientations_list, str):
        orientations_list = [orientations_list]
    available_orientations = {'xy', 'xz', 'yz', 'xyz'}
    assert set(orientations_list).issubset(available_orientations), \
        "Orientations must be subset of {}".format(available_orientations)

    fn_mapping = {
        'xy': metrics_inst.estimate_xy_metrics,
        'xz': metrics_inst.estimate_xz_metrics,
        'yz': metrics_inst.estimate_yz_metrics,
        'xyz': metrics_inst.estimate_xyz_metrics,
    }
    metrics_mapping = {
        'xy': metrics_inst.get_metrics_xy,
        'xz': metrics_inst.get_metrics_xz,
        'yz': metrics_inst.get_metrics_yz,
        'xyz': metrics_inst.get_metrics_xyz,
    }
    df_mapping = {
        'xy': pd.DataFrame(),
        'xz': pd.DataFrame(),
        'yz': pd.DataFrame(),
        'xyz': pd.DataFrame(),
    }

    # Iterate over all indices for test data
    for time_idx in metadata_ids['time_idx']:
        for pos_idx in metadata_ids['pos_idx']:
            target_fnames = []
            pred_fnames = []
            for slice_idx in metadata_ids['slice_idx']:
                im_idx = aux_utils.get_meta_idx(
                    frames_metadata=frames_meta,
                    time_idx=time_idx,
                    channel_idx=target_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                )
                target_fname = os.path.join(
                    image_dir,
                    frames_meta.loc[im_idx, 'file_name'],
                )
                target_fnames.append(target_fname)
                pred_fname = aux_utils.get_im_name(
                    time_idx=time_idx,
                    channel_idx=pred_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                    ext=ext,
                )
                pred_fname = os.path.join(pred_dir, pred_fname)
                pred_fnames.append(pred_fname)

            target_stack = image_utils.read_imstack(
                input_fnames=tuple(target_fnames),
            )
            pred_stack = image_utils.read_imstack(
                input_fnames=tuple(pred_fnames),
                normalize_im=False,
            )

            if depth == 1:
                # Remove singular z dimension for 2D image
                target_stack = np.squeeze(target_stack)
                pred_stack = np.squeeze(pred_stack)
            if target_stack.dtype == np.float64:
                target_stack = target_stack.astype(np.float32)
            pred_name = "t{}_p{}".format(time_idx, pos_idx)
            for orientation in orientations_list:
                metric_fn = fn_mapping[orientation]
                metric_fn(
                    target=target_stack,
                    prediction=pred_stack,
                    pred_name=pred_name,
                )
                df_mapping[orientation] = df_mapping[orientation].append(
                    metrics_mapping[orientation](),
                    ignore_index=True,
                )

    # Save non-empty dataframes
    for orientation in orientations_list:
        metrics_df = df_mapping[orientation]
        df_name = 'metrics_{}.csv'.format(orientation)
        metrics_name = os.path.join(pred_dir, df_name)
        metrics_df.to_csv(metrics_name, sep=",", index=False)
Esempio n. 8
0
def run_prediction(model_dir,
                   image_dir,
                   gpu_ids,
                   gpu_mem_frac,
                   model_fname=None,
                   metrics=None,
                   test_data=True,
                   ext='.tif',
                   save_figs=False,
                   normalize_im=False):
    """
    Predict images given model + weights.
    If the test_data flag is set to True, the test indices in
    split_samples.json file in model directory will be predicted
    Otherwise, all images in image directory will be predicted.
    It will load the config.yml file save in model_dir to reconstruct the model.
    Predictions are converted to uint16 and saved as png as default, but can
    also be saved as is in .npy format.
    If saving figures, it assumes that input as well as target channels are
    present in image_dir.

    :param str model_dir: Model directory
    :param str image_dir: Directory containing images for inference
    :param int gpu_ids: GPU ID to use for session
    :param float gpu_mem_frac: What fraction of GPU memory to use
    :param str model_fname: Model weights file name (in model dir)
    :param str metrics: String or list thereof of train/metrics.py functions
        to be computed during inference
    :param bool test_data: Use test indices from metadata, else use all
    :param str ext: File extension for inference output
    :param bool save_figs: Save plots of input/target/prediction
    """
    if gpu_ids >= 0:
        sess = train_utils.set_keras_session(gpu_ids=gpu_ids,
                                             gpu_mem_frac=gpu_mem_frac)
    # Load config file
    config_name = os.path.join(model_dir, 'config.yml')
    with open(config_name, 'r') as f:
        config = yaml.safe_load(f)
    # Load frames metadata and determine indices
    network_config = config['network']
    dataset_config = config['dataset']
    trainer_config = config['trainer']
    frames_meta = pd.read_csv(
        os.path.join(image_dir, 'frames_meta.csv'),
        index_col=0,
    )
    test_tile_meta = pd.read_csv(
        os.path.join(model_dir, 'test_metadata.csv'),
        index_col=0,
    )
    # TODO: generate test_frames_meta.csv together with tile csv during training
    test_frames_meta_filename = os.path.join(
        model_dir,
        'test_frames_meta.csv',
    )
    if metrics is not None:
        if isinstance(metrics, str):
            metrics = [metrics]
        metrics_cls = train_utils.get_metrics(metrics)
    else:
        metrics_cls = metrics
    loss = trainer_config['loss']
    loss_cls = train_utils.get_loss(loss)
    split_idx_name = dataset_config['split_by_column']
    K.set_image_data_format(network_config['data_format'])
    if test_data:
        idx_fname = os.path.join(model_dir, 'split_samples.json')
        try:
            split_samples = aux_utils.read_json(idx_fname)
            test_ids = split_samples['test']
        except FileNotFoundError as e:
            print("No split_samples file. Will predict all images in dir.")
    else:
        test_ids = np.unique(frames_meta[split_idx_name])

    # Find other indices to iterate over than split index name
    # E.g. if split is position, we also need to iterate over time and slice
    metadata_ids = {split_idx_name: test_ids}
    iter_ids = ['slice_idx', 'pos_idx', 'time_idx']
    for id in iter_ids:
        if id != split_idx_name:
            metadata_ids[id] = np.unique(test_tile_meta[id])

    # create empty dataframe for test image metadata
    if metrics is not None:
        test_frames_meta = pd.DataFrame(
            columns=frames_meta.columns.values.tolist() + metrics, )
    else:
        test_frames_meta = pd.DataFrame(
            columns=frames_meta.columns.values.tolist())
    # Get model weight file name, if none, load latest saved weights
    if model_fname is None:
        fnames = [f for f in os.listdir(model_dir) if f.endswith('.hdf5')]
        assert len(fnames) > 0, 'No weight files found in model dir'
        fnames = natsort.natsorted(fnames)
        model_fname = fnames[-1]
    weights_path = os.path.join(model_dir, model_fname)

    # Create image subdirectory to write predicted images
    pred_dir = os.path.join(model_dir, 'predictions')
    os.makedirs(pred_dir, exist_ok=True)
    target_channel = dataset_config['target_channels'][0]
    # If saving figures, create another subdirectory to predictions
    if save_figs:
        fig_dir = os.path.join(pred_dir, 'figures')
        os.makedirs(fig_dir, exist_ok=True)

    # If network depth is > 3 determine depth margins for +-z
    depth = 1
    if 'depth' in network_config:
        depth = network_config['depth']

    # Get input channel
    # TODO: Add multi channel support once such models are tested
    input_channel = dataset_config['input_channels'][0]
    assert isinstance(input_channel, int),\
        "Only supporting single input channel for now"
    # Get data format
    data_format = 'channels_first'
    if 'data_format' in network_config:
        data_format = network_config['data_format']
    # Load model with predict = True
    model = inference.load_model(
        network_config=network_config,
        model_fname=weights_path,
        predict=True,
    )
    print(model.summary())
    optimizer = trainer_config['optimizer']['name']
    model.compile(loss=loss_cls, optimizer=optimizer, metrics=metrics_cls)
    # Iterate over all indices for test data
    for time_idx in metadata_ids['time_idx']:
        for pos_idx in metadata_ids['pos_idx']:
            for slice_idx in metadata_ids['slice_idx']:
                # TODO: Add flatfield support
                im_stack = preprocess_imstack(frames_metadata=frames_meta,
                                              input_dir=image_dir,
                                              depth=depth,
                                              time_idx=time_idx,
                                              channel_idx=input_channel,
                                              slice_idx=slice_idx,
                                              pos_idx=pos_idx,
                                              normalize_im=normalize_im)
                # Crop image shape to nearest factor of two
                im_stack = image_utils.crop2base(im_stack)
                # Change image stack format to zyx
                im_stack = np.transpose(im_stack, [2, 0, 1])
                if depth == 1:
                    # Remove singular z dimension for 2D image
                    im_stack = np.squeeze(im_stack)
                # Add channel dimension
                if data_format == 'channels_first':
                    im_stack = im_stack[np.newaxis, ...]
                else:
                    im_stack = im_stack[..., np.newaxis]
                # add batch dimensions
                im_stack = im_stack[np.newaxis, ...]
                # Predict on large image
                start = time.time()
                im_pred = inference.predict_large_image(
                    model=model,
                    input_image=im_stack,
                )
                print("Inference time:", time.time() - start)
                # Write prediction image
                im_name = aux_utils.get_im_name(
                    time_idx=time_idx,
                    channel_idx=input_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                    ext=ext,
                )
                file_name = os.path.join(pred_dir, im_name)
                if ext == '.png':
                    # Convert to uint16 for now
                    im_pred = 2 ** 16 * (im_pred - im_pred.min()) / \
                              (im_pred.max() - im_pred.min())
                    im_pred = im_pred.astype(np.uint16)
                    cv2.imwrite(file_name, np.squeeze(im_pred))
                if ext == '.tif':
                    # Convert to float32 and remove batch dimension
                    im_pred = im_pred.astype(np.float32)
                    cv2.imwrite(file_name, np.squeeze(im_pred))
                elif ext == '.npy':
                    np.save(file_name, im_pred, allow_pickle=True)
                else:
                    raise ValueError('Unsupported file extension')

                # assuming target and predicted images are always 2D for now
                # Load target
                meta_idx = aux_utils.get_meta_idx(
                    frames_meta,
                    time_idx,
                    target_channel,
                    slice_idx,
                    pos_idx,
                )
                # get a single row of frame meta data
                test_frames_meta_row = frames_meta.loc[meta_idx].copy()
                im_target = preprocess_imstack(
                    frames_metadata=frames_meta,
                    input_dir=image_dir,
                    depth=1,
                    time_idx=time_idx,
                    channel_idx=target_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                )
                im_target = image_utils.crop2base(im_target)
                # TODO: Add image_format option to network config
                # Change image stack format to zyx
                im_target = np.transpose(im_target, [2, 0, 1])
                if depth == 1:
                    # Remove singular z dimension for 2D image
                    im_target = np.squeeze(im_target)
                # Add channel dimension
                if data_format == 'channels_first':
                    im_target = im_target[np.newaxis, ...]
                else:
                    im_target = im_target[..., np.newaxis]
                # add batch dimensions
                im_target = im_target[np.newaxis, ...]

                metric_vals = model.evaluate(x=im_pred, y=im_target)
                for metric, metric_val in zip([loss] + metrics, metric_vals):
                    test_frames_meta_row[metric] = metric_val

                test_frames_meta = test_frames_meta.append(
                    test_frames_meta_row,
                    ignore_index=True,
                )
                # Save figures if specified
                if save_figs:
                    # save predicted images assumes 2D
                    if depth > 1:
                        im_stack = im_stack[..., depth // 2, :, :]
                        im_target = im_target[0, ...]
                    plot_utils.save_predicted_images(input_batch=im_stack,
                                                     target_batch=im_target,
                                                     pred_batch=im_pred,
                                                     output_dir=fig_dir,
                                                     output_fname=im_name[:-4],
                                                     ext='jpg',
                                                     clip_limits=1,
                                                     font_size=15)

    # Save metrics as csv
    test_frames_meta.to_csv(test_frames_meta_filename, sep=",")