Ejemplo n.º 1
0
def _get_memmap_or_load(za,
                        timepoint,
                        memmap_dir=None,
                        use_median=False,
                        img_size=None):
    if memmap_dir:
        key = f'{Path(za.store.path).parent.name}-t{timepoint}-{use_median}'
        fpath_org = Path(memmap_dir) / f'{key}.dat'
        if img_size is not None:
            key += '-' + '-'.join(map(str, img_size))
        fpath = Path(memmap_dir) / f'{key}.dat'
        lock = FileLock(str(fpath) + '.lock')
        with lock:
            if not fpath.exists():
                logger().info(f'creating {fpath}')
                fpath.parent.mkdir(parents=True, exist_ok=True)
                img_org = np.memmap(fpath_org,
                                    dtype='float32',
                                    mode='w+',
                                    shape=za.shape[1:])
                img_org[:] = za[timepoint].astype('float32')
                if img_size is None:
                    img = img_org
                else:
                    img = np.memmap(fpath,
                                    dtype='float32',
                                    mode='w+',
                                    shape=img_size)
                    img[:] = F.interpolate(
                        torch.from_numpy(img_org)[None, None],
                        size=img_size,
                        mode='trilinear' if img.ndim == 3 else 'bilinear',
                        align_corners=True,
                    )[0, 0].numpy()
                if use_median and img.ndim == 3:
                    global_median = np.median(img)
                    for z in range(img.shape[0]):
                        slice_median = np.median(img[z])
                        if 0 < slice_median:
                            img[z] -= slice_median - global_median
                img = normalize_zero_one(img)
            logger().info(f'loading from {fpath}')
            return np.memmap(
                fpath,
                dtype='float32',
                mode='c',
                shape=za.shape[1:] if img_size is None else img_size)
    else:
        img = za[timepoint].astype('float32')
        if use_median and img.ndim == 3:
            global_median = np.median(img)
            for z in range(img.shape[0]):
                slice_median = np.median(img[z])
                if 0 < slice_median:
                    img[z] -= slice_median - global_median
        img = normalize_zero_one(img)
    return img
Ejemplo n.º 2
0
def _get_flow_prediction(img, timepoint, model_path, keep_axials, device,
                         flow_norm_factor, patch_size):
    img = normalize_zero_one(img.astype('float32'))
    models = load_flow_models(
        model_path, keep_axials, device, is_eval=True)

    if patch_size is None:
        pad_base = (2**keep_axials.count(False), 16, 16)
        pad_shape = tuple(
            get_pad_size(img.shape[i + 1], pad_base[i])
            for i in range(len(img.shape) - 1)
        )
        slices = (slice(None),) + tuple(
            slice(None) if max(pad_shape[i]) == 0 else
            slice(pad_shape[i][0], -pad_shape[i][1])
            for i in range(len(pad_shape))
        )
        img = np.pad(img, ((0, 0),) + pad_shape, mode='reflect')
    with torch.no_grad():
        x = torch.from_numpy(img[np.newaxis]).to(device)
        prediction = np.mean(
            [
                predict(model, x, patch_size, is_log=False)
                for model in models
            ],
            axis=0)
    if patch_size is None:
        prediction = prediction[slices]
    # save and use as float16 to save the storage
    return prediction.astype('float16')
Ejemplo n.º 3
0
def spots_with_flow(config, spots):
    prediction = None
    if hasattr(config, 'tiff_input') and config.tiff_input is not None:
        img_input = np.array([skimage.io.imread(f) for f in config.tiff_input])
    elif config.zpath_input is not None:
        za_input = zarr.open(config.zpath_input, mode='a')
        za_flow = zarr.open(config.zpath_flow, mode='a')
        za_hash = zarr.open(config.zpath_flow_hashes, mode='a')

        # https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file#answer-3431838
        hash_md5 = hashlib.md5()
        with open(config.model_path, 'rb') as f:
            for chunk in iter(lambda: f.read(4096), b''):
                hash_md5.update(chunk)
        za_md5 = zarr.array(
            za_input[config.timepoint - 1:config.timepoint + 1]).digest('md5')
        hash_md5.update(za_md5)
        hash_md5.update(json.dumps(config.patch_size).encode('utf-8'))
        model_md5 = hash_md5.digest()
        if model_md5 == za_hash[config.timepoint - 1]:
            prediction = za_flow[config.timepoint - 1]
        else:
            img_input = np.array([
                normalize_zero_one(za_input[i].astype('float32'))
                for i in range(config.timepoint - 1, config.timepoint + 1)
            ])
    if prediction is None:
        try:
            prediction = _get_flow_prediction(img_input,
                                              config.timepoint,
                                              config.model_path,
                                              config.keep_axials,
                                              config.device,
                                              config.flow_norm_factor,
                                              config.patch_size)
        finally:
            torch.cuda.empty_cache()
        if config.output_prediction:
            za_flow[config.timepoint - 1] = prediction
            za_hash[config.timepoint - 1] = model_md5
        else:
            za_hash[config.timepoint - 1] = 0
    # Restore to voxel unit
    for d in range(prediction.shape[0]):
        prediction[d] *= config.flow_norm_factor[d]
    res_spots = _estimate_spots_with_flow(spots, prediction, config.scales)
    return res_spots
