def _get_memmap_or_load(za, timepoint, memmap_dir=None, use_median=False, img_size=None): if memmap_dir: key = f'{Path(za.store.path).parent.name}-t{timepoint}-{use_median}' fpath_org = Path(memmap_dir) / f'{key}.dat' if img_size is not None: key += '-' + '-'.join(map(str, img_size)) fpath = Path(memmap_dir) / f'{key}.dat' lock = FileLock(str(fpath) + '.lock') with lock: if not fpath.exists(): logger().info(f'creating {fpath}') fpath.parent.mkdir(parents=True, exist_ok=True) img_org = np.memmap(fpath_org, dtype='float32', mode='w+', shape=za.shape[1:]) img_org[:] = za[timepoint].astype('float32') if img_size is None: img = img_org else: img = np.memmap(fpath, dtype='float32', mode='w+', shape=img_size) img[:] = F.interpolate( torch.from_numpy(img_org)[None, None], size=img_size, mode='trilinear' if img.ndim == 3 else 'bilinear', align_corners=True, )[0, 0].numpy() if use_median and img.ndim == 3: global_median = np.median(img) for z in range(img.shape[0]): slice_median = np.median(img[z]) if 0 < slice_median: img[z] -= slice_median - global_median img = normalize_zero_one(img) logger().info(f'loading from {fpath}') return np.memmap( fpath, dtype='float32', mode='c', shape=za.shape[1:] if img_size is None else img_size) else: img = za[timepoint].astype('float32') if use_median and img.ndim == 3: global_median = np.median(img) for z in range(img.shape[0]): slice_median = np.median(img[z]) if 0 < slice_median: img[z] -= slice_median - global_median img = normalize_zero_one(img) return img
def _get_flow_prediction(img, timepoint, model_path, keep_axials, device, flow_norm_factor, patch_size): img = normalize_zero_one(img.astype('float32')) models = load_flow_models( model_path, keep_axials, device, is_eval=True) if patch_size is None: pad_base = (2**keep_axials.count(False), 16, 16) pad_shape = tuple( get_pad_size(img.shape[i + 1], pad_base[i]) for i in range(len(img.shape) - 1) ) slices = (slice(None),) + tuple( slice(None) if max(pad_shape[i]) == 0 else slice(pad_shape[i][0], -pad_shape[i][1]) for i in range(len(pad_shape)) ) img = np.pad(img, ((0, 0),) + pad_shape, mode='reflect') with torch.no_grad(): x = torch.from_numpy(img[np.newaxis]).to(device) prediction = np.mean( [ predict(model, x, patch_size, is_log=False) for model in models ], axis=0) if patch_size is None: prediction = prediction[slices] # save and use as float16 to save the storage return prediction.astype('float16')
def spots_with_flow(config, spots): prediction = None if hasattr(config, 'tiff_input') and config.tiff_input is not None: img_input = np.array([skimage.io.imread(f) for f in config.tiff_input]) elif config.zpath_input is not None: za_input = zarr.open(config.zpath_input, mode='a') za_flow = zarr.open(config.zpath_flow, mode='a') za_hash = zarr.open(config.zpath_flow_hashes, mode='a') # https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file#answer-3431838 hash_md5 = hashlib.md5() with open(config.model_path, 'rb') as f: for chunk in iter(lambda: f.read(4096), b''): hash_md5.update(chunk) za_md5 = zarr.array( za_input[config.timepoint - 1:config.timepoint + 1]).digest('md5') hash_md5.update(za_md5) hash_md5.update(json.dumps(config.patch_size).encode('utf-8')) model_md5 = hash_md5.digest() if model_md5 == za_hash[config.timepoint - 1]: prediction = za_flow[config.timepoint - 1] else: img_input = np.array([ normalize_zero_one(za_input[i].astype('float32')) for i in range(config.timepoint - 1, config.timepoint + 1) ]) if prediction is None: try: prediction = _get_flow_prediction(img_input, config.timepoint, config.model_path, config.keep_axials, config.device, config.flow_norm_factor, config.patch_size) finally: torch.cuda.empty_cache() if config.output_prediction: za_flow[config.timepoint - 1] = prediction za_hash[config.timepoint - 1] = model_md5 else: za_hash[config.timepoint - 1] = 0 # Restore to voxel unit for d in range(prediction.shape[0]): prediction[d] *= config.flow_norm_factor[d] res_spots = _estimate_spots_with_flow(spots, prediction, config.scales) return res_spots
def _load_image(za_input, timepoint, use_median=False, img_size=None): img = za_input[timepoint].astype('float32') if use_median and img.ndim == 3: global_median = np.median(img) for z in range(img.shape[0]): slice_median = np.median(img[z]) if 0 < slice_median: img[z] -= slice_median - global_median img = normalize_zero_one(img) if img_size is not None: img = F.interpolate( torch.from_numpy(img)[None, None], size=img_size, mode='trilinear' if img.ndim == 3 else 'bilinear', align_corners=True, )[0, 0].numpy() return img
def __getitem__(self, index): i_frame = self.indices[index // self.n_crops] za_input = zarr.open(self.zpath_input, mode='r') img_input = np.array([ normalize_zero_one(za_input[i].astype('float32')) for i in range(i_frame, i_frame + 2) ]) za_label = zarr.open(self.zpath_flow_label, mode='r') img_label = za_label[i_frame] assert 0 < img_label[-1].max(), ( 'positive weight should exist in the label') if self.rotation_angle is not None and 0 < self.rotation_angle: # rotate image theta = randint(-self.rotation_angle, self.rotation_angle) img_input = np.array([ [ rotate( img_input[c, z], theta, resize=True, preserve_range=True, order=1, # 1: Bi-linear (default) ) for z in range(img_input.shape[1]) ] for c in range(img_input.shape[0]) ]) # rotate label img_label = np.array([ [ rotate( img_label[c, z], theta, resize=True, preserve_range=True, order=0, # 0: Nearest-neighbor ) for z in range(img_label.shape[1]) ] for c in range(img_label.shape[0]) ]) # update flow label (y axis is inversed in the image coordinate) cos_theta = math.cos(math.radians(theta)) sin_theta = math.sin(math.radians(theta)) img_label_x = img_label[0].copy() img_label_y = img_label[1].copy() * -1 img_label[0] = cos_theta * img_label_x - sin_theta * img_label_y img_label[1] = (sin_theta * img_label_x + cos_theta * img_label_y) * -1 if 0 < sum(self.scale_factors): item_crop_size = [ randrange( min( img_input.shape[i + 1], round(self.crop_size[i] * (1. - self.scale_factors[i]))), min( img_input.shape[i + 1] + 1, int(self.crop_size[i] * (1. + self.scale_factors[i])) + 1)) for i in range(self.n_dims) ] # scale labels by resize factor img_label[0] *= self.crop_size[2] / item_crop_size[2] # X img_label[1] *= self.crop_size[1] / item_crop_size[1] # Y img_label[2] *= self.crop_size[0] / item_crop_size[0] # Z else: item_crop_size = self.crop_size index_pool = np.argwhere(0 < img_label[-1]) while True: base_index = index_pool[randrange(len(index_pool))] origins = [ randint( max(0, base_index[i] - (item_crop_size[i] - 1)), min(img_input.shape[i + 1] - item_crop_size[i], base_index[i])) for i in range(self.n_dims) ] slices = (slice(None), ) + tuple( slice(origins[i], origins[i] + item_crop_size[i]) for i in range(self.n_dims)) # First screening if img_label[slices][-1].max() != 0: tensor_label = torch.from_numpy(img_label[slices]) if 0 < sum(self.scale_factors): tensor_label = F.interpolate(tensor_label[None].float(), self.crop_size, mode='nearest') tensor_label = tensor_label.view((4, ) + self.crop_size) # Second screening if tensor_label[-1].max() != 0: break tensor_input = torch.from_numpy(img_input[slices]) if 0 < sum(self.scale_factors): tensor_input = F.interpolate(tensor_input[None], self.crop_size, mode='trilinear', align_corners=True) tensor_input = tensor_input.view((2, ) + self.crop_size) # Channel order: (flow_x, flow_y, flow_z, mask, input_t0, input_t1) tensor_target = torch.cat((tensor_label, tensor_input), ) return tensor_input, tensor_target
def __getitem__(self, index): za_input = zarr.open(self.zpath_input, mode='r') if self.is_ae: i_frame = randrange(za_input.shape[0]) else: # 0: unlabeled, 1: BG (LW), 2: Outer (LW), 3: Inner (LW) # 4: BG (HW), 5: Outer (HW), 6: Inner (HW) za_label = zarr.open(self.zpath_seg_label, mode='r') if self.is_livemode: while True: v = self.redis_c.blpop(REDIS_KEY_TIMEPOINT, 1) if v is not None: i_frame = int(v[1]) if 0 < za_label[i_frame].max(): break if (int(self.redis_c.get(REDIS_KEY_STATE)) == TrainState.IDLE.value): return torch.tensor(-100.), torch.tensor(-100) else: i_frame = self.indices[index // self.n_crops] img_label = za_label[i_frame].astype('int64') # 0: unlabeled, 1: BG (LW), 2: Outer (LW), 3: Inner (LW) # 4: BG (HW), 5: Outer (HW), 6: Inner (HW) assert 0 < img_label.max(), ( 'positive weight should exist in the label') img_input = normalize_zero_one(za_input[i_frame].astype('float32')) if not self.is_ae: fg_index = np.isin(img_label, (1, 2, 4, 5)) bg_index = np.isin(img_label, (0, 3)) if fg_index.any() and bg_index.any(): fg_mean = img_input[fg_index].mean() bg_mean = img_input[bg_index].mean() cr_factor = (( (fg_mean - bg_mean) * uniform(0.5, 1) + bg_mean) / fg_mean) img_input[fg_index] *= cr_factor if self.rotation_angle is not None and 0 < self.rotation_angle: # rotate image theta = randint(-self.rotation_angle, self.rotation_angle) img_input = np.array([ rotate( img_input[z], theta, resize=True, preserve_range=True, order=1, # 1: Bi-linear (default) ) for z in range(img_input.shape[0]) ]) # rotate label img_label = np.array([ rotate( img_label[z], theta, resize=True, preserve_range=True, order=0, # 0: Nearest-neighbor ) for z in range(img_label.shape[0]) ]) if 0 < sum(self.scale_factors): item_crop_size = [ randrange( min( img_input.shape[i], round(self.crop_size[i] * (1. - self.scale_factors[i]))), min( img_input.shape[i] + 1, int(self.crop_size[i] * (1. + self.scale_factors[i])) + 1)) for i in range(self.n_dims) ] else: item_crop_size = self.crop_size if not self.is_ae: img_label -= 1 # -1: unlabeled, 0: BG (LW), 1: Outer (LW), 2: Inner (LW) # 3: BG (HW), 4: Outer (HW), 5: Inner (HW) index_pool = np.argwhere(-1 < img_label) while True: if self.is_ae: origins = [ randint(0, img_input.shape[i] - item_crop_size[i]) for i in range(self.n_dims) ] else: base_index = index_pool[randrange(len(index_pool))] origins = [ randint( max(0, base_index[i] - (item_crop_size[i] - 1)), min(img_input.shape[i] - item_crop_size[i], base_index[i])) for i in range(self.n_dims) ] slices = tuple( slice(origins[i], origins[i] + item_crop_size[i]) for i in range(self.n_dims)) if self.is_ae: break # First screening if self._is_valid(img_label[slices]): tensor_label = torch.from_numpy(img_label[slices]) if 0 < sum(self.scale_factors): tensor_label = F.interpolate(tensor_label[None, None].float(), self.crop_size, mode='nearest').long() else: tensor_label = tensor_label[None, None].long() # Second screening if self._is_valid(tensor_label.numpy()): break tensor_input = torch.from_numpy(img_input[slices]) if 0 < sum(self.scale_factors): tensor_input = F.interpolate(tensor_input[None, None], self.crop_size, mode='trilinear', align_corners=True) tensor_input = tensor_input.view((1, ) + self.crop_size) else: tensor_input = tensor_input[None] if self.is_ae: return tensor_input, tensor_input tensor_label = tensor_label.view(self.crop_size) return tensor_input, tensor_label
def test_normalize_zero_one(): data = np.array(range(5)).astype(float) data = normalize_zero_one(data) assert data.min() == 0 and data.max() == 1
def _update_seg_labels(spots_dict, scales, zpath_input, zpath_seg_label, zpath_seg_label_vis, auto_bg_thresh=0, c_ratio=0.5): za_input = zarr.open(zpath_input, mode='r') za_label = zarr.open(zpath_seg_label, mode='a') za_label_vis = zarr.open(zpath_seg_label_vis, mode='a') keyorder = ['tp', 'fp', 'tn', 'fn', 'tb', 'fb'] MIN_AREA_ELLIPSOID = 20 for t, spots in spots_dict.items(): # label = np.zeros(label_shape, dtype='int64') - 1 label = np.where( normalize_zero_one(za_input[t].astype('float32')) < auto_bg_thresh, 1, 0 ).astype('uint8') cnt = collections.Counter({x: 0 for x in keyorder}) for spot in spots: if int(redis_client.get(REDIS_KEY_STATE)) == TrainState.IDLE.value: print('aborted') return jsonify({'completed': False}) cnt[spot['tag']] += 1 centroid = np.array(spot['pos'][::-1]) covariance = np.array(spot['covariance'][::-1]).reshape(3, 3) radii, rotation = np.linalg.eigh(covariance) radii = np.sqrt(radii) dd_outer, rr_outer, cc_outer = ellipsoid( centroid, radii, rotation, scales, label.shape, MIN_AREA_ELLIPSOID ) label_offset = 0 if spot['tag'] in ['tp', 'tb', 'tn'] else 3 if spot['tag'] in ('tp', 'fn'): dd_inner, rr_inner, cc_inner = ellipsoid( centroid, radii * c_ratio, rotation, scales, label.shape, MIN_AREA_ELLIPSOID ) dd_inner_p, rr_inner_p, cc_inner_p = _dilate_3d_indices( dd_inner, rr_inner, cc_inner, label.shape) label[dd_outer, rr_outer, cc_outer] = np.where( np.fmod(label[dd_outer, rr_outer, cc_outer] - 1, 3) <= 1, 2 + label_offset, label[dd_outer, rr_outer, cc_outer] ) label[dd_inner_p, rr_inner_p, cc_inner_p] = 2 + label_offset label[dd_inner, rr_inner, cc_inner] = np.where( np.fmod(label[dd_inner, rr_inner, cc_inner] - 1, 3) <= 2, 3 + label_offset, label[dd_inner, rr_inner, cc_inner] ) elif spot['tag'] in ('tb', 'fb'): label[dd_outer, rr_outer, cc_outer] = np.where( np.fmod(label[dd_outer, rr_outer, cc_outer] - 1, 3) <= 1, 2 + label_offset, label[dd_outer, rr_outer, cc_outer] ) elif spot['tag'] in ('tn', 'fp'): label[dd_outer, rr_outer, cc_outer] = np.where( np.fmod(label[dd_outer, rr_outer, cc_outer] - 1, 3) <= 0, 1 + label_offset, label[dd_outer, rr_outer, cc_outer] ) print('frame:{}, {}'.format( t, sorted(cnt.items(), key=lambda i: keyorder.index(i[0])))) za_label[t] = label za_label_vis[t, ..., 0] = np.where( label == 1, 255, 0) + np.where(label == 4, 127, 0) za_label_vis[t, ..., 1] = np.where( label == 2, 255, 0) + np.where(label == 5, 127, 0) za_label_vis[t, ..., 2] = np.where( label == 3, 255, 0) + np.where(label == 6, 127, 0) if redis_client.get(REDIS_KEY_NCROPS) is not None: for i in range(int(redis_client.get(REDIS_KEY_NCROPS))): redis_client.rpush(REDIS_KEY_TIMEPOINT, str(t)) return jsonify({'completed': True})
def _get_seg_prediction(img, timepoint, model_path, keep_axials, device, use_median=True, patch_size=None, crop_box=None, is_pad=False): img = img.astype('float32') if use_median: global_median = np.median(img) for z in range(img.shape[0]): slice_median = np.median(img[z]) if 0 < slice_median: img[z] -= slice_median - global_median img = normalize_zero_one(img) models = load_seg_models(model_path, keep_axials, device, is_eval=True, is_pad=is_pad) if crop_box is not None: img = img[crop_box[0]:crop_box[0] + crop_box[3], # Z crop_box[1]:crop_box[1] + crop_box[4], # Y crop_box[2]:crop_box[2] + crop_box[5]] # X if patch_size is None: pad_base = (2**keep_axials.count(False), 16, 16) pad_shape = tuple( get_pad_size(img.shape[i], pad_base[i]) for i in range(len(img.shape)) ) slices = (slice(None),) + tuple( slice(None) if max(pad_shape[i]) == 0 else slice(pad_shape[i][0], -pad_shape[i][1]) for i in range(len(pad_shape)) ) img = np.pad(img, pad_shape, mode='constant', constant_values=0) with torch.no_grad(): x = torch.from_numpy(img[np.newaxis, np.newaxis]).to(device) prediction = [] # test-time augmentation (TTA) # dims: (N, C, D, H, W) for flip_x in range(1): for flip_y in range(1): for flip_z in range(1): fl_dims = [i-3 for i, flag in enumerate([flip_z, flip_y, flip_x]) if flag] xx = x.flip(fl_dims) if fl_dims else x pred = np.mean([predict(model, xx, patch_size, is_log=True) for model in models], axis=0 ) pred = np.flip(pred, fl_dims) if fl_dims else pred prediction.append(pred) prediction = np.mean(prediction, axis=0) if patch_size is None: prediction = prediction[slices] for z in range(prediction.shape[1]): post_fg = np.maximum( prediction[2, z] - normalize_zero_one( prewitt(gaussian(prediction[0, z], sigma=3))), 0 ) prediction[0, z] = np.minimum( prediction[0, z] + (prediction[2, z] - post_fg), 1 ) prediction[2, z] = post_fg prediction[1, z] = 1. - (prediction[0, z] + prediction[2, z]) return prediction