def post(self): ''' Update flow label. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) if redis_client.get(REDIS_KEY_UPDATE_ONGOING_FLOW): msg = 'Last update is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) try: redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) req_json = request.get_json() req_json['device'] = get_device() config = FlowTrainConfig(req_json) logger().info(config) if req_json.get('reset'): zarr.open_like(zarr.open(config.zpath_flow_label, mode='r'), config.zpath_flow_label, mode='w') return make_response(jsonify({'completed': True})) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not spots_dict: msg = 'nothing to update' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) response = _update_flow_labels(spots_dict, config.scales, config.zpath_flow_label, config.flow_norm_factor) except RuntimeError as e: logger().exception('Failed in update_flow_labels') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in update_flow_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) redis_client.delete(REDIS_KEY_UPDATE_ONGOING_FLOW) return response
def post(self): ''' Download outputs in CTC format. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response('', 204) try: req_json = request.get_json() req_json['device'] = get_device() config = ExportConfig(req_json) za_input = zarr.open(config.zpath_input, mode='r') config.shape = za_input.shape[1:] logger().info(config) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) result = export_ctc_labels(config, spots_dict) if isinstance(result, str): resp = send_file(result) # file_remover.cleanup_once_done(resp, result) elif not result: resp = make_response('', 204) else: resp = make_response('', 200) except RuntimeError as e: logger().exception('Failed in export_ctc_labels') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in export_ctc_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) return resp
def __getitem__(self, index): if (redis_client is not None and get_state() == TrainState.IDLE.value): raise KeyboardInterrupt data_ind = index // len(self.patch_list) patch_ind = index % len(self.patch_list) slices, _ = self.patch_list[patch_ind] return (self.input[(data_ind, ) + tuple(slices)], self.keep_axials[data_ind], data_ind, patch_ind)
def post(self): ''' Reset seg model. ''' if all(ctype not in request.headers['Content-Type'] for ctype in ('multipart/form-data', 'application/json')): msg = ('Content-Type should be multipart/form-data or ' 'application/json') logger().error(msg) return make_response(jsonify(error=msg), 400) if get_state() != TrainState.IDLE.value: msg = 'Process is running. Model cannot be reset.' logger().error(msg) return make_response(jsonify(error=msg), 500) try: device = get_device() if 'multipart/form-data' in request.headers['Content-Type']: print(request.form) req_json = json.loads(request.form.get('data')) file = request.files['file'] checkpoint = torch.load(file.stream, map_location=device) state_dicts = checkpoint if isinstance(checkpoint, list) else [checkpoint] req_json['url'] = None else: req_json = request.get_json() state_dicts = None req_json['device'] = device config = ResetConfig(req_json) logger().info(config) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) init_seg_models(config.model_path, config.keep_axials, config.device, config.is_3d, config.n_models, config.n_crops, config.zpath_input, config.crop_size, config.scales, url=config.url, state_dicts=state_dicts, is_cpu=config.is_cpu(), cache_maxbytes=config.cache_maxbytes) except RuntimeError as e: logger().exception('Failed in init_seg_models') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in init_seg_models') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: gc.collect() torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def _update_flow_labels(spots_dict, scales, zpath_flow_label, flow_norm_factor): za_label = zarr.open(zpath_flow_label, mode='a') MIN_AREA_ELLIPSOID = 9 n_dims = len(za_label.shape) - 2 for t, spots in spots_dict.items(): label = np.zeros(za_label.shape[1:], dtype='float32') label_indices = set() centroids = [] for spot in spots: if get_state() == TrainState.IDLE.value: logger().info('update aborted') return make_response(jsonify({'completed': False})) centroid = np.array(spot['pos'][::-1]) centroid = centroid[-n_dims:] centroids.append((centroid / scales).astype(int).tolist()) covariance = np.array(spot['covariance'][::-1]).reshape(3, 3) covariance = covariance[-n_dims:, -n_dims:] radii, rotation = np.linalg.eigh(covariance) radii = np.sqrt(radii) draw_func = ellipsoid if n_dims == 3 else ellipse indices = draw_func(centroid, radii, rotation, scales, label.shape[-n_dims:], MIN_AREA_ELLIPSOID) weight = 1 # if spot['tag'] in ['tp'] else false_weight displacement = spot['displacement'] # X, Y, Z for i in range(n_dims): ind = (np.full(len(indices[0]), i), ) + indices label[ind] = (displacement[i] / scales[-1 - i] / flow_norm_factor[i]) # last channels is for weight ind = (np.full(len(indices[0]), -1), ) + indices label[ind] = weight label_indices.update( tuple(map(tuple, np.stack(indices, axis=1).tolist()))) logger().info(f'frame:{t+1}, {len(spots)} linkings') target = tuple(np.array(list(label_indices)).T) target = (np.array( sum(tuple([(i, ) * len(target[0]) for i in range(n_dims + 1)]), ())), ) + tuple( np.tile(target[i], n_dims + 1) for i in range(n_dims)) target_t = (np.full(len(target[0]), t), ) + target za_label.attrs[f'label.indices.{t}'] = centroids za_label.attrs['updated'] = True za_label[target_t] = label[target] return make_response(jsonify({'completed': True}))
def post(self): ''' Update process state. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) req_json = request.get_json() state = req_json.get(REDIS_KEY_STATE) if (not isinstance(state, int) or state not in TrainState._value2member_map_): msg = f'Invalid state: {state}' logger().error(msg) return make_response(jsonify(res=msg), 400) redis_client.set(REDIS_KEY_STATE, state) return make_response(jsonify(success=True, state=get_state()))
def post(self): ''' Reset flow model. ''' if all(ctype not in request.headers['Content-Type'] for ctype in ('multipart/form-data', 'application/json')): msg = ('Content-Type should be multipart/form-data or ' 'application/json') logger().error(msg) return (jsonify(error=msg), 400) if get_state() != TrainState.IDLE.value: return make_response(jsonify(error='Process is running'), 500) try: device = get_device() if 'multipart/form-data' in request.headers['Content-Type']: print(request.form) req_json = json.loads(request.form.get('data')) file = request.files['file'] checkpoint = torch.load(file.stream, map_location=device) state_dicts = checkpoint if isinstance(checkpoint, list) else [checkpoint] req_json['url'] = None else: req_json = request.get_json() state_dicts = None req_json['device'] = device config = ResetConfig(req_json) logger().info(config) init_flow_models(config.model_path, config.device, config.is_3d, url=config.url, state_dicts=state_dicts) except RuntimeError as e: logger().exception('Failed in reset_flow_models') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in reset_flow_models') return make_response(jsonify(error=f'Exception: {e}'), 500) return make_response(jsonify({'completed': True}))
def post(self): ''' Train seg model. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.n_crops < 1: msg = 'n_crops should be a positive number' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not (spots_dict or config.is_livemode): msg = 'nothing to train' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) if get_state() != TrainState.IDLE.value: msg = 'Process is running' logger().error(msg) return make_response(jsonify(error=msg), 500) redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value) if config.is_livemode: redis_client.delete(REDIS_KEY_TIMEPOINT) else: try: _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) step_offset = 0 for path in sorted(Path(config.log_dir).glob('event*')): try: *_, last_record = TFRecordDataset(str(path)) last = event_pb2.Event.FromString(last_record.numpy()).step step_offset = max(step_offset, last + 1) except Exception: pass epoch_start = 0 async_result = train_seg_task.delay( list(spots_dict.keys()), config.batch_size, config.crop_size, config.class_weights, config.false_weight, config.model_path, config.n_epochs, config.keep_axials, config.scales, config.lr, config.n_crops, config.is_3d, config.is_livemode, config.scale_factor_base, config.rotation_angle, config.contrast, config.zpath_input, config.zpath_seg_label, config.log_interval, config.log_dir, step_offset, epoch_start, config.is_cpu(), config.is_mixed_precision, config.cache_maxbytes, config.memmap_dir, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('training aborted') return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in train_seg') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: torch.cuda.empty_cache() redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value) return make_response(jsonify({'completed': True}))
def post(self): ''' Update seg label. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' logger().error(msg) return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) if redis_client.get(REDIS_KEY_UPDATE_ONGOING_SEG): msg = 'Last update is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) try: redis_client.set(REDIS_KEY_UPDATE_ONGOING_SEG, 1) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) req_json = request.get_json() req_json['device'] = get_device() config = SegmentationTrainConfig(req_json) logger().info(config) if config.is_livemode and redis_client.get(REDIS_KEY_TIMEPOINT): msg = 'Last update/training is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) if req_json.get('reset'): try: zarr.open_like(zarr.open(config.zpath_seg_label, mode='r'), config.zpath_seg_label, mode='w') zarr.open_like(zarr.open(config.zpath_seg_label_vis, mode='r'), config.zpath_seg_label_vis, mode='w') except RuntimeError as e: logger().exception('Failed in opening zarr') return make_response(jsonify(error=f'Runtime Error: {e}'), 500) except Exception as e: logger().exception('Failed in opening zarr') return make_response(jsonify(error=f'Exception: {e}'), 500) return make_response(jsonify({'completed': True})) spots_dict = collections.defaultdict(list) for spot in req_json.get('spots'): spots_dict[spot['t']].append(spot) if not spots_dict: msg = 'nothing to update' logger().error(msg) return make_response(jsonify(error=msg), 400) if config.is_livemode and len(spots_dict.keys()) != 1: msg = 'Livemode should update only a single timepoint' logger().error(msg) return make_response(jsonify(error=msg), 400) spots_dict = collections.OrderedDict(sorted(spots_dict.items())) try: response = _update_seg_labels(spots_dict, config.scales, config.zpath_input, config.zpath_seg_label, config.zpath_seg_label_vis, config.auto_bg_thresh, config.c_ratio, config.is_livemode, memmap_dir=config.memmap_dir) except KeyboardInterrupt: return make_response(jsonify({'completed': False})) except Exception as e: logger().exception('Failed in _update_seg_labels') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) redis_client.delete(REDIS_KEY_UPDATE_ONGOING_SEG) return response
def post(self): ''' Predict seg. ''' if request.headers['Content-Type'] != 'application/json': msg = 'Content-Type should be application/json' return make_response(jsonify(error=msg), 400) state = get_state() while (state == TrainState.WAIT.value): logger().info(f'waiting @{request.path}') time.sleep(1) state = get_state() if (state == TrainState.IDLE.value): return make_response(jsonify({'completed': False})) try: req_json = request.get_json() req_json['device'] = get_device() config = SegmentationEvalConfig(req_json) logger().info(config) redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value) async_result = detect_spots_task.delay( str(config.device), config.model_path, config.keep_axials, config.is_pad, config.is_3d, config.crop_size, config.scales, config.cache_maxbytes, config.use_2d, config.use_median, config.patch_size, config.crop_box, config.c_ratio, config.p_thresh, config.r_min, config.r_max, config.output_prediction, config.zpath_input, config.zpath_seg_output, config.timepoint, None, config.memmap_dir, config.batch_size, config.input_size, ) while not async_result.ready(): if (redis_client is not None and get_state() == TrainState.IDLE.value): logger().info('prediction aborted') return make_response( jsonify({ 'spots': [], 'completed': False })) if async_result.failed(): raise async_result.result spots = async_result.result if spots is None: logger().info('prediction aborted') return make_response(jsonify({ 'spots': [], 'completed': False })) publish_mq('prediction', 'Prediction updated') except Exception as e: logger().exception('Failed in detect_spots') return make_response(jsonify(error=f'Exception: {e}'), 500) finally: gc.collect() torch.cuda.empty_cache() if get_state() != TrainState.IDLE.value: redis_client.set(REDIS_KEY_STATE, state) return make_response(jsonify({'spots': spots, 'completed': True}))
def _update_seg_labels(spots_dict, scales, zpath_input, zpath_seg_label, zpath_seg_label_vis, auto_bg_thresh=0, c_ratio=0.5, is_livemode=False, memmap_dir=None): if is_livemode: assert len(spots_dict.keys()) == 1 za_input = zarr.open(zpath_input, mode='r') za_label = zarr.open(zpath_seg_label, mode='a') za_label_vis = zarr.open(zpath_seg_label_vis, mode='a') keyorder = ['tp', 'fp', 'tn', 'fn', 'tb', 'fb'] MIN_AREA_ELLIPSOID = 9 img_shape = za_input.shape[1:] n_dims = len(img_shape) keybase = Path(za_label.store.path).parent.name if scales is None: scales = (1., ) * n_dims scales = np.array(scales) for t, spots in spots_dict.items(): label_indices = set() centroids = [] label = np.zeros(img_shape, dtype='uint8') label_vis = np.zeros(img_shape + (3, ), dtype='uint8') if 0 < auto_bg_thresh: indices_bg = np.nonzero( get_input_at(za_input, t, memmap_dir=memmap_dir) < auto_bg_thresh) label[indices_bg] = 1 label_vis[indices_bg] = (255, 0, 0) if INDICES_THRESHOLD < len(indices_bg[0]): label_indices = None else: label_indices.update(tuple(map(tuple, np.array(indices_bg).T))) cnt = collections.Counter({x: 0 for x in keyorder}) for spot in spots: if get_state() == TrainState.IDLE.value: logger().info('update aborted') raise KeyboardInterrupt cnt[spot['tag']] += 1 centroid = np.array(spot['pos'][::-1]) centroid = centroid[-n_dims:] centroids.append((centroid / scales).astype(int).tolist()) covariance = np.array(spot['covariance'][::-1]).reshape(3, 3) covariance = covariance[-n_dims:, -n_dims:] radii, rotation = np.linalg.eigh(covariance) radii = np.sqrt(radii) if n_dims == 3: draw_func = ellipsoid dilate_func = _dilate_3d_indices else: draw_func = ellipse dilate_func = _dilate_2d_indices indices_outer = draw_func(centroid, radii, rotation, scales, img_shape, MIN_AREA_ELLIPSOID) if label_indices is not None: if INDICES_THRESHOLD < len(indices_outer[0]): label_indices = None else: label_indices.update( tuple(map(tuple, np.array(indices_outer).T))) if INDICES_THRESHOLD < len(label_indices): label_indices = None if spot['tag'] in ['tp', 'tb', 'tn']: label_offset = 0 label_vis_value = 255 else: label_offset = 3 label_vis_value = 127 cond_outer_1 = np.fmod(label[indices_outer] - 1, 3) <= 1 if spot['tag'] in ('tp', 'fn'): indices_inner = draw_func(centroid, radii * c_ratio, rotation, scales, img_shape, MIN_AREA_ELLIPSOID) indices_inner_p = dilate_func(*indices_inner, img_shape) label[indices_outer] = np.where(cond_outer_1, 2 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_1[..., None], (0, label_vis_value, 0), label_vis[indices_outer]) label[indices_inner_p] = 2 + label_offset label_vis[indices_inner_p] = (0, label_vis_value, 0) cond_inner = np.fmod(label[indices_inner] - 1, 3) <= 2 label[indices_inner] = np.where(cond_inner, 3 + label_offset, label[indices_inner]) label_vis[indices_inner] = np.where(cond_inner[..., None], (0, 0, label_vis_value), label_vis[indices_inner]) elif spot['tag'] in ('tb', 'fb'): label[indices_outer] = np.where(cond_outer_1, 2 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_1[..., None], (0, label_vis_value, 0), label_vis[indices_outer]) elif spot['tag'] in ('tn', 'fp'): cond_outer_0 = np.fmod(label[indices_outer] - 1, 3) <= 0 label[indices_outer] = np.where(cond_outer_0, 1 + label_offset, label[indices_outer]) label_vis[indices_outer] = np.where(cond_outer_0[..., None], (label_vis_value, 0, 0), label_vis[indices_outer]) logger().info('frame:{}, {}'.format( t, sorted(cnt.items(), key=lambda i: keyorder.index(i[0])))) if label_indices is None: for chunk in Path(za_label.store.path).glob(f'{t}.*'): chunk.unlink() for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'): chunk.unlink() za_label[t] = label za_label_vis[t] = label_vis else: target = tuple(np.array(list(label_indices)).T) target_t = to_fancy_index(t, *target) target_vis = ( tuple(np.tile(target[i], 3) for i in range(n_dims)) + (np.array( sum(tuple([(c, ) * len(target[0]) for c in range(3)]), ())), )) target_vis_t = to_fancy_index(t, *target_vis) for chunk in Path(za_label.store.path).glob(f'{t}.*'): chunk.unlink() for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'): chunk.unlink() za_label[target_t] = label[target] za_label_vis[target_vis_t] = label_vis[target_vis] za_label.attrs[f'label.indices.{t}'] = centroids za_label.attrs['updated'] = True if memmap_dir: for fpath in Path(memmap_dir).glob( f'{keybase}-t{t}*-seglabel.dat'): lock = FileLock(str(fpath) + '.lock') with lock: if fpath.exists(): logger().info(f'remove {fpath}') fpath.unlink() if is_livemode: if redis_client.get(REDIS_KEY_TIMEPOINT): msg = 'Last update/training is ongoing' logger().error(msg) return make_response(jsonify(error=msg), 400) redis_client.set(REDIS_KEY_TIMEPOINT, t) return make_response(jsonify({'completed': True}))
def get(self): ''' Check Process state. ''' return make_response(jsonify(success=True, state=get_state()))
def __getitem__(self, index): """ Input shape: (2, (D,) H, W) Label shape: (ndim+1, (D,) H, W) Label channels: (flow_x, flow_y, flow_z, mask) """ if (redis_client is not None and get_state() == TrainState.IDLE.value): raise KeyboardInterrupt if self.length is not None: i_frame = np.random.choice(self.indices) else: i_frame = self.indices[index // self.n_crops] img_input = get_inputs_at(self.za_input, i_frame, cache_dict=self.cache_dict_input, memmap_dir=self.memmap_dir, img_size=self.img_size) if self.za_input.shape[1:] != self.img_size: resize_factor = [ self.img_size[d] / self.za_input.shape[1 + d] for d in range(self.n_dims) ] else: resize_factor = [ 1, ] * self.n_dims img_label = self._get_label_at(i_frame, img_size=self.img_size) if self.is_eval: tensor_input = torch.from_numpy(img_input) tensor_label = torch.from_numpy(img_label) tensor_target = torch.cat((tensor_label, tensor_input), ) return (tensor_input, self.keep_axials), tensor_target while True: if 0 < sum(self.scale_factors): item_crop_size = [ randrange( min( img_input.shape[i + 1], round(self.crop_size[i] * (1. - self.scale_factors[i]))), min( img_input.shape[i + 1] + 1, int(self.crop_size[i] * (1. + self.scale_factors[i])) + 1)) for i in range(self.n_dims) ] else: item_crop_size = self.crop_size if self.rotation_angle is not None and 0 < self.rotation_angle: # rotate image theta = randint(-self.rotation_angle, self.rotation_angle) cos_theta = math.cos(math.radians(theta)) sin_theta = math.sin(math.radians(theta)) for i in (-2, -1): item_crop_size[i] *= (abs(cos_theta) + abs(sin_theta)) item_crop_size[i] = math.ceil(item_crop_size[i]) item_crop_size = [ min(img_input.shape[i], item_crop_size[i]) for i in range(self.n_dims) ] za_label_a = zarr.open(self.zpath_flow_label, mode='a') index_pool = za_label_a.attrs.get(f'label.indices.{i_frame}') if index_pool is None: index_pool = np.argwhere(0 < img_label[-1]) za_label_a.attrs[f'label.indices.{i_frame}'] = tuple( map(tuple, index_pool.tolist())) base_index = index_pool[randrange(len(index_pool))] origins = [ randint( max(0, (int(base_index[i] * resize_factor[i]) - (item_crop_size[i] - 1))), min((img_input.shape[1 + i] - item_crop_size[i]), int(base_index[i] * resize_factor[i]))) for i in range(self.n_dims) ] slices = (slice(None), ) + tuple( slice(origins[i], origins[i] + item_crop_size[i]) for i in range(self.n_dims)) sliced_label = img_label[slices].copy() # assert 0 < sliced_label[-1].max() sliced_input = img_input[slices].copy() # scale labels by resize factor if 0 < sum(self.scale_factors): sliced_label[0] *= self.crop_size[-1] / item_crop_size[-1] # X sliced_label[1] *= self.crop_size[-2] / item_crop_size[-2] # Y if self.n_dims == 3: sliced_label[2] *= self.crop_size[-3] / \ item_crop_size[-3] # Z if self.rotation_angle is not None and 0 < self.rotation_angle: if self.n_dims == 3: sliced_input = np.array([ [ rotate( sliced_input[c, z], theta, resize=True, preserve_range=True, order=1, # 1: Bi-linear (default) ) for z in range(sliced_input.shape[1]) ] for c in range(sliced_input.shape[0]) ]) else: sliced_input = np.array([ rotate( sliced_input[c], theta, resize=True, preserve_range=True, order=1, # 1: Bi-linear (default) ) for c in range(sliced_input.shape[0]) ]) h_crop, w_crop = self.crop_size[-2:] h_rotate, w_rotate = sliced_input.shape[-2:] r_origin = max(0, (h_rotate - h_crop) // 2) c_origin = max(0, (w_rotate - w_crop) // 2) sliced_input = sliced_input[..., r_origin:r_origin + h_crop, c_origin:c_origin + w_crop] # rotate label if self.n_dims == 3: sliced_label = np.array([ [ rotate( sliced_label[c, z], theta, resize=True, preserve_range=True, order=0, # 0: Nearest-neighbor ) for z in range(sliced_label.shape[1]) ] for c in range(sliced_label.shape[0]) ]) else: sliced_label = np.array([ rotate( sliced_label[c], theta, resize=True, preserve_range=True, order=0, # 0: Nearest-neighbor ) for c in range(sliced_label.shape[0]) ]) sliced_label = sliced_label[..., r_origin:r_origin + h_crop, c_origin:c_origin + w_crop] if sliced_label[-1].max() == 0: continue # update flow label (y axis is inversed in the image coordinate) cos_theta = math.cos(math.radians(theta)) sin_theta = math.sin(math.radians(theta)) sliced_label_x = sliced_label[0].copy() sliced_label_y = sliced_label[1].copy() * -1 sliced_label[0] = (cos_theta * sliced_label_x - sin_theta * sliced_label_y) sliced_label[1] = (sin_theta * sliced_label_x + cos_theta * sliced_label_y) * -1 if 0 < sliced_label[-1].max(): break tensor_input = torch.from_numpy(sliced_input) tensor_label = torch.from_numpy(sliced_label) if tensor_input.shape[1:] != self.crop_size: interpolate_mode = 'trilinear' if self.n_dims == 3 else 'bilinear' tensor_input = F.interpolate(tensor_input[None], self.crop_size, mode=interpolate_mode, align_corners=True)[0] tensor_label = F.interpolate(tensor_label[None].float(), self.crop_size, mode='nearest')[0] is_flip = True if is_flip: flip_dims = [ -(1 + i) for i, v in enumerate(torch.rand(self.n_dims)) if v < 0.5 ] tensor_input = torch.flip(tensor_input, flip_dims) tensor_label = torch.flip(tensor_label, flip_dims) for flip_dim in flip_dims: tensor_label[-1 - flip_dim] *= -1 # Channel order: (flow_x, flow_y, flow_z, mask, input_t0, input_t1) tensor_target = torch.cat((tensor_label, tensor_input), ) return (tensor_input, self.keep_axials), tensor_target
def __getitem__(self, index): """ Input shape: ((D,) H, W) Label shape: ((D,) H, W) Label values: 0: unlabeled, 1: BG (LW), 2: Outer (LW), 3: Inner (LW) 4: BG (HW), 5: Outer (HW), 6: Inner (HW) """ if (redis_client is not None and get_state() == TrainState.IDLE.value): raise KeyboardInterrupt if self.is_ae: i_frame = randrange(self.za_input.shape[0]) else: if self.is_livemode: while True: v = redis_client.get(REDIS_KEY_TIMEPOINT) if v is not None: i_frame = int(v) img_label = self._get_label_at(i_frame, self.img_size) break if (get_state() == TrainState.IDLE.value): raise KeyboardInterrupt if self.length != int(redis_client.get(REDIS_KEY_NCROPS)): return ((torch.tensor(-200.), self.keep_axials), torch.tensor(-200)) else: if self.length is not None: i_frame = np.random.choice(self.indices) else: i_frame = self.indices[index // self.n_crops] img_label = self._get_label_at(i_frame, self.img_size) img_input = get_input_at(self.za_input, i_frame, self.cache_dict_input, self.memmap_dir, img_size=self.img_size) if self.za_input.shape[1:] != self.img_size: resize_factor = [ self.img_size[d] / self.za_input.shape[1 + d] for d in range(img_input.ndim) ] else: resize_factor = [ 1, ] * img_input.ndim if self.is_eval: tensor_input = torch.from_numpy(img_input[None]) tensor_label = torch.from_numpy(img_label).long() return (tensor_input, self.keep_axials), tensor_label - 1 while True: if 0 < sum(self.scale_factors): item_crop_size = [ randrange( min( img_input.shape[i], round(self.crop_size[i] * (1. - self.scale_factors[i]))), min( img_input.shape[i] + 1, int(self.crop_size[i] * (1. + self.scale_factors[i])) + 1)) for i in range(self.n_dims) ] else: item_crop_size = self.crop_size if self.rotation_angle is not None and 0 < self.rotation_angle: # rotate image theta = randint(-self.rotation_angle, self.rotation_angle) cos_theta = math.cos(math.radians(theta)) sin_theta = math.sin(math.radians(theta)) for i in (-2, -1): item_crop_size[i] *= (abs(cos_theta) + abs(sin_theta)) item_crop_size[i] = math.ceil(item_crop_size[i]) item_crop_size = [ min(img_input.shape[i], item_crop_size[i]) for i in range(self.n_dims) ] if not self.is_ae: za_label_a = zarr.open(self.zpath_seg_label, mode='a') index_pool = za_label_a.attrs.get(f'label.indices.{i_frame}') if index_pool is None: index_pool = np.argwhere(0 < img_label) za_label_a.attrs[f'label.indices.{i_frame}'] = tuple( map(tuple, index_pool.tolist())) if self.is_ae: origins = [ randint(0, img_input.shape[i] - item_crop_size[i]) for i in range(self.n_dims) ] else: base_index = index_pool[randrange(len(index_pool))] origins = [ randint( max(0, (int(base_index[i] * resize_factor[i]) - (item_crop_size[i] - 1))), min((img_input.shape[i] - item_crop_size[i]), int(base_index[i] * resize_factor[i]))) for i in range(self.n_dims) ] slices = tuple( slice(origins[i], origins[i] + item_crop_size[i]) for i in range(self.n_dims)) if not self.is_ae: sliced_label = img_label[slices].copy() sliced_input = img_input[slices].copy() if not self.is_ae and 0 < self.contrast: fg_index = np.isin(sliced_label, (2, 3, 5, 6)) bg_index = np.isin(sliced_label, (1, 4)) if fg_index.any() and bg_index.any(): fg_mean = sliced_input[fg_index].mean() bg_mean = sliced_input[bg_index].mean() cr_factor = ( ((fg_mean - bg_mean) * uniform(self.contrast, 1) + bg_mean) / fg_mean) sliced_input[fg_index] *= cr_factor if self.rotation_angle is not None and 0 < self.rotation_angle: if self.n_dims == 3: sliced_input = np.array([ rotate( sliced_input[z], theta, resize=True, preserve_range=True, order=1, # 1: Bi-linear (default) ) for z in range(sliced_input.shape[0]) ]) else: sliced_input = rotate( sliced_input, theta, resize=True, preserve_range=True, order=1, # 1: Bi-linear (default) ) h_crop, w_crop = self.crop_size[-2:] h_rotate, w_rotate = sliced_input.shape[-2:] r_origin = max(0, (h_rotate - h_crop) // 2) c_origin = max(0, (w_rotate - w_crop) // 2) sliced_input = sliced_input[..., r_origin:r_origin + h_crop, c_origin:c_origin + w_crop] # rotate label if not self.is_ae: if self.n_dims == 3: sliced_label = np.array([ rotate( sliced_label[z], theta, resize=True, preserve_range=True, order=0, # 0: Nearest-neighbor ) for z in range(sliced_label.shape[0]) ]) else: sliced_label = rotate( sliced_label, theta, resize=True, preserve_range=True, order=0, # 0: Nearest-neighbor ) sliced_label = sliced_label[..., r_origin:r_origin + h_crop, c_origin:c_origin + w_crop] if sliced_label.max() == 0: continue break tensor_input = torch.from_numpy(sliced_input)[None] if not self.is_ae: tensor_label = torch.from_numpy(sliced_label) if tensor_input.shape[1:] != self.crop_size: interpolate_mode = 'trilinear' if self.n_dims == 3 else 'bilinear' tensor_input = F.interpolate(tensor_input[None], self.crop_size, mode=interpolate_mode, align_corners=True)[0] if not self.is_ae: tensor_label = F.interpolate(tensor_label[None, None].float(), self.crop_size, mode='nearest')[0, 0] if self.is_ae: return (tensor_input, self.keep_axials), tensor_input tensor_label = tensor_label.long() flip_dims = [ -(1 + i) for i, v in enumerate(torch.rand(self.n_dims)) if v < 0.5 ] tensor_input = torch.flip(tensor_input, flip_dims) tensor_label = torch.flip(tensor_label, flip_dims) # -1: unlabeled, 0: BG (LW), 1: Outer (LW), 2: Inner (LW) # 3: BG (HW), 4: Outer (HW), 5: Inner (HW) return (tensor_input, self.keep_axials), tensor_label - 1