Ejemplo n.º 4
0
def _load_image(za_input, timepoint, use_median=False, img_size=None):
    img = za_input[timepoint].astype('float32')
    if use_median and img.ndim == 3:
        global_median = np.median(img)
        for z in range(img.shape[0]):
            slice_median = np.median(img[z])
            if 0 < slice_median:
                img[z] -= slice_median - global_median
    img = normalize_zero_one(img)
    if img_size is not None:
        img = F.interpolate(
            torch.from_numpy(img)[None, None],
            size=img_size,
            mode='trilinear' if img.ndim == 3 else 'bilinear',
            align_corners=True,
        )[0, 0].numpy()
    return img
Ejemplo n.º 5
0
    def __getitem__(self, index):
        i_frame = self.indices[index // self.n_crops]
        za_input = zarr.open(self.zpath_input, mode='r')
        img_input = np.array([
            normalize_zero_one(za_input[i].astype('float32'))
            for i in range(i_frame, i_frame + 2)
        ])
        za_label = zarr.open(self.zpath_flow_label, mode='r')
        img_label = za_label[i_frame]
        assert 0 < img_label[-1].max(), (
            'positive weight should exist in the label')
        if self.rotation_angle is not None and 0 < self.rotation_angle:
            # rotate image
            theta = randint(-self.rotation_angle, self.rotation_angle)
            img_input = np.array([
                [
                    rotate(
                        img_input[c, z],
                        theta,
                        resize=True,
                        preserve_range=True,
                        order=1,  # 1: Bi-linear (default)
                    ) for z in range(img_input.shape[1])
                ] for c in range(img_input.shape[0])
            ])
            # rotate label
            img_label = np.array([
                [
                    rotate(
                        img_label[c, z],
                        theta,
                        resize=True,
                        preserve_range=True,
                        order=0,  # 0: Nearest-neighbor
                    ) for z in range(img_label.shape[1])
                ] for c in range(img_label.shape[0])
            ])
            # update flow label (y axis is inversed in the image coordinate)
            cos_theta = math.cos(math.radians(theta))
            sin_theta = math.sin(math.radians(theta))
            img_label_x = img_label[0].copy()
            img_label_y = img_label[1].copy() * -1
            img_label[0] = cos_theta * img_label_x - sin_theta * img_label_y
            img_label[1] = (sin_theta * img_label_x +
                            cos_theta * img_label_y) * -1

        if 0 < sum(self.scale_factors):
            item_crop_size = [
                randrange(
                    min(
                        img_input.shape[i + 1],
                        round(self.crop_size[i] *
                              (1. - self.scale_factors[i]))),
                    min(
                        img_input.shape[i + 1] + 1,
                        int(self.crop_size[i] * (1. + self.scale_factors[i])) +
                        1)) for i in range(self.n_dims)
            ]
            # scale labels by resize factor
            img_label[0] *= self.crop_size[2] / item_crop_size[2]  # X
            img_label[1] *= self.crop_size[1] / item_crop_size[1]  # Y
            img_label[2] *= self.crop_size[0] / item_crop_size[0]  # Z
        else:
            item_crop_size = self.crop_size
        index_pool = np.argwhere(0 < img_label[-1])
        while True:
            base_index = index_pool[randrange(len(index_pool))]
            origins = [
                randint(
                    max(0, base_index[i] - (item_crop_size[i] - 1)),
                    min(img_input.shape[i + 1] - item_crop_size[i],
                        base_index[i])) for i in range(self.n_dims)
            ]
            slices = (slice(None), ) + tuple(
                slice(origins[i], origins[i] + item_crop_size[i])
                for i in range(self.n_dims))
            # First screening
            if img_label[slices][-1].max() != 0:
                tensor_label = torch.from_numpy(img_label[slices])
                if 0 < sum(self.scale_factors):
                    tensor_label = F.interpolate(tensor_label[None].float(),
                                                 self.crop_size,
                                                 mode='nearest')
                    tensor_label = tensor_label.view((4, ) + self.crop_size)
                # Second screening
                if tensor_label[-1].max() != 0:
                    break
        tensor_input = torch.from_numpy(img_input[slices])
        if 0 < sum(self.scale_factors):
            tensor_input = F.interpolate(tensor_input[None],
                                         self.crop_size,
                                         mode='trilinear',
                                         align_corners=True)
            tensor_input = tensor_input.view((2, ) + self.crop_size)
        # Channel order: (flow_x, flow_y, flow_z, mask, input_t0, input_t1)
        tensor_target = torch.cat((tensor_label, tensor_input), )
        return tensor_input, tensor_target
Ejemplo n.º 6
0
    def __getitem__(self, index):
        za_input = zarr.open(self.zpath_input, mode='r')
        if self.is_ae:
            i_frame = randrange(za_input.shape[0])
        else:
            # 0: unlabeled, 1: BG (LW), 2: Outer (LW), 3: Inner (LW)
            #               4: BG (HW), 5: Outer (HW), 6: Inner (HW)
            za_label = zarr.open(self.zpath_seg_label, mode='r')
            if self.is_livemode:
                while True:
                    v = self.redis_c.blpop(REDIS_KEY_TIMEPOINT, 1)
                    if v is not None:
                        i_frame = int(v[1])
                        if 0 < za_label[i_frame].max():
                            break
                    if (int(self.redis_c.get(REDIS_KEY_STATE)) ==
                            TrainState.IDLE.value):
                        return torch.tensor(-100.), torch.tensor(-100)
            else:
                i_frame = self.indices[index // self.n_crops]

            img_label = za_label[i_frame].astype('int64')
            # 0: unlabeled, 1: BG (LW), 2: Outer (LW), 3: Inner (LW)
            #               4: BG (HW), 5: Outer (HW), 6: Inner (HW)
            assert 0 < img_label.max(), (
                'positive weight should exist in the label')
        img_input = normalize_zero_one(za_input[i_frame].astype('float32'))
        if not self.is_ae:
            fg_index = np.isin(img_label, (1, 2, 4, 5))
            bg_index = np.isin(img_label, (0, 3))
            if fg_index.any() and bg_index.any():
                fg_mean = img_input[fg_index].mean()
                bg_mean = img_input[bg_index].mean()
                cr_factor = ((
                    (fg_mean - bg_mean) * uniform(0.5, 1) + bg_mean) / fg_mean)
                img_input[fg_index] *= cr_factor
        if self.rotation_angle is not None and 0 < self.rotation_angle:
            # rotate image
            theta = randint(-self.rotation_angle, self.rotation_angle)
            img_input = np.array([
                rotate(
                    img_input[z],
                    theta,
                    resize=True,
                    preserve_range=True,
                    order=1,  # 1: Bi-linear (default)
                ) for z in range(img_input.shape[0])
            ])
            # rotate label
            img_label = np.array([
                rotate(
                    img_label[z],
                    theta,
                    resize=True,
                    preserve_range=True,
                    order=0,  # 0: Nearest-neighbor
                ) for z in range(img_label.shape[0])
            ])
        if 0 < sum(self.scale_factors):
            item_crop_size = [
                randrange(
                    min(
                        img_input.shape[i],
                        round(self.crop_size[i] *
                              (1. - self.scale_factors[i]))),
                    min(
                        img_input.shape[i] + 1,
                        int(self.crop_size[i] * (1. + self.scale_factors[i])) +
                        1)) for i in range(self.n_dims)
            ]
        else:
            item_crop_size = self.crop_size
        if not self.is_ae:
            img_label -= 1
            # -1: unlabeled, 0: BG (LW), 1: Outer (LW), 2: Inner (LW)
            #                3: BG (HW), 4: Outer (HW), 5: Inner (HW)
            index_pool = np.argwhere(-1 < img_label)
        while True:
            if self.is_ae:
                origins = [
                    randint(0, img_input.shape[i] - item_crop_size[i])
                    for i in range(self.n_dims)
                ]
            else:
                base_index = index_pool[randrange(len(index_pool))]
                origins = [
                    randint(
                        max(0, base_index[i] - (item_crop_size[i] - 1)),
                        min(img_input.shape[i] - item_crop_size[i],
                            base_index[i])) for i in range(self.n_dims)
                ]
            slices = tuple(
                slice(origins[i], origins[i] + item_crop_size[i])
                for i in range(self.n_dims))
            if self.is_ae:
                break
            # First screening
            if self._is_valid(img_label[slices]):
                tensor_label = torch.from_numpy(img_label[slices])
                if 0 < sum(self.scale_factors):
                    tensor_label = F.interpolate(tensor_label[None,
                                                              None].float(),
                                                 self.crop_size,
                                                 mode='nearest').long()
                else:
                    tensor_label = tensor_label[None, None].long()
                # Second screening
                if self._is_valid(tensor_label.numpy()):
                    break
        tensor_input = torch.from_numpy(img_input[slices])
        if 0 < sum(self.scale_factors):
            tensor_input = F.interpolate(tensor_input[None, None],
                                         self.crop_size,
                                         mode='trilinear',
                                         align_corners=True)
            tensor_input = tensor_input.view((1, ) + self.crop_size)
        else:
            tensor_input = tensor_input[None]
        if self.is_ae:
            return tensor_input, tensor_input
        tensor_label = tensor_label.view(self.crop_size)
        return tensor_input, tensor_label
Ejemplo n.º 7
0
def test_normalize_zero_one():
    data = np.array(range(5)).astype(float)
    data = normalize_zero_one(data)
    assert data.min() == 0 and data.max() == 1
Ejemplo n.º 8
0
def _update_seg_labels(spots_dict, scales, zpath_input, zpath_seg_label,
                       zpath_seg_label_vis, auto_bg_thresh=0, c_ratio=0.5):
    za_input = zarr.open(zpath_input, mode='r')
    za_label = zarr.open(zpath_seg_label, mode='a')
    za_label_vis = zarr.open(zpath_seg_label_vis, mode='a')
    keyorder = ['tp', 'fp', 'tn', 'fn', 'tb', 'fb']
    MIN_AREA_ELLIPSOID = 20
    for t, spots in spots_dict.items():
        # label = np.zeros(label_shape, dtype='int64') - 1
        label = np.where(
            normalize_zero_one(za_input[t].astype('float32')) < auto_bg_thresh,
            1,
            0
        ).astype('uint8')
        cnt = collections.Counter({x: 0 for x in keyorder})
        for spot in spots:
            if int(redis_client.get(REDIS_KEY_STATE)) == TrainState.IDLE.value:
                print('aborted')
                return jsonify({'completed': False})
            cnt[spot['tag']] += 1
            centroid = np.array(spot['pos'][::-1])
            covariance = np.array(spot['covariance'][::-1]).reshape(3, 3)
            radii, rotation = np.linalg.eigh(covariance)
            radii = np.sqrt(radii)
            dd_outer, rr_outer, cc_outer = ellipsoid(
                centroid,
                radii,
                rotation,
                scales,
                label.shape,
                MIN_AREA_ELLIPSOID
            )
            label_offset = 0 if spot['tag'] in ['tp', 'tb', 'tn'] else 3
            if spot['tag'] in ('tp', 'fn'):
                dd_inner, rr_inner, cc_inner = ellipsoid(
                    centroid,
                    radii * c_ratio,
                    rotation,
                    scales,
                    label.shape,
                    MIN_AREA_ELLIPSOID
                )
                dd_inner_p, rr_inner_p, cc_inner_p = _dilate_3d_indices(
                    dd_inner, rr_inner, cc_inner, label.shape)
                label[dd_outer, rr_outer, cc_outer] = np.where(
                    np.fmod(label[dd_outer, rr_outer, cc_outer] - 1, 3) <= 1,
                    2 + label_offset,
                    label[dd_outer, rr_outer, cc_outer]
                )
                label[dd_inner_p, rr_inner_p, cc_inner_p] = 2 + label_offset
                label[dd_inner, rr_inner, cc_inner] = np.where(
                    np.fmod(label[dd_inner, rr_inner, cc_inner] - 1, 3) <= 2,
                    3 + label_offset,
                    label[dd_inner, rr_inner, cc_inner]
                )
            elif spot['tag'] in ('tb', 'fb'):
                label[dd_outer, rr_outer, cc_outer] = np.where(
                    np.fmod(label[dd_outer, rr_outer, cc_outer] - 1, 3) <= 1,
                    2 + label_offset,
                    label[dd_outer, rr_outer, cc_outer]
                )
            elif spot['tag'] in ('tn', 'fp'):
                label[dd_outer, rr_outer, cc_outer] = np.where(
                    np.fmod(label[dd_outer, rr_outer, cc_outer] - 1, 3) <= 0,
                    1 + label_offset,
                    label[dd_outer, rr_outer, cc_outer]
                )
        print('frame:{}, {}'.format(
            t, sorted(cnt.items(), key=lambda i: keyorder.index(i[0]))))
        za_label[t] = label
        za_label_vis[t, ..., 0] = np.where(
            label == 1, 255, 0) + np.where(label == 4, 127, 0)
        za_label_vis[t, ..., 1] = np.where(
            label == 2, 255, 0) + np.where(label == 5, 127, 0)
        za_label_vis[t, ..., 2] = np.where(
            label == 3, 255, 0) + np.where(label == 6, 127, 0)
        if redis_client.get(REDIS_KEY_NCROPS) is not None:
            for i in range(int(redis_client.get(REDIS_KEY_NCROPS))):
                redis_client.rpush(REDIS_KEY_TIMEPOINT, str(t))
    return jsonify({'completed': True})
Ejemplo n.º 9
0
def _get_seg_prediction(img, timepoint, model_path, keep_axials, device,
                        use_median=True, patch_size=None, crop_box=None,
                        is_pad=False):
    img = img.astype('float32')
    if use_median:
        global_median = np.median(img)
        for z in range(img.shape[0]):
            slice_median = np.median(img[z])
            if 0 < slice_median:
                img[z] -= slice_median - global_median
    img = normalize_zero_one(img)
    models = load_seg_models(model_path, keep_axials,
                             device, is_eval=True, is_pad=is_pad)
    if crop_box is not None:
        img = img[crop_box[0]:crop_box[0] + crop_box[3],  # Z
                  crop_box[1]:crop_box[1] + crop_box[4],  # Y
                  crop_box[2]:crop_box[2] + crop_box[5]]  # X
    if patch_size is None:
        pad_base = (2**keep_axials.count(False), 16, 16)
        pad_shape = tuple(
            get_pad_size(img.shape[i], pad_base[i])
            for i in range(len(img.shape))
        )
        slices = (slice(None),) + tuple(
            slice(None) if max(pad_shape[i]) == 0 else
            slice(pad_shape[i][0], -pad_shape[i][1])
            for i in range(len(pad_shape))
        )
        img = np.pad(img, pad_shape, mode='constant', constant_values=0)
    with torch.no_grad():
        x = torch.from_numpy(img[np.newaxis, np.newaxis]).to(device)
        prediction = []
        # test-time augmentation (TTA)
        # dims: (N, C, D, H, W)
        for flip_x in range(1):
            for flip_y in range(1):
                for flip_z in range(1):
                    fl_dims = [i-3 for i, flag in
                               enumerate([flip_z, flip_y, flip_x]) if flag]
                    xx = x.flip(fl_dims) if fl_dims else x
                    pred = np.mean([predict(model, xx, patch_size, is_log=True)
                                    for model in models],
                                   axis=0
                                   )
                    pred = np.flip(pred, fl_dims) if fl_dims else pred
                    prediction.append(pred)
        prediction = np.mean(prediction, axis=0)
    if patch_size is None:
        prediction = prediction[slices]
    for z in range(prediction.shape[1]):
        post_fg = np.maximum(
            prediction[2, z] - normalize_zero_one(
                prewitt(gaussian(prediction[0, z], sigma=3))),
            0
        )
        prediction[0, z] = np.minimum(
            prediction[0, z] + (prediction[2, z] - post_fg),
            1
        )
        prediction[2, z] = post_fg
        prediction[1, z] = 1. - (prediction[0, z] + prediction[2, z])
    return prediction