def check_dataset(dataset, shape): """Check a dataset for ELEPHANT. Parameters ---------- dataset : str or Path Dataset dir to check. shape : tuple of int Expected image shape (timepoints, depth, height, width) for 3D+t data, or (timepoints, height, width) for 2D+t data. Returns ---------- message : str Returns 'ready' if everything is ok, otherwise returns the first encountered problem. This function checks if the dataset has the foolowing files. It also checks if zarr shapes and dtypes are consistent with imgs.zarr. dataset ├── flow_hashes.zarr ├── flow_labels.zarr ├── flow_outputs.zarr ├── imgs.zarr ├── seg_labels_vis.zarr ├── seg_labels.zarr └── seg_outputs.zarr """ p = Path(dataset) message = 'ready' try: try: img = zarr.open(str(p / 'imgs.zarr'), 'r') except ValueError: raise Exception(f'{p / "imgs.zarr"} is not found or broken.') if img.shape != shape: raise Exception('Invalid shape for imgs.zarr\n' + f'Expected: {shape} Found: {img.shape}') n_dims = len(shape) - 1 n_timepoints = shape[0] _check_zarr(p / 'flow_outputs.zarr', ( n_timepoints - 1, n_dims, ) + shape[-n_dims:], 'float16') _check_zarr(p / 'flow_hashes.zarr', (n_timepoints - 1, ), 'S16') _check_zarr(p / 'flow_labels.zarr', ( n_timepoints - 1, n_dims + 1, ) + shape[-n_dims:], 'float32') _check_zarr(p / 'seg_outputs.zarr', (n_timepoints, ) + shape[-n_dims:] + (3, ), 'float16') _check_zarr(p / 'seg_labels.zarr', (n_timepoints, ) + shape[-n_dims:], 'uint8') _check_zarr(p / 'seg_labels_vis.zarr', (n_timepoints, ) + shape[-n_dims:] + (3, ), 'uint8') except Exception as e: logger().info(str(e)) message = str(e) return message
def _get_memmap_or_load(za, timepoint, memmap_dir=None, use_median=False, img_size=None): if memmap_dir: key = f'{Path(za.store.path).parent.name}-t{timepoint}-{use_median}' fpath_org = Path(memmap_dir) / f'{key}.dat' if img_size is not None: key += '-' + '-'.join(map(str, img_size)) fpath = Path(memmap_dir) / f'{key}.dat' lock = FileLock(str(fpath) + '.lock') with lock: if not fpath.exists(): logger().info(f'creating {fpath}') fpath.parent.mkdir(parents=True, exist_ok=True) img_org = np.memmap(fpath_org, dtype='float32', mode='w+', shape=za.shape[1:]) img_org[:] = za[timepoint].astype('float32') if img_size is None: img = img_org else: img = np.memmap(fpath, dtype='float32', mode='w+', shape=img_size) img[:] = F.interpolate( torch.from_numpy(img_org)[None, None], size=img_size, mode='trilinear' if img.ndim == 3 else 'bilinear', align_corners=True, )[0, 0].numpy() if use_median and img.ndim == 3: global_median = np.median(img) for z in range(img.shape[0]): slice_median = np.median(img[z]) if 0 < slice_median: img[z] -= slice_median - global_median img = normalize_zero_one(img) logger().info(f'loading from {fpath}') return np.memmap( fpath, dtype='float32', mode='c', shape=za.shape[1:] if img_size is None else img_size) else: img = za[timepoint].astype('float32') if use_median and img.ndim == 3: global_median = np.median(img) for z in range(img.shape[0]): slice_median = np.median(img[z]) if 0 < slice_median: img[z] -= slice_median - global_median img = normalize_zero_one(img) return img
def post(self): ''' Reset seg model. ''' if all(ctype not in request.headers['Content-Type'] for ctype in ('multipart/form-data', 'application/json')): msg = ('Content-Type should be multipart/form-data or ' 'application/json') logger().error(msg) return make_response(jsonify(error=msg), 400) if get_state() != TrainState.IDLE.value: msg = 'Process is running. Model cannot be reset.' logger().error(msg) return make_response(jsonify(error=msg), 500) try: device = get_device() if 'multipart/form-data' in request.headers['Content-Type']: print(request.form) req_json = json.loads(request.form.get('data')) file = request.files['file'] checkpoint = torch.load(file.stream, map_location=device) state_dicts = checkpoint if isinstance(checkpoint, list) else [checkpoint] req_json['url'] = None else: req_json = request.get_json() state_dicts = None req_json['device'] = device config = ResetConfig(req_json) logger().info(config) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) init_seg_models(config.model_path, config.keep_axials, config.device, config.is_3d, config.n_models, config.n_crops, config.zpath_input, config.crop_size, config.scales, url=config.url, state_dicts=state_dicts, is_cpu=config.is_cpu(), cache_maxbytes=config.cache_maxbytes) except RuntimeError as e: logger().exception('Failed in init_seg_models') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in init_seg_models') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: gc.collect() torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def post(self): ''' Download a model paramter file. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) req_json = request.get_json() config = BaseConfig(req_json) logger().info(config) if not Path(config.model_path).exists(): logger().info( f'model file {config.model_path} not found @{request.path}') return make_response( jsonify(message=f'model file {config.model_path} not found'), 204) try: resp = send_file(config.model_path) except Exception as e: logger().exception( 'Failed to prepare a model parameter file for download') return make_response(jsonify(error=f'Exception: {e}'), 500) return resp
def _update_flow_labels(spots_dict, scales, zpath_flow_label, flow_norm_factor): za_label = zarr.open(zpath_flow_label, mode='a') MIN_AREA_ELLIPSOID = 9 n_dims = len(za_label.shape) - 2 for t, spots in spots_dict.items(): label = np.zeros(za_label.shape[1:], dtype='float32') label_indices = set() centroids = [] for spot in spots: if get_state() == TrainState.IDLE.value: logger().info('update aborted') return make_response(jsonify({'completed': False})) centroid = np.array(spot['pos'][::-1]) centroid = centroid[-n_dims:] centroids.append((centroid / scales).astype(int).tolist()) covariance = np.array(spot['covariance'][::-1]).reshape(3, 3) covariance = covariance[-n_dims:, -n_dims:] radii, rotation = np.linalg.eigh(covariance) radii = np.sqrt(radii) draw_func = ellipsoid if n_dims == 3 else ellipse indices = draw_func(centroid, radii, rotation, scales, label.shape[-n_dims:], MIN_AREA_ELLIPSOID) weight = 1 # if spot['tag'] in ['tp'] else false_weight displacement = spot['displacement'] # X, Y, Z for i in range(n_dims): ind = (np.full(len(indices[0]), i), ) + indices label[ind] = (displacement[i] / scales[-1 - i] / flow_norm_factor[i]) # last channels is for weight ind = (np.full(len(indices[0]), -1), ) + indices label[ind] = weight label_indices.update( tuple(map(tuple, np.stack(indices, axis=1).tolist()))) logger().info(f'frame:{t+1}, {len(spots)} linkings') target = tuple(np.array(list(label_indices)).T) target = (np.array( sum(tuple([(i, ) * len(target[0]) for i in range(n_dims + 1)]), ())), ) + tuple( np.tile(target[i], n_dims + 1) for i in range(n_dims)) target_t = (np.full(len(target[0]), t), ) + target za_label.attrs[f'label.indices.{t}'] = centroids za_label.attrs['updated'] = True za_label[target_t] = label[target] return make_response(jsonify({'completed': True}))
def post(self): ''' Update process state. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) req_json = request.get_json() state = req_json.get(REDIS_KEY_STATE) if (not isinstance(state, int) or state not in TrainState._value2member_map_): msg = f'Invalid state: {state}' logger().error(msg) return make_response(jsonify(res=msg), 400) redis_client.set(REDIS_KEY_STATE, state) return make_response(jsonify(success=True, state=get_state()))
def post(self): ''' Validate Dataset. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) req_json = request.get_json() if 'dataset_name' not in req_json: return make_response( jsonify(error='dataset_name key is missing'), 400 ) if 'shape' not in req_json: return make_response(jsonify(error='shape key is missing'), 400) message = dstool.check_dataset( Path(DATASETS_DIR) / req_json['dataset_name'], tuple(req_json['shape']) ) return make_response(jsonify(message=message), 200)
def _get_memmap_or_load_label(self, timepoint, img_size=None): if self.memmap_dir: key = f'{Path(self.za_label.store.path).parent.name}-t{timepoint}' if img_size is not None: key += '-' + '-'.join(map(str, img_size)) key += '-flowlabel' fpath = Path(self.memmap_dir) / f'{key}.dat' lock = FileLock(str(fpath) + '.lock') shape = (self.n_dims + 1, ) + (self.za_label.shape[-self.n_dims:] if img_size is None else img_size) with lock: if not fpath.exists(): logger().info(f'creating {fpath}') fpath.parent.mkdir(parents=True, exist_ok=True) np.memmap( fpath, dtype='float32', mode='w+', shape=shape, )[:] = (self.za_label[timepoint] if img_size is None else F.interpolate( torch.from_numpy( self.za_label[timepoint])[None], size=img_size, mode='nearest', )[0].numpy()) logger().info(f'loading from {fpath}') return np.memmap( fpath, dtype='float32', mode='c', shape=shape, ) return (self.za_label[timepoint] if img_size is None else F.interpolate( torch.from_numpy(self.za_label[timepoint])[None], size=img_size, mode='nearest', )[0].numpy())
def post(self): ''' Download outputs in CTC format. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response('', 204) try: req_json = request.get_json() req_json['device'] = get_device() config = ExportConfig(req_json) za_input = zarr.open(config.zpath_input, mode='r') config.shape = za_input.shape[1:] logger().info(config) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) result = export_ctc_labels(config, spots_dict) if isinstance(result, str): resp = send_file(result) # file_remover.cleanup_once_done(resp, result) elif not result: resp = make_response('', 204) else: resp = make_response('', 200) except RuntimeError as e: logger().exception('Failed in export_ctc_labels') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in export_ctc_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) return resp
def post(self): ''' Generate Dataset. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) req_json = request.get_json() if 'dataset_name' not in req_json: return make_response( jsonify(error='dataset_name key is missing'), 400 ) p_dataset = Path(DATASETS_DIR) / req_json['dataset_name'] h5_files = list(sorted(p_dataset.glob('*.h5'))) if len(h5_files) == 0: logger().info(f'.h5 file not found @{request.path}') return make_response( jsonify( message=f'.h5 file not found in {req_json["dataset_name"]}' ), 204 ) if p_dataset / (p_dataset.name + '.h5') in h5_files: h5_filename = str(p_dataset / (p_dataset.name + '.h5')) else: h5_filename = str(h5_files[0]) logger().info(f'multiple .h5 files found, use {h5_filename}') try: generate_dataset_task.delay( h5_filename, str(p_dataset), req_json.get('is_uint16', None), req_json.get('divisor', 1.), req_json.get('is_2d', False), ).wait() except Exception as e: logger().exception('Failed in gen_datset') return make_response(jsonify(error=f'Exception: {e}'), 500) return make_response('', 200)
def post(self): ''' Reset flow model. ''' if all(ctype not in request.headers['Content-Type'] for ctype in ('multipart/form-data', 'application/json')): msg = ('Content-Type should be multipart/form-data or ' 'application/json') logger().error(msg) return (jsonify(error=msg), 400) if get_state() != TrainState.IDLE.value: return make_response(jsonify(error='Process is running'), 500) try: device = get_device() if 'multipart/form-data' in request.headers['Content-Type']: print(request.form) req_json = json.loads(request.form.get('data')) file = request.files['file'] checkpoint = torch.load(file.stream, map_location=device) state_dicts = checkpoint if isinstance(checkpoint, list) else [checkpoint] req_json['url'] = None else: req_json = request.get_json() state_dicts = None req_json['device'] = device config = ResetConfig(req_json) logger().info(config) init_flow_models(config.model_path, config.device, config.is_3d, url=config.url, state_dicts=state_dicts) except RuntimeError as e: logger().exception('Failed in reset_flow_models') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in reset_flow_models') return make_response(jsonify(error=f'Exception: {e}'), 500) return make_response(jsonify({'completed': True}))
def post(self): if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) req_json = request.get_json() lr = req_json.get('lr') if not isinstance(lr, float) or lr < 0: msg = f'Invalid learning rate: {lr}' logger().error(msg) return make_response(jsonify(error=msg), 400) n_crops = req_json.get('n_crops') if not isinstance(n_crops, int) or n_crops < 0: msg = f'Invalid number of crops: {n_crops}' logger().error(msg) return make_response(jsonify(error=msg), 400) redis_client.set(REDIS_KEY_LR, str(lr)) redis_client.set(REDIS_KEY_NCROPS, str(n_crops)) logger().info(f'[params updated] lr: {lr}, n_crops: {n_crops}') return make_response( jsonify(success=True, lr=float(redis_client.get(REDIS_KEY_LR)), n_crops=int(redis_client.get(REDIS_KEY_NCROPS))))
def log_before_request(): if request.endpoint not in (None, 'gpus'): logger().info(f'START {request.method} {request.path}')
def post(self): ''' Train seg model. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.n_crops < 1: msg = 'n_crops should be a positive number' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not (spots_dict or config.is_livemode): msg = 'nothing to train' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) if get_state() != TrainState.IDLE.value: msg = 'Process is running' logger().error(msg) return make_response(jsonify(error=msg), 500) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) if config.is_livemode: redis_client.delete(REDIS_KEY_TIMEPOINT) else: try: _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) step_offset = 0 for path in sorted(Path(config.log_dir).glob('event*')): try: *_, last_record = TFRecordDataset(str(path)) last = event_pb2.Event.FromString(last_record.numpy()).step step_offset = max(step_offset, last + 1) except Exception: pass epoch_start = 0 async_result = train_seg_task.delay( list(spots_dict.keys()), config.batch_size, config.crop_size, config.class_weights, config.false_weight, config.model_path, config.n_epochs, config.keep_axials, config.scales, config.lr, config.n_crops, config.is_3d, config.is_livemode, config.scale_factor_base, config.rotation_angle, config.contrast, config.zpath_input, config.zpath_seg_label, config.log_interval, config.log_dir, step_offset, epoch_start, config.is_cpu(), config.is_mixed_precision, config.cache_maxbytes, config.memmap_dir, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('training aborted') return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in train_seg') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def post(self): ''' Update flow label. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) if redis_client.get(REDIS_KEY_UPDATE_ONGOING_FLOW): msg = 'Last update is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) try: redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) req_json = request.get_json() req_json['device'] = get_device() config = FlowTrainConfig(req_json) logger().info(config) if req_json.get('reset'): zarr.open_like(zarr.open(config.zpath_flow_label, mode='r'), config.zpath_flow_label, mode='w') return make_response(jsonify({'completed': True})) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not spots_dict: msg = 'nothing to update' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) response = _update_flow_labels(spots_dict, config.scales, config.zpath_flow_label, config.flow_norm_factor) except RuntimeError as e: logger().exception('Failed in update_flow_labels') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in update_flow_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) redis_client.delete(REDIS_KEY_UPDATE_ONGOING_FLOW) return response
def post(self): ''' Update seg label. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) if redis_client.get(REDIS_KEY_UPDATE_ONGOING_SEG): msg = 'Last update is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) try: redis_client.set(REDIS_KEY_UPDATE_ONGOING_SEG, 1) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.is_livemode and redis_client.get(REDIS_KEY_TIMEPOINT): msg = 'Last update/training is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) if req_json.get('reset'): try: zarr.open_like(zarr.open(config.zpath_seg_label, mode='r'), config.zpath_seg_label, mode='w') zarr.open_like(zarr.open(config.zpath_seg_label_vis, mode='r'), config.zpath_seg_label_vis, mode='w') except RuntimeError as e: logger().exception('Failed in opening zarr') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in opening zarr') return make_response(jsonify(error=f'Exception: {e}'), 500) return make_response(jsonify({'completed': True})) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not spots_dict: msg = 'nothing to update' logger().error(msg) return make_response(jsonify(error=msg), 400) if config.is_livemode and len(spots_dict.keys()) != 1: msg = 'Livemode should update only a single timepoint' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) try: response = _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, config.is_livemode, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in _update_seg_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) redis_client.delete(REDIS_KEY_UPDATE_ONGOING_SEG) return response
def post(self): ''' Predict seg. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationEvalConfig(req_json) logger().info(config) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) async_result = detect_spots_task.delay( str(config.device), config.model_path, config.keep_axials, config.is_pad, config.is_3d, config.crop_size, config.scales, config.cache_maxbytes, config.use_2d, config.use_median, config.patch_size, config.crop_box, config.c_ratio, config.p_thresh, config.r_min, config.r_max, config.output_prediction, config.zpath_input, config.zpath_seg_output, config.timepoint, None, config.memmap_dir, config.batch_size, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('prediction aborted') return make_response( jsonify({ 'spots': [], 'completed': False })) if async_result.failed(): raise async_result.result spots = async_result.result if spots is None: logger().info('prediction aborted') return make_response(jsonify({ 'spots': [], 'completed': False })) publish_mq('prediction', 'Prediction updated') except Exception as e: logger().exception('Failed in detect_spots') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: gc.collect() torch.cuda.empty_cache() if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) return make_response(jsonify({'spots': spots, 'completed': True}))
def _update_seg_labels(spots_dict, scales, zpath_input, zpath_seg_label, zpath_seg_label_vis, auto_bg_thresh=0, c_ratio=0.5, is_livemode=False, memmap_dir=None): if is_livemode: assert len(spots_dict.keys()) == 1 za_input = zarr.open(zpath_input, mode='r') za_label = zarr.open(zpath_seg_label, mode='a') za_label_vis = zarr.open(zpath_seg_label_vis, mode='a') keyorder = ['tp', 'fp', 'tn', 'fn', 'tb', 'fb'] MIN_AREA_ELLIPSOID = 9 img_shape = za_input.shape[1:] n_dims = len(img_shape) keybase = Path(za_label.store.path).parent.name if scales is None: scales = (1., ) * n_dims scales = np.array(scales) for t, spots in spots_dict.items(): label_indices = set() centroids = [] label = np.zeros(img_shape, dtype='uint8') label_vis = np.zeros(img_shape + (3, ), dtype='uint8') if 0 < auto_bg_thresh: indices_bg = np.nonzero( get_input_at(za_input, t, memmap_dir=memmap_dir) < auto_bg_thresh) label[indices_bg] = 1 label_vis[indices_bg] = (255, 0, 0) if INDICES_THRESHOLD < len(indices_bg[0]): label_indices = None else: label_indices.update(tuple(map(tuple, np.array(indices_bg).T))) cnt = collections.Counter({x: 0 for x in keyorder}) for spot in spots: if get_state() == TrainState.IDLE.value: logger().info('update aborted') raise KeyboardInterrupt cnt[spot['tag']] += 1 centroid = np.array(spot['pos'][::-1]) centroid = centroid[-n_dims:] centroids.append((centroid / scales).astype(int).tolist()) covariance = np.array(spot['covariance'][::-1]).reshape(3, 3) covariance = covariance[-n_dims:, -n_dims:] radii, rotation = np.linalg.eigh(covariance) radii = np.sqrt(radii) if n_dims == 3: draw_func = ellipsoid dilate_func = _dilate_3d_indices else: draw_func = ellipse dilate_func = _dilate_2d_indices indices_outer = draw_func(centroid, radii, rotation, scales, img_shape, MIN_AREA_ELLIPSOID) if label_indices is not None: if INDICES_THRESHOLD < len(indices_outer[0]): label_indices = None else: label_indices.update( tuple(map(tuple, np.array(indices_outer).T))) if INDICES_THRESHOLD < len(label_indices): label_indices = None if spot['tag'] in ['tp', 'tb', 'tn']: label_offset = 0 label_vis_value = 255 else: label_offset = 3 label_vis_value = 127 cond_outer_1 = np.fmod(label[indices_outer] - 1, 3) <= 1 if spot['tag'] in ('tp', 'fn'): indices_inner = draw_func(centroid, radii * c_ratio, rotation, scales, img_shape, MIN_AREA_ELLIPSOID) indices_inner_p = dilate_func(*indices_inner, img_shape) label[indices_outer] = np.where(cond_outer_1, 2 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_1[..., None], (0, label_vis_value, 0), label_vis[indices_outer]) label[indices_inner_p] = 2 + label_offset label_vis[indices_inner_p] = (0, label_vis_value, 0) cond_inner = np.fmod(label[indices_inner] - 1, 3) <= 2 label[indices_inner] = np.where(cond_inner, 3 + label_offset, label[indices_inner]) label_vis[indices_inner] = np.where(cond_inner[..., None], (0, 0, label_vis_value), label_vis[indices_inner]) elif spot['tag'] in ('tb', 'fb'): label[indices_outer] = np.where(cond_outer_1, 2 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_1[..., None], (0, label_vis_value, 0), label_vis[indices_outer]) elif spot['tag'] in ('tn', 'fp'): cond_outer_0 = np.fmod(label[indices_outer] - 1, 3) <= 0 label[indices_outer] = np.where(cond_outer_0, 1 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_0[..., None], (label_vis_value, 0, 0), label_vis[indices_outer]) logger().info('frame:{}, {}'.format( t, sorted(cnt.items(), key=lambda i: keyorder.index(i[0])))) if label_indices is None: for chunk in Path(za_label.store.path).glob(f'{t}.*'): chunk.unlink() for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'): chunk.unlink() za_label[t] = label za_label_vis[t] = label_vis else: target = tuple(np.array(list(label_indices)).T) target_t = to_fancy_index(t, *target) target_vis = ( tuple(np.tile(target[i], 3) for i in range(n_dims)) + (np.array( sum(tuple([(c, ) * len(target[0]) for c in range(3)]), ())), )) target_vis_t = to_fancy_index(t, *target_vis) for chunk in Path(za_label.store.path).glob(f'{t}.*'): chunk.unlink() for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'): chunk.unlink() za_label[target_t] = label[target] za_label_vis[target_vis_t] = label_vis[target_vis] za_label.attrs[f'label.indices.{t}'] = centroids za_label.attrs['updated'] = True if memmap_dir: for fpath in Path(memmap_dir).glob( f'{keybase}-t{t}*-seglabel.dat'): lock = FileLock(str(fpath) + '.lock') with lock: if fpath.exists(): logger().info(f'remove {fpath}') fpath.unlink() if is_livemode: if redis_client.get(REDIS_KEY_TIMEPOINT): msg = 'Last update/training is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) redis_client.set(REDIS_KEY_TIMEPOINT, t) return make_response(jsonify({'completed': True}))
def log_after_request(response): if request.endpoint not in (None, 'gpus'): logger().info( f'DONE {request.method} {request.path} => [{response.status}]') return response
def generate_dataset(input, output, is_uint16=False, divisor=1., is_2d=False, is_message_queue=False, is_multiprocessing=True): """Generate a dataset for ELEPHANT. Parameters ---------- input : str Input .h5 file. output : str or Path Output directory. is_uint16 : bool With this flag, the original image will be stored with uint16. If None, determine if uint8 or uint16 dynamically. default: False (uint8) divisor : float Divide the original pixel values by this value. (with uint8, the values should be scale-downed to 0-255) default: 1.0 is_2d : bool With this flag, the original image will be stored as 2d+time. default: False (3d+time) is_message_queue : bool With this flag, progress is reported using pika.BlockingConnection. is_multiprocessing : bool With this flag, multiprocessing is enabled. This function will generate the following files. output ├── input.h5 ├── flow_hashes.zarr ├── flow_labels.zarr ├── flow_outputs.zarr ├── imgs.zarr ├── seg_labels_vis.zarr ├── seg_labels.zarr └── seg_outputs.zarr """ logger().info(f'input: {input}') logger().info(f'output dir: {output}') logger().info(f'divisor: {divisor}') with h5py.File(input, 'r') as f: # timepoints are stored as 't00000', 't000001', ... timepoints = list(filter(re.compile(r't\d{5}').search, list(f.keys()))) def func(x): return x[0] if is_2d else x # determine if uint8 or uint16 dynamically if is_uint16 is None: is_uint16 = False for timepoint in tqdm(timepoints): if 255 < np.array(func( f[timepoint]['s00']['0']['cells'])).max(): is_uint16 = True break shape = f[timepoints[0]]['s00']['0']['cells'].shape if is_2d: shape = shape[-2:] n_dims = 3 - is_2d # 3 or 2 n_timepoints = len(timepoints) p = Path(output) chunk_shape = tuple(min(s, 1024) for s in shape[-2:]) if not is_2d: chunk_shape = (1, ) + chunk_shape zarr.open(str(p / 'imgs.zarr'), 'w', shape=(n_timepoints, ) + shape, chunks=(1, ) + chunk_shape, dtype='u2' if is_uint16 else 'u1') zarr.open(str(p / 'flow_outputs.zarr'), 'w', shape=( n_timepoints - 1, n_dims, ) + shape, chunks=( 1, 1, ) + chunk_shape, dtype='f2') zarr.open(str(p / 'flow_hashes.zarr'), 'w', shape=(n_timepoints - 1, ), chunks=(n_timepoints - 1, ), dtype='S16') zarr.open(str(p / 'flow_labels.zarr'), 'w', shape=( n_timepoints - 1, n_dims + 1, ) + shape, chunks=( 1, 1, ) + chunk_shape, dtype='f4') zarr.open(str(p / 'seg_outputs.zarr'), 'w', shape=(n_timepoints, ) + shape + (3, ), chunks=(1, ) + chunk_shape + (1, ), dtype='f2') zarr.open(str(p / 'seg_labels.zarr'), 'w', shape=(n_timepoints, ) + shape, chunks=(1, ) + chunk_shape, dtype='u1') zarr.open(str(p / 'seg_labels_vis.zarr'), 'w', shape=(n_timepoints, ) + shape + (3, ), chunks=(1, ) + chunk_shape + (1, ), dtype='u1') dtype = np.uint16 if is_uint16 else np.uint8 if n_dims == 2: chunks = tuple((slice(None), slice(y, y + chunk_shape[-2]), slice(x, x + chunk_shape[-1])) for y in range(0, shape[-2], chunk_shape[-2]) for x in range(0, shape[-1], chunk_shape[-1])) else: chunks = tuple( (slice(z, z + chunk_shape[-3]), slice(y, y + chunk_shape[-2]), slice(x, x + chunk_shape[-1])) for z in range(0, shape[-3], chunk_shape[-3]) for y in range(0, shape[-2], chunk_shape[-2]) for x in range(0, shape[-1], chunk_shape[-1])) if is_multiprocessing: pool = mp.Pool() try: for t, timepoint in tqdm(enumerate(timepoints)): partial_write_chunk = partial(write_chunk, zpath=str(p / 'imgs.zarr'), t=t, n_dims=n_dims, is_2d=is_2d, input=input, timepoint=timepoint, divisor=divisor, dtype=dtype) if is_multiprocessing: pool.map(partial_write_chunk, chunks) else: for chunk in chunks: partial_write_chunk(chunk) if is_message_queue: publish_mq( 'dataset', json.dumps({ 't_max': n_timepoints, 't_current': t + 1, })) finally: if is_multiprocessing: pool.close()