def post(self): ''' Reset seg model. ''' if all(ctype not in request.headers['Content-Type'] for ctype in ('multipart/form-data', 'application/json')): msg = ('Content-Type should be multipart/form-data or ' 'application/json') logger().error(msg) return make_response(jsonify(error=msg), 400) if get_state() != TrainState.IDLE.value: msg = 'Process is running. Model cannot be reset.' logger().error(msg) return make_response(jsonify(error=msg), 500) try: device = get_device() if 'multipart/form-data' in request.headers['Content-Type']: print(request.form) req_json = json.loads(request.form.get('data')) file = request.files['file'] checkpoint = torch.load(file.stream, map_location=device) state_dicts = checkpoint if isinstance(checkpoint, list) else [checkpoint] req_json['url'] = None else: req_json = request.get_json() state_dicts = None req_json['device'] = device config = ResetConfig(req_json) logger().info(config) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) init_seg_models(config.model_path, config.keep_axials, config.device, config.is_3d, config.n_models, config.n_crops, config.zpath_input, config.crop_size, config.scales, url=config.url, state_dicts=state_dicts, is_cpu=config.is_cpu(), cache_maxbytes=config.cache_maxbytes) except RuntimeError as e: logger().exception('Failed in init_seg_models') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in init_seg_models') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: gc.collect() torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def post(self): ''' Update flow label. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) if redis_client.get(REDIS_KEY_UPDATE_ONGOING_FLOW): msg = 'Last update is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) try: redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) req_json = request.get_json() req_json['device'] = get_device() config = FlowTrainConfig(req_json) logger().info(config) if req_json.get('reset'): zarr.open_like(zarr.open(config.zpath_flow_label, mode='r'), config.zpath_flow_label, mode='w') return make_response(jsonify({'completed': True})) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not spots_dict: msg = 'nothing to update' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) response = _update_flow_labels(spots_dict, config.scales, config.zpath_flow_label, config.flow_norm_factor) except RuntimeError as e: logger().exception('Failed in update_flow_labels') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in update_flow_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) redis_client.delete(REDIS_KEY_UPDATE_ONGOING_FLOW) return response
def __init__(self, config): self.dataset_name = config.get('dataset_name') self.timepoint = config.get('timepoint') if config.get('model_name') is not None: self.model_path = os.path.join(MODELS_DIR, config.get('model_name')) self.device = config.get('device', get_device()) if not torch.cuda.is_available(): self.device = torch.device("cpu") self.is_mixed_precision = config.get('is_mixed_precision', True) self.debug = config.get('debug', False) self.output_prediction = config.get('output_prediction', False) self.is_3d = config.get('is_3d', True) self.use_2d = config.get('use_2d', False) self.batch_size = config.get('batch_size', 1) self.patch_size = config.get('patch') self.log_interval = config.get('log_interval', 10) self.cache_maxbytes = config.get('cache_maxbytes') if self.patch_size is not None: self.patch_size = self.patch_size[::-1] if config.get('scales') is not None: self.scales = config.get('scales')[::-1] else: self.scales = None if config.get('input_size') is not None: self.input_size = tuple(config.get('input_size')[::-1]) else: self.input_size = None # U-Net has 4 downsamplings n_keep_axials = min(4, config.get('n_keep_axials', 4)) self.keep_axials = tuple(True if i < n_keep_axials else False for i in range(4)) self.crop_size = config.get('crop_size', DEFAULT_CROP_SIZE)[::-1] if not self.is_3d: if self.patch_size is not None: self.patch_size = self.patch_size[-2:] if self.scales is not None: self.scales = self.scales[-2:] self.crop_size = self.crop_size[-2:] if self.input_size is not None: self.input_size = self.input_size[-2:] if self.dataset_name is not None: self.zpath_input = os.path.join(DATASETS_DIR, self.dataset_name, ZARR_INPUT) else: self.zpath_input = None if config.get('use_memmap'): self.memmap_dir = MEMMAPS_DIR Path(self.memmap_dir).mkdir(parents=True, exist_ok=True) else: self.memmap_dir = None
def post(self): ''' Download outputs in CTC format. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response('', 204) try: req_json = request.get_json() req_json['device'] = get_device() config = ExportConfig(req_json) za_input = zarr.open(config.zpath_input, mode='r') config.shape = za_input.shape[1:] logger().info(config) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) result = export_ctc_labels(config, spots_dict) if isinstance(result, str): resp = send_file(result) # file_remover.cleanup_once_done(resp, result) elif not result: resp = make_response('', 204) else: resp = make_response('', 200) except RuntimeError as e: logger().exception('Failed in export_ctc_labels') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in export_ctc_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) return resp
def post(self): ''' Reset flow model. ''' if all(ctype not in request.headers['Content-Type'] for ctype in ('multipart/form-data', 'application/json')): msg = ('Content-Type should be multipart/form-data or ' 'application/json') logger().error(msg) return (jsonify(error=msg), 400) if get_state() != TrainState.IDLE.value: return make_response(jsonify(error='Process is running'), 500) try: device = get_device() if 'multipart/form-data' in request.headers['Content-Type']: print(request.form) req_json = json.loads(request.form.get('data')) file = request.files['file'] checkpoint = torch.load(file.stream, map_location=device) state_dicts = checkpoint if isinstance(checkpoint, list) else [checkpoint] req_json['url'] = None else: req_json = request.get_json() state_dicts = None req_json['device'] = device config = ResetConfig(req_json) logger().info(config) init_flow_models(config.model_path, config.device, config.is_3d, url=config.url, state_dicts=state_dicts) except RuntimeError as e: logger().exception('Failed in reset_flow_models') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in reset_flow_models') return make_response(jsonify(error=f'Exception: {e}'), 500) return make_response(jsonify({'completed': True}))
def post(self): ''' Train seg model. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.n_crops < 1: msg = 'n_crops should be a positive number' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not (spots_dict or config.is_livemode): msg = 'nothing to train' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) if get_state() != TrainState.IDLE.value: msg = 'Process is running' logger().error(msg) return make_response(jsonify(error=msg), 500) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) if config.is_livemode: redis_client.delete(REDIS_KEY_TIMEPOINT) else: try: _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) step_offset = 0 for path in sorted(Path(config.log_dir).glob('event*')): try: *_, last_record = TFRecordDataset(str(path)) last = event_pb2.Event.FromString(last_record.numpy()).step step_offset = max(step_offset, last + 1) except Exception: pass epoch_start = 0 async_result = train_seg_task.delay( list(spots_dict.keys()), config.batch_size, config.crop_size, config.class_weights, config.false_weight, config.model_path, config.n_epochs, config.keep_axials, config.scales, config.lr, config.n_crops, config.is_3d, config.is_livemode, config.scale_factor_base, config.rotation_angle, config.contrast, config.zpath_input, config.zpath_seg_label, config.log_interval, config.log_dir, step_offset, epoch_start, config.is_cpu(), config.is_mixed_precision, config.cache_maxbytes, config.memmap_dir, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('training aborted') return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in train_seg') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def post(self): ''' Update seg label. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) if redis_client.get(REDIS_KEY_UPDATE_ONGOING_SEG): msg = 'Last update is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) try: redis_client.set(REDIS_KEY_UPDATE_ONGOING_SEG, 1) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.is_livemode and redis_client.get(REDIS_KEY_TIMEPOINT): msg = 'Last update/training is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) if req_json.get('reset'): try: zarr.open_like(zarr.open(config.zpath_seg_label, mode='r'), config.zpath_seg_label, mode='w') zarr.open_like(zarr.open(config.zpath_seg_label_vis, mode='r'), config.zpath_seg_label_vis, mode='w') except RuntimeError as e: logger().exception('Failed in opening zarr') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in opening zarr') return make_response(jsonify(error=f'Exception: {e}'), 500) return make_response(jsonify({'completed': True})) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not spots_dict: msg = 'nothing to update' logger().error(msg) return make_response(jsonify(error=msg), 400) if config.is_livemode and len(spots_dict.keys()) != 1: msg = 'Livemode should update only a single timepoint' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) try: response = _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, config.is_livemode, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in _update_seg_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) redis_client.delete(REDIS_KEY_UPDATE_ONGOING_SEG) return response
def post(self): ''' Predict seg. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationEvalConfig(req_json) logger().info(config) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) async_result = detect_spots_task.delay( str(config.device), config.model_path, config.keep_axials, config.is_pad, config.is_3d, config.crop_size, config.scales, config.cache_maxbytes, config.use_2d, config.use_median, config.patch_size, config.crop_box, config.c_ratio, config.p_thresh, config.r_min, config.r_max, config.output_prediction, config.zpath_input, config.zpath_seg_output, config.timepoint, None, config.memmap_dir, config.batch_size, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('prediction aborted') return make_response( jsonify({ 'spots': [], 'completed': False })) if async_result.failed(): raise async_result.result spots = async_result.result if spots is None: logger().info('prediction aborted') return make_response(jsonify({ 'spots': [], 'completed': False })) publish_mq('prediction', 'Prediction updated') except Exception as e: logger().exception('Failed in detect_spots') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: gc.collect() torch.cuda.empty_cache() if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) return make_response(jsonify({'spots': spots, 'completed': True}))
def test_get_device(): assert get_device() in (torch.device("cpu"), torch.device("cuda"))
def device(): if 'device' not in g: g.device = get_device() return g.device