def post(self): ''' Reset seg model. ''' if all(ctype not in request.headers['Content-Type'] for ctype in ('multipart/form-data', 'application/json')): msg = ('Content-Type should be multipart/form-data or ' 'application/json') logger().error(msg) return make_response(jsonify(error=msg), 400) if get_state() != TrainState.IDLE.value: msg = 'Process is running. Model cannot be reset.' logger().error(msg) return make_response(jsonify(error=msg), 500) try: device = get_device() if 'multipart/form-data' in request.headers['Content-Type']: print(request.form) req_json = json.loads(request.form.get('data')) file = request.files['file'] checkpoint = torch.load(file.stream, map_location=device) state_dicts = checkpoint if isinstance(checkpoint, list) else [checkpoint] req_json['url'] = None else: req_json = request.get_json() state_dicts = None req_json['device'] = device config = ResetConfig(req_json) logger().info(config) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) init_seg_models(config.model_path, config.keep_axials, config.device, config.is_3d, config.n_models, config.n_crops, config.zpath_input, config.crop_size, config.scales, url=config.url, state_dicts=state_dicts, is_cpu=config.is_cpu(), cache_maxbytes=config.cache_maxbytes) except RuntimeError as e: logger().exception('Failed in init_seg_models') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in init_seg_models') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: gc.collect() torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def post(self): ''' Update flow label. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) if redis_client.get(REDIS_KEY_UPDATE_ONGOING_FLOW): msg = 'Last update is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) try: redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) req_json = request.get_json() req_json['device'] = get_device() config = FlowTrainConfig(req_json) logger().info(config) if req_json.get('reset'): zarr.open_like(zarr.open(config.zpath_flow_label, mode='r'), config.zpath_flow_label, mode='w') return make_response(jsonify({'completed': True})) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not spots_dict: msg = 'nothing to update' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) response = _update_flow_labels(spots_dict, config.scales, config.zpath_flow_label, config.flow_norm_factor) except RuntimeError as e: logger().exception('Failed in update_flow_labels') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in update_flow_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) redis_client.delete(REDIS_KEY_UPDATE_ONGOING_FLOW) return response
def post(self): ''' Download outputs in CTC format. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response('', 204) try: req_json = request.get_json() req_json['device'] = get_device() config = ExportConfig(req_json) za_input = zarr.open(config.zpath_input, mode='r') config.shape = za_input.shape[1:] logger().info(config) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) result = export_ctc_labels(config, spots_dict) if isinstance(result, str): resp = send_file(result) # file_remover.cleanup_once_done(resp, result) elif not result: resp = make_response('', 204) else: resp = make_response('', 200) except RuntimeError as e: logger().exception('Failed in export_ctc_labels') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in export_ctc_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) return resp
def post(self): ''' Update process state. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) req_json = request.get_json() state = req_json.get(REDIS_KEY_STATE) if (not isinstance(state, int) or state not in TrainState._value2member_map_): msg = f'Invalid state: {state}' logger().error(msg) return make_response(jsonify(res=msg), 400) redis_client.set(REDIS_KEY_STATE, state) return make_response(jsonify(success=True, state=get_state()))
def post(self): if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) req_json = request.get_json() lr = req_json.get('lr') if not isinstance(lr, float) or lr < 0: msg = f'Invalid learning rate: {lr}' logger().error(msg) return make_response(jsonify(error=msg), 400) n_crops = req_json.get('n_crops') if not isinstance(n_crops, int) or n_crops < 0: msg = f'Invalid number of crops: {n_crops}' logger().error(msg) return make_response(jsonify(error=msg), 400) redis_client.set(REDIS_KEY_LR, str(lr)) redis_client.set(REDIS_KEY_NCROPS, str(n_crops)) logger().info(f'[params updated] lr: {lr}, n_crops: {n_crops}') return make_response( jsonify(success=True, lr=float(redis_client.get(REDIS_KEY_LR)), n_crops=int(redis_client.get(REDIS_KEY_NCROPS))))
def post(self): ''' Train seg model. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.n_crops < 1: msg = 'n_crops should be a positive number' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not (spots_dict or config.is_livemode): msg = 'nothing to train' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) if get_state() != TrainState.IDLE.value: msg = 'Process is running' logger().error(msg) return make_response(jsonify(error=msg), 500) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) if config.is_livemode: redis_client.delete(REDIS_KEY_TIMEPOINT) else: try: _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) step_offset = 0 for path in sorted(Path(config.log_dir).glob('event*')): try: *_, last_record = TFRecordDataset(str(path)) last = event_pb2.Event.FromString(last_record.numpy()).step step_offset = max(step_offset, last + 1) except Exception: pass epoch_start = 0 async_result = train_seg_task.delay( list(spots_dict.keys()), config.batch_size, config.crop_size, config.class_weights, config.false_weight, config.model_path, config.n_epochs, config.keep_axials, config.scales, config.lr, config.n_crops, config.is_3d, config.is_livemode, config.scale_factor_base, config.rotation_angle, config.contrast, config.zpath_input, config.zpath_seg_label, config.log_interval, config.log_dir, step_offset, epoch_start, config.is_cpu(), config.is_mixed_precision, config.cache_maxbytes, config.memmap_dir, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('training aborted') return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in train_seg') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def post(self): ''' Update seg label. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) if redis_client.get(REDIS_KEY_UPDATE_ONGOING_SEG): msg = 'Last update is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) try: redis_client.set(REDIS_KEY_UPDATE_ONGOING_SEG, 1) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.is_livemode and redis_client.get(REDIS_KEY_TIMEPOINT): msg = 'Last update/training is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) if req_json.get('reset'): try: zarr.open_like(zarr.open(config.zpath_seg_label, mode='r'), config.zpath_seg_label, mode='w') zarr.open_like(zarr.open(config.zpath_seg_label_vis, mode='r'), config.zpath_seg_label_vis, mode='w') except RuntimeError as e: logger().exception('Failed in opening zarr') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in opening zarr') return make_response(jsonify(error=f'Exception: {e}'), 500) return make_response(jsonify({'completed': True})) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not spots_dict: msg = 'nothing to update' logger().error(msg) return make_response(jsonify(error=msg), 400) if config.is_livemode and len(spots_dict.keys()) != 1: msg = 'Livemode should update only a single timepoint' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) try: response = _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, config.is_livemode, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in _update_seg_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) redis_client.delete(REDIS_KEY_UPDATE_ONGOING_SEG) return response
def post(self): ''' Predict seg. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationEvalConfig(req_json) logger().info(config) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) async_result = detect_spots_task.delay( str(config.device), config.model_path, config.keep_axials, config.is_pad, config.is_3d, config.crop_size, config.scales, config.cache_maxbytes, config.use_2d, config.use_median, config.patch_size, config.crop_box, config.c_ratio, config.p_thresh, config.r_min, config.r_max, config.output_prediction, config.zpath_input, config.zpath_seg_output, config.timepoint, None, config.memmap_dir, config.batch_size, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('prediction aborted') return make_response( jsonify({ 'spots': [], 'completed': False })) if async_result.failed(): raise async_result.result spots = async_result.result if spots is None: logger().info('prediction aborted') return make_response(jsonify({ 'spots': [], 'completed': False })) publish_mq('prediction', 'Prediction updated') except Exception as e: logger().exception('Failed in detect_spots') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: gc.collect() torch.cuda.empty_cache() if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) return make_response(jsonify({'spots': spots, 'completed': True}))
def _update_seg_labels(spots_dict, scales, zpath_input, zpath_seg_label, zpath_seg_label_vis, auto_bg_thresh=0, c_ratio=0.5, is_livemode=False, memmap_dir=None): if is_livemode: assert len(spots_dict.keys()) == 1 za_input = zarr.open(zpath_input, mode='r') za_label = zarr.open(zpath_seg_label, mode='a') za_label_vis = zarr.open(zpath_seg_label_vis, mode='a') keyorder = ['tp', 'fp', 'tn', 'fn', 'tb', 'fb'] MIN_AREA_ELLIPSOID = 9 img_shape = za_input.shape[1:] n_dims = len(img_shape) keybase = Path(za_label.store.path).parent.name if scales is None: scales = (1., ) * n_dims scales = np.array(scales) for t, spots in spots_dict.items(): label_indices = set() centroids = [] label = np.zeros(img_shape, dtype='uint8') label_vis = np.zeros(img_shape + (3, ), dtype='uint8') if 0 < auto_bg_thresh: indices_bg = np.nonzero( get_input_at(za_input, t, memmap_dir=memmap_dir) < auto_bg_thresh) label[indices_bg] = 1 label_vis[indices_bg] = (255, 0, 0) if INDICES_THRESHOLD < len(indices_bg[0]): label_indices = None else: label_indices.update(tuple(map(tuple, np.array(indices_bg).T))) cnt = collections.Counter({x: 0 for x in keyorder}) for spot in spots: if get_state() == TrainState.IDLE.value: logger().info('update aborted') raise KeyboardInterrupt cnt[spot['tag']] += 1 centroid = np.array(spot['pos'][::-1]) centroid = centroid[-n_dims:] centroids.append((centroid / scales).astype(int).tolist()) covariance = np.array(spot['covariance'][::-1]).reshape(3, 3) covariance = covariance[-n_dims:, -n_dims:] radii, rotation = np.linalg.eigh(covariance) radii = np.sqrt(radii) if n_dims == 3: draw_func = ellipsoid dilate_func = _dilate_3d_indices else: draw_func = ellipse dilate_func = _dilate_2d_indices indices_outer = draw_func(centroid, radii, rotation, scales, img_shape, MIN_AREA_ELLIPSOID) if label_indices is not None: if INDICES_THRESHOLD < len(indices_outer[0]): label_indices = None else: label_indices.update( tuple(map(tuple, np.array(indices_outer).T))) if INDICES_THRESHOLD < len(label_indices): label_indices = None if spot['tag'] in ['tp', 'tb', 'tn']: label_offset = 0 label_vis_value = 255 else: label_offset = 3 label_vis_value = 127 cond_outer_1 = np.fmod(label[indices_outer] - 1, 3) <= 1 if spot['tag'] in ('tp', 'fn'): indices_inner = draw_func(centroid, radii * c_ratio, rotation, scales, img_shape, MIN_AREA_ELLIPSOID) indices_inner_p = dilate_func(*indices_inner, img_shape) label[indices_outer] = np.where(cond_outer_1, 2 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_1[..., None], (0, label_vis_value, 0), label_vis[indices_outer]) label[indices_inner_p] = 2 + label_offset label_vis[indices_inner_p] = (0, label_vis_value, 0) cond_inner = np.fmod(label[indices_inner] - 1, 3) <= 2 label[indices_inner] = np.where(cond_inner, 3 + label_offset, label[indices_inner]) label_vis[indices_inner] = np.where(cond_inner[..., None], (0, 0, label_vis_value), label_vis[indices_inner]) elif spot['tag'] in ('tb', 'fb'): label[indices_outer] = np.where(cond_outer_1, 2 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_1[..., None], (0, label_vis_value, 0), label_vis[indices_outer]) elif spot['tag'] in ('tn', 'fp'): cond_outer_0 = np.fmod(label[indices_outer] - 1, 3) <= 0 label[indices_outer] = np.where(cond_outer_0, 1 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_0[..., None], (label_vis_value, 0, 0), label_vis[indices_outer]) logger().info('frame:{}, {}'.format( t, sorted(cnt.items(), key=lambda i: keyorder.index(i[0])))) if label_indices is None: for chunk in Path(za_label.store.path).glob(f'{t}.*'): chunk.unlink() for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'): chunk.unlink() za_label[t] = label za_label_vis[t] = label_vis else: target = tuple(np.array(list(label_indices)).T) target_t = to_fancy_index(t, *target) target_vis = ( tuple(np.tile(target[i], 3) for i in range(n_dims)) + (np.array( sum(tuple([(c, ) * len(target[0]) for c in range(3)]), ())), )) target_vis_t = to_fancy_index(t, *target_vis) for chunk in Path(za_label.store.path).glob(f'{t}.*'): chunk.unlink() for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'): chunk.unlink() za_label[target_t] = label[target] za_label_vis[target_vis_t] = label_vis[target_vis] za_label.attrs[f'label.indices.{t}'] = centroids za_label.attrs['updated'] = True if memmap_dir: for fpath in Path(memmap_dir).glob( f'{keybase}-t{t}*-seglabel.dat'): lock = FileLock(str(fpath) + '.lock') with lock: if fpath.exists(): logger().info(f'remove {fpath}') fpath.unlink() if is_livemode: if redis_client.get(REDIS_KEY_TIMEPOINT): msg = 'Last update/training is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) redis_client.set(REDIS_KEY_TIMEPOINT, t) return make_response(jsonify({'completed': True}))
def __init__(self, zpath_input, zpath_seg_label, indices, img_size, crop_size, n_crops, keep_axials=(True, ) * 4, scales=None, is_livemode=False, scale_factor_base=0.2, is_ae=False, rotation_angle=None, contrast=0.5, is_eval=False, length=None, adaptive_length=False, cache_maxbytes=None, memmap_dir=None): """Generate dataset for segmentation. Args: zpath_input(String): a path to .zarr file for input data. zpath_seg_label(String): a path to .zarr file for label data. indices(list): a list of timepoints to be used. img_size(array-like of length ndim): input image size. crop_size(array-like of length ndim): crop size to generate dataset. n_crops(int): number of crops per timepoint. keep_axials(array-like of length 4): this value is used to calculate how many times down/up sampling are performed in z direction. Ignored for 2D data. scales(array-like of length ndim): a list of pixel/voxel size in physical unit (e.g. 0.5 μm/px). This is used to calculate scale factors for augmentation. is_livemode(boolean): True if training is performed in live mode. scale_factor_base(float): a base scale factor for augmentation. is_ae(boolean): True if called from a prior training. rotation_angle(float): rotation angle for augmentation in degree. contrast(float): contrast factor for augmentation. is_eval(boolean): True if the dataset is for evaluation, where no augmentation is performed. length(int): a lenght of this dataset. If None, it is automatically calculated by the length of indices and n_crops. adaptive_length(boolean): True if the length of the dataset is adjusted adaptively. For example, given that the length is 10 and the len(indices) * self.n_crops is 6, the length becomes 6 if adaptive_length is true, while it remains 10 if false. cache_maxbytes (int): size of the memory capacity for cache in byte. memmap_dir (str): path to a directory for storing memmap files. """ if len(img_size) != len(crop_size): raise ValueError( 'img_size: {} and crop_size: {} should have the same length'. format(img_size, crop_size)) if scale_factor_base < 0 or 1 <= scale_factor_base: raise ValueError( 'scale_factor_base should be 0 <= scale_factor_base < 1') self.n_dims = len(crop_size) if scales is None: scales = (1., ) * self.n_dims scale_factors = tuple(scale_factor_base * min(scales) / scales[i] for i in range(self.n_dims)) self.za_input = zarr.open(zpath_input, mode='r') crop_size = tuple( min(crop_size[i], img_size[i]) for i in range(self.n_dims)) self.img_size = tuple(img_size) self.crop_size = crop_size self.scale_factors = scale_factors self.rand_crop_ranges = [ (min(img_size[i], round(crop_size[i] * (1. - scale_factors[i]))), min(img_size[i] + 1, int(crop_size[i] * (1. + scale_factors[i])) + 1)) for i in range(self.n_dims) ] self.n_crops = n_crops self.is_ae = is_ae if is_ae: self.is_eval = False self.is_livemode = False else: # Label is not used for autoencoder self.za_label = zarr.open(zpath_seg_label, mode='r') self.zpath_seg_label = zpath_seg_label self.indices = indices self.is_eval = is_eval self.is_livemode = is_livemode and RUN_ON_FLASK if self.is_livemode: assert redis_client is not None redis_client.set(REDIS_KEY_NCROPS, str(n_crops)) if (adaptive_length and (length is not None) and (len(indices) * self.n_crops <= length)): length = None self.rotation_angle = rotation_angle self.contrast = contrast if self.is_livemode: self.length = int(redis_client.get(REDIS_KEY_NCROPS)) else: self.length = length self.keep_axials = torch.tensor(keep_axials) if cache_maxbytes: self.use_cache = True self.cache_dict_input = LRUCacheDict(cache_maxbytes // 2) self.cache_dict_label = LRUCacheDict(cache_maxbytes // 2) else: self.use_cache = False self.cache_dict_input = None self.cache_dict_label = None self.memmap_dir = memmap_dir