def test_clean_up_annotations(self):
        img = sk.measure.label(sk.data.binary_blobs(length=256, n_dim=2)) * 3
        img = np.expand_dims(img, axis=-1)
        img = np.expand_dims(img, axis=0)  # time axis
        uid = 100

        cleaned = utils.clean_up_annotations(img,
                                             uid=uid,
                                             data_format='channels_last')
        unique = np.unique(cleaned)
        assert len(np.unique(img)) == len(unique)
        expected = np.arange(len(unique)) + uid - 1
        expected[0] = 0  # background shouldn't get added
        np.testing.assert_equal(expected, unique)

        img = sk.measure.label(sk.data.binary_blobs(length=256, n_dim=2)) * 3
        img = np.expand_dims(img, axis=0)
        img = np.expand_dims(img, axis=1)  # time axis

        cleaned = utils.clean_up_annotations(img,
                                             uid=uid,
                                             data_format='channels_first')
        unique = np.unique(cleaned)
        assert len(np.unique(img)) == len(unique)
        expected = np.arange(len(unique)) + uid - 1
        expected[0] = 0  # background shouldn't get added
        np.testing.assert_equal(expected, unique)
    def __init__(self,
                 movie,
                 annotation,
                 tracking_model,
                 neighborhood_encoder=None,
                 distance_threshold=64,
                 appearance_dim=32,
                 death=0.99,
                 birth=0.99,
                 division=0.9,
                 track_length=5,
                 embedding_axis=0,
                 dtype='float32',
                 data_format='channels_last'):

        if not len(movie.shape) == 4 or not len(annotation.shape) == 4:
            raise ValueError(
                'Input data and labels but be rank 4 '
                '(frames, x, y, channels).  Got {} and {}.'.format(
                    len(movie.shape), len(annotation.shape)))

        if not movie.shape[:-1] == annotation.shape[:-1]:
            raise ValueError('Input data and labels should have the same shape'
                             ' except for the channel dimension.  Got {} and '
                             '{}'.format(movie.shape, annotation.shape))

        if data_format not in {'channels_first', 'channels_last'}:
            raise ValueError('The `data_format` argument must be one of '
                             '"channels_first", "channels_last". Received: ' +
                             str(data_format))

        self.X = copy.copy(movie)
        self.y = copy.copy(annotation)
        self.tracks = {}

        self.neighborhood_encoder = neighborhood_encoder
        self.tracking_model = tracking_model
        self.distance_threshold = distance_threshold
        self.appearance_dim = appearance_dim
        self.death = death
        self.birth = birth
        self.division = division
        self.dtype = dtype
        self.track_length = track_length
        self.embedding_axis = embedding_axis

        self.a_matrix = []
        self.c_matrix = []
        self.assignments = []

        self.data_format = data_format
        self.channel_axis = 0 if data_format == 'channels_first' else -1
        self.time_axis = 1 if data_format == 'channels_first' else 0
        self.logger = logging.getLogger(str(self.__class__.__name__))

        # Clean up annotations
        self.y = clean_up_annotations(self.y, data_format=self.data_format)

        # Accounting for 0 (background) label with 0-indexing for tracks
        self.id_to_idx = {}  # int: int mapping
        self.idx_to_id = {}  # (frame, cell_idx): cell_id mapping

        # Establish features for every instance of every cell in the movie
        adj_matrices, appearances, morphologies, centroids = self._est_feats()

        # Compute embeddings for every instance of every cell in the movie
        embeddings = self._get_neighborhood_embeddings(
            appearances=appearances,
            morphologies=morphologies,
            centroids=centroids,
            adj_matrices=adj_matrices)

        # TODO: immutable dict for safety? these values should never change.
        self.features = {
            'embedding': embeddings,
            'centroid': centroids,
        }
Beispiel #3
0
    def __init__(self,
                 movie,
                 annotation,
                 model,
                 features={'appearance', 'distance', 'neighborhood', 'regionprop'},
                 crop_dim=32,
                 death=0.95,
                 birth=0.95,
                 division=0.9,
                 max_distance=50,
                 track_length=7,
                 neighborhood_scale_size=30,
                 neighborhood_true_size=100,
                 dtype='float32',
                 data_format='channels_last'):

        if not len(movie.shape) == 4 or not len(annotation.shape) == 4:
            raise ValueError('Input data and labels but be rank 4 '
                             '(frames, x, y, channels).  Got {} and {}.'.format(
                                 len(movie.shape), len(annotation.shape)))

        if not movie.shape[:-1] == annotation.shape[:-1]:
            raise ValueError('Input data and labels should have the same shape'
                             ' except for the channel dimension.  Got {} and '
                             '{}'.format(movie.shape, annotation.shape))

        if not features:
            raise ValueError('`features` is empty but should be a list with any'
                             ' or all of the following values: "appearance", '
                             '"distance", "neighborhood" or "regionprop".')

        if data_format not in {'channels_first', 'channels_last'}:
            raise ValueError('The `data_format` argument must be one of '
                             '"channels_first", "channels_last". Received: ' +
                             str(data_format))

        self.x = copy.copy(movie)
        self.y = copy.copy(annotation)
        self.tracks = {}
        # TODO: Use a model that is served by tf-serving, not one on a local machine
        self.model = model
        self.crop_dim = crop_dim
        self.death = death
        self.birth = birth
        self.division = division
        self.max_distance = max_distance
        self.neighborhood_scale_size = neighborhood_scale_size
        self.neighborhood_true_size = neighborhood_true_size
        self.dtype = dtype
        self.data_format = data_format
        self.track_length = track_length
        self.channel_axis = 0 if data_format == 'channels_first' else -1
        self.time_axis = 1 if data_format == 'channels_first' else 0
        self.logger = logging.getLogger(str(self.__class__.__name__))

        self._track_cells = self.track_cells  # backwards compatibility

        self.features = sorted(features)

        # Clean up annotations
        self.y = clean_up_annotations(self.y, data_format=self.data_format)