예제 #1
0
def check_dataset(dataset, shape):
    """Check a dataset for ELEPHANT.

    Parameters
    ----------
    dataset : str or Path
        Dataset dir to check.
    shape : tuple of int
        Expected image shape (timepoints, depth, height, width) for 3D+t data,
        or (timepoints, height, width) for 2D+t data.

    Returns
    ----------
    message : str
        Returns 'ready' if everything is ok, otherwise returns the first
        encountered problem.

    This function checks if the dataset has the foolowing files.
    It also checks if zarr shapes and dtypes are consistent with imgs.zarr.

    dataset
    ├── flow_hashes.zarr
    ├── flow_labels.zarr
    ├── flow_outputs.zarr
    ├── imgs.zarr
    ├── seg_labels_vis.zarr
    ├── seg_labels.zarr
    └── seg_outputs.zarr
    """
    p = Path(dataset)
    message = 'ready'
    try:
        try:
            img = zarr.open(str(p / 'imgs.zarr'), 'r')
        except ValueError:
            raise Exception(f'{p / "imgs.zarr"} is not found or broken.')
        if img.shape != shape:
            raise Exception('Invalid shape for imgs.zarr\n' +
                            f'Expected: {shape} Found: {img.shape}')
        n_dims = len(shape) - 1
        n_timepoints = shape[0]
        _check_zarr(p / 'flow_outputs.zarr', (
            n_timepoints - 1,
            n_dims,
        ) + shape[-n_dims:], 'float16')
        _check_zarr(p / 'flow_hashes.zarr', (n_timepoints - 1, ), 'S16')
        _check_zarr(p / 'flow_labels.zarr', (
            n_timepoints - 1,
            n_dims + 1,
        ) + shape[-n_dims:], 'float32')
        _check_zarr(p / 'seg_outputs.zarr',
                    (n_timepoints, ) + shape[-n_dims:] + (3, ), 'float16')
        _check_zarr(p / 'seg_labels.zarr', (n_timepoints, ) + shape[-n_dims:],
                    'uint8')
        _check_zarr(p / 'seg_labels_vis.zarr',
                    (n_timepoints, ) + shape[-n_dims:] + (3, ), 'uint8')
    except Exception as e:
        logger().info(str(e))
        message = str(e)
    return message
예제 #2
0
def _get_memmap_or_load(za,
                        timepoint,
                        memmap_dir=None,
                        use_median=False,
                        img_size=None):
    if memmap_dir:
        key = f'{Path(za.store.path).parent.name}-t{timepoint}-{use_median}'
        fpath_org = Path(memmap_dir) / f'{key}.dat'
        if img_size is not None:
            key += '-' + '-'.join(map(str, img_size))
        fpath = Path(memmap_dir) / f'{key}.dat'
        lock = FileLock(str(fpath) + '.lock')
        with lock:
            if not fpath.exists():
                logger().info(f'creating {fpath}')
                fpath.parent.mkdir(parents=True, exist_ok=True)
                img_org = np.memmap(fpath_org,
                                    dtype='float32',
                                    mode='w+',
                                    shape=za.shape[1:])
                img_org[:] = za[timepoint].astype('float32')
                if img_size is None:
                    img = img_org
                else:
                    img = np.memmap(fpath,
                                    dtype='float32',
                                    mode='w+',
                                    shape=img_size)
                    img[:] = F.interpolate(
                        torch.from_numpy(img_org)[None, None],
                        size=img_size,
                        mode='trilinear' if img.ndim == 3 else 'bilinear',
                        align_corners=True,
                    )[0, 0].numpy()
                if use_median and img.ndim == 3:
                    global_median = np.median(img)
                    for z in range(img.shape[0]):
                        slice_median = np.median(img[z])
                        if 0 < slice_median:
                            img[z] -= slice_median - global_median
                img = normalize_zero_one(img)
            logger().info(f'loading from {fpath}')
            return np.memmap(
                fpath,
                dtype='float32',
                mode='c',
                shape=za.shape[1:] if img_size is None else img_size)
    else:
        img = za[timepoint].astype('float32')
        if use_median and img.ndim == 3:
            global_median = np.median(img)
            for z in range(img.shape[0]):
                slice_median = np.median(img[z])
                if 0 < slice_median:
                    img[z] -= slice_median - global_median
        img = normalize_zero_one(img)
    return img
예제 #3
0
    def post(self):
        '''
        Reset seg model.

        '''
        if all(ctype not in request.headers['Content-Type']
               for ctype in ('multipart/form-data', 'application/json')):
            msg = ('Content-Type should be multipart/form-data or '
                   'application/json')
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        if get_state() != TrainState.IDLE.value:
            msg = 'Process is running. Model cannot be reset.'
            logger().error(msg)
            return make_response(jsonify(error=msg), 500)
        try:
            device = get_device()
            if 'multipart/form-data' in request.headers['Content-Type']:
                print(request.form)
                req_json = json.loads(request.form.get('data'))
                file = request.files['file']
                checkpoint = torch.load(file.stream, map_location=device)
                state_dicts = checkpoint if isinstance(checkpoint,
                                                       list) else [checkpoint]
                req_json['url'] = None
            else:
                req_json = request.get_json()
                state_dicts = None
            req_json['device'] = device
            config = ResetConfig(req_json)
            logger().info(config)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            init_seg_models(config.model_path,
                            config.keep_axials,
                            config.device,
                            config.is_3d,
                            config.n_models,
                            config.n_crops,
                            config.zpath_input,
                            config.crop_size,
                            config.scales,
                            url=config.url,
                            state_dicts=state_dicts,
                            is_cpu=config.is_cpu(),
                            cache_maxbytes=config.cache_maxbytes)
        except RuntimeError as e:
            logger().exception('Failed in init_seg_models')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in init_seg_models')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
예제 #4
0
    def post(self):
        '''
        Download a model paramter file.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        req_json = request.get_json()
        config = BaseConfig(req_json)
        logger().info(config)
        if not Path(config.model_path).exists():
            logger().info(
                f'model file {config.model_path} not found @{request.path}')
            return make_response(
                jsonify(message=f'model file {config.model_path} not found'),
                204)
        try:
            resp = send_file(config.model_path)
        except Exception as e:
            logger().exception(
                'Failed to prepare a model parameter file for download')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        return resp
예제 #5
0
def _update_flow_labels(spots_dict, scales, zpath_flow_label,
                        flow_norm_factor):
    za_label = zarr.open(zpath_flow_label, mode='a')
    MIN_AREA_ELLIPSOID = 9
    n_dims = len(za_label.shape) - 2
    for t, spots in spots_dict.items():
        label = np.zeros(za_label.shape[1:], dtype='float32')
        label_indices = set()
        centroids = []
        for spot in spots:
            if get_state() == TrainState.IDLE.value:
                logger().info('update aborted')
                return make_response(jsonify({'completed': False}))
            centroid = np.array(spot['pos'][::-1])
            centroid = centroid[-n_dims:]
            centroids.append((centroid / scales).astype(int).tolist())
            covariance = np.array(spot['covariance'][::-1]).reshape(3, 3)
            covariance = covariance[-n_dims:, -n_dims:]
            radii, rotation = np.linalg.eigh(covariance)
            radii = np.sqrt(radii)
            draw_func = ellipsoid if n_dims == 3 else ellipse
            indices = draw_func(centroid, radii, rotation, scales,
                                label.shape[-n_dims:], MIN_AREA_ELLIPSOID)
            weight = 1  # if spot['tag'] in ['tp'] else false_weight
            displacement = spot['displacement']  # X, Y, Z
            for i in range(n_dims):
                ind = (np.full(len(indices[0]), i), ) + indices
                label[ind] = (displacement[i] / scales[-1 - i] /
                              flow_norm_factor[i])
            # last channels is for weight
            ind = (np.full(len(indices[0]), -1), ) + indices
            label[ind] = weight
            label_indices.update(
                tuple(map(tuple,
                          np.stack(indices, axis=1).tolist())))
        logger().info(f'frame:{t+1}, {len(spots)} linkings')
        target = tuple(np.array(list(label_indices)).T)
        target = (np.array(
            sum(tuple([(i, ) * len(target[0]) for i in range(n_dims + 1)]),
                ())), ) + tuple(
                    np.tile(target[i], n_dims + 1) for i in range(n_dims))
        target_t = (np.full(len(target[0]), t), ) + target
        za_label.attrs[f'label.indices.{t}'] = centroids
        za_label.attrs['updated'] = True
        za_label[target_t] = label[target]
    return make_response(jsonify({'completed': True}))
예제 #6
0
    def post(self):
        '''
        Update process state.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        req_json = request.get_json()
        state = req_json.get(REDIS_KEY_STATE)
        if (not isinstance(state, int)
                or state not in TrainState._value2member_map_):
            msg = f'Invalid state: {state}'
            logger().error(msg)
            return make_response(jsonify(res=msg), 400)
        redis_client.set(REDIS_KEY_STATE, state)
        return make_response(jsonify(success=True, state=get_state()))
예제 #7
0
    def post(self):
        '''
        Validate Dataset.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        req_json = request.get_json()
        if 'dataset_name' not in req_json:
            return make_response(
                jsonify(error='dataset_name key is missing'),
                400
            )
        if 'shape' not in req_json:
            return make_response(jsonify(error='shape key is missing'), 400)
        message = dstool.check_dataset(
            Path(DATASETS_DIR) / req_json['dataset_name'],
            tuple(req_json['shape'])
        )
        return make_response(jsonify(message=message), 200)
예제 #8
0
 def _get_memmap_or_load_label(self, timepoint, img_size=None):
     if self.memmap_dir:
         key = f'{Path(self.za_label.store.path).parent.name}-t{timepoint}'
         if img_size is not None:
             key += '-' + '-'.join(map(str, img_size))
         key += '-flowlabel'
         fpath = Path(self.memmap_dir) / f'{key}.dat'
         lock = FileLock(str(fpath) + '.lock')
         shape = (self.n_dims + 1, ) + (self.za_label.shape[-self.n_dims:]
                                        if img_size is None else img_size)
         with lock:
             if not fpath.exists():
                 logger().info(f'creating {fpath}')
                 fpath.parent.mkdir(parents=True, exist_ok=True)
                 np.memmap(
                     fpath,
                     dtype='float32',
                     mode='w+',
                     shape=shape,
                 )[:] = (self.za_label[timepoint]
                         if img_size is None else F.interpolate(
                             torch.from_numpy(
                                 self.za_label[timepoint])[None],
                             size=img_size,
                             mode='nearest',
                         )[0].numpy())
             logger().info(f'loading from {fpath}')
             return np.memmap(
                 fpath,
                 dtype='float32',
                 mode='c',
                 shape=shape,
             )
     return (self.za_label[timepoint]
             if img_size is None else F.interpolate(
                 torch.from_numpy(self.za_label[timepoint])[None],
                 size=img_size,
                 mode='nearest',
             )[0].numpy())
예제 #9
0
    def post(self):
        '''
        Download outputs in CTC format.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response('', 204)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = ExportConfig(req_json)
            za_input = zarr.open(config.zpath_input, mode='r')
            config.shape = za_input.shape[1:]
            logger().info(config)
            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            result = export_ctc_labels(config, spots_dict)
            if isinstance(result, str):
                resp = send_file(result)
                # file_remover.cleanup_once_done(resp, result)
            elif not result:
                resp = make_response('', 204)
            else:
                resp = make_response('', 200)
        except RuntimeError as e:
            logger().exception('Failed in export_ctc_labels')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in export_ctc_labels')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
        return resp
예제 #10
0
    def post(self):
        '''
        Generate Dataset.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        req_json = request.get_json()
        if 'dataset_name' not in req_json:
            return make_response(
                jsonify(error='dataset_name key is missing'),
                400
            )
        p_dataset = Path(DATASETS_DIR) / req_json['dataset_name']
        h5_files = list(sorted(p_dataset.glob('*.h5')))
        if len(h5_files) == 0:
            logger().info(f'.h5 file not found @{request.path}')
            return make_response(
                jsonify(
                    message=f'.h5 file not found in {req_json["dataset_name"]}'
                ),
                204
            )
        if p_dataset / (p_dataset.name + '.h5') in h5_files:
            h5_filename = str(p_dataset / (p_dataset.name + '.h5'))
        else:
            h5_filename = str(h5_files[0])
        logger().info(f'multiple .h5 files found, use {h5_filename}')
        try:
            generate_dataset_task.delay(
                h5_filename,
                str(p_dataset),
                req_json.get('is_uint16', None),
                req_json.get('divisor', 1.),
                req_json.get('is_2d', False),
            ).wait()
        except Exception as e:
            logger().exception('Failed in gen_datset')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        return make_response('', 200)
예제 #11
0
    def post(self):
        '''
        Reset flow model.

        '''
        if all(ctype not in request.headers['Content-Type']
               for ctype in ('multipart/form-data', 'application/json')):
            msg = ('Content-Type should be multipart/form-data or '
                   'application/json')
            logger().error(msg)
            return (jsonify(error=msg), 400)
        if get_state() != TrainState.IDLE.value:
            return make_response(jsonify(error='Process is running'), 500)
        try:
            device = get_device()
            if 'multipart/form-data' in request.headers['Content-Type']:
                print(request.form)
                req_json = json.loads(request.form.get('data'))
                file = request.files['file']
                checkpoint = torch.load(file.stream, map_location=device)
                state_dicts = checkpoint if isinstance(checkpoint,
                                                       list) else [checkpoint]
                req_json['url'] = None
            else:
                req_json = request.get_json()
                state_dicts = None
            req_json['device'] = device
            config = ResetConfig(req_json)
            logger().info(config)
            init_flow_models(config.model_path,
                             config.device,
                             config.is_3d,
                             url=config.url,
                             state_dicts=state_dicts)
        except RuntimeError as e:
            logger().exception('Failed in reset_flow_models')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in reset_flow_models')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        return make_response(jsonify({'completed': True}))
예제 #12
0
 def post(self):
     if request.headers['Content-Type'] != 'application/json':
         msg = 'Content-Type should be application/json'
         logger().error(msg)
         return make_response(jsonify(error=msg), 400)
     req_json = request.get_json()
     lr = req_json.get('lr')
     if not isinstance(lr, float) or lr < 0:
         msg = f'Invalid learning rate: {lr}'
         logger().error(msg)
         return make_response(jsonify(error=msg), 400)
     n_crops = req_json.get('n_crops')
     if not isinstance(n_crops, int) or n_crops < 0:
         msg = f'Invalid number of crops: {n_crops}'
         logger().error(msg)
         return make_response(jsonify(error=msg), 400)
     redis_client.set(REDIS_KEY_LR, str(lr))
     redis_client.set(REDIS_KEY_NCROPS, str(n_crops))
     logger().info(f'[params updated] lr: {lr}, n_crops: {n_crops}')
     return make_response(
         jsonify(success=True,
                 lr=float(redis_client.get(REDIS_KEY_LR)),
                 n_crops=int(redis_client.get(REDIS_KEY_NCROPS))))
예제 #13
0
def log_before_request():
    if request.endpoint not in (None, 'gpus'):
        logger().info(f'START {request.method} {request.path}')
예제 #14
0
    def post(self):
        '''
        Train seg model.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.n_crops < 1:
                msg = 'n_crops should be a positive number'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not (spots_dict or config.is_livemode):
                msg = 'nothing to train'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            if get_state() != TrainState.IDLE.value:
                msg = 'Process is running'
                logger().error(msg)
                return make_response(jsonify(error=msg), 500)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            if config.is_livemode:
                redis_client.delete(REDIS_KEY_TIMEPOINT)
            else:
                try:
                    _update_seg_labels(spots_dict,
                                       config.scales,
                                       config.zpath_input,
                                       config.zpath_seg_label,
                                       config.zpath_seg_label_vis,
                                       config.auto_bg_thresh,
                                       config.c_ratio,
                                       memmap_dir=config.memmap_dir)
                except KeyboardInterrupt:
                    return make_response(jsonify({'completed': False}))
            step_offset = 0
            for path in sorted(Path(config.log_dir).glob('event*')):
                try:
                    *_, last_record = TFRecordDataset(str(path))
                    last = event_pb2.Event.FromString(last_record.numpy()).step
                    step_offset = max(step_offset, last + 1)
                except Exception:
                    pass
            epoch_start = 0
            async_result = train_seg_task.delay(
                list(spots_dict.keys()),
                config.batch_size,
                config.crop_size,
                config.class_weights,
                config.false_weight,
                config.model_path,
                config.n_epochs,
                config.keep_axials,
                config.scales,
                config.lr,
                config.n_crops,
                config.is_3d,
                config.is_livemode,
                config.scale_factor_base,
                config.rotation_angle,
                config.contrast,
                config.zpath_input,
                config.zpath_seg_label,
                config.log_interval,
                config.log_dir,
                step_offset,
                epoch_start,
                config.is_cpu(),
                config.is_mixed_precision,
                config.cache_maxbytes,
                config.memmap_dir,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('training aborted')
                    return make_response(jsonify({'completed': False}))
        except Exception as e:
            logger().exception('Failed in train_seg')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
예제 #15
0
    def post(self):
        '''
        Update flow label.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        if redis_client.get(REDIS_KEY_UPDATE_ONGOING_FLOW):
            msg = 'Last update is ongoing'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = FlowTrainConfig(req_json)
            logger().info(config)
            if req_json.get('reset'):
                zarr.open_like(zarr.open(config.zpath_flow_label, mode='r'),
                               config.zpath_flow_label,
                               mode='w')
                return make_response(jsonify({'completed': True}))

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not spots_dict:
                msg = 'nothing to update'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            response = _update_flow_labels(spots_dict, config.scales,
                                           config.zpath_flow_label,
                                           config.flow_norm_factor)
        except RuntimeError as e:
            logger().exception('Failed in update_flow_labels')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in update_flow_labels')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
            redis_client.delete(REDIS_KEY_UPDATE_ONGOING_FLOW)
        return response
예제 #16
0
    def post(self):
        '''
        Update seg label.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        if redis_client.get(REDIS_KEY_UPDATE_ONGOING_SEG):
            msg = 'Last update is ongoing'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            redis_client.set(REDIS_KEY_UPDATE_ONGOING_SEG, 1)
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.is_livemode and redis_client.get(REDIS_KEY_TIMEPOINT):
                msg = 'Last update/training is ongoing'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            if req_json.get('reset'):
                try:
                    zarr.open_like(zarr.open(config.zpath_seg_label, mode='r'),
                                   config.zpath_seg_label,
                                   mode='w')
                    zarr.open_like(zarr.open(config.zpath_seg_label_vis,
                                             mode='r'),
                                   config.zpath_seg_label_vis,
                                   mode='w')
                except RuntimeError as e:
                    logger().exception('Failed in opening zarr')
                    return make_response(jsonify(error=f'Runtime Error: {e}'),
                                         500)
                except Exception as e:
                    logger().exception('Failed in opening zarr')
                    return make_response(jsonify(error=f'Exception: {e}'), 500)
                return make_response(jsonify({'completed': True}))

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not spots_dict:
                msg = 'nothing to update'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            if config.is_livemode and len(spots_dict.keys()) != 1:
                msg = 'Livemode should update only a single timepoint'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            try:
                response = _update_seg_labels(spots_dict,
                                              config.scales,
                                              config.zpath_input,
                                              config.zpath_seg_label,
                                              config.zpath_seg_label_vis,
                                              config.auto_bg_thresh,
                                              config.c_ratio,
                                              config.is_livemode,
                                              memmap_dir=config.memmap_dir)
            except KeyboardInterrupt:
                return make_response(jsonify({'completed': False}))
            except Exception as e:
                logger().exception('Failed in _update_seg_labels')
                return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
            redis_client.delete(REDIS_KEY_UPDATE_ONGOING_SEG)
        return response
예제 #17
0
    def post(self):
        '''
        Predict seg.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationEvalConfig(req_json)
            logger().info(config)
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            async_result = detect_spots_task.delay(
                str(config.device),
                config.model_path,
                config.keep_axials,
                config.is_pad,
                config.is_3d,
                config.crop_size,
                config.scales,
                config.cache_maxbytes,
                config.use_2d,
                config.use_median,
                config.patch_size,
                config.crop_box,
                config.c_ratio,
                config.p_thresh,
                config.r_min,
                config.r_max,
                config.output_prediction,
                config.zpath_input,
                config.zpath_seg_output,
                config.timepoint,
                None,
                config.memmap_dir,
                config.batch_size,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('prediction aborted')
                    return make_response(
                        jsonify({
                            'spots': [],
                            'completed': False
                        }))
            if async_result.failed():
                raise async_result.result
            spots = async_result.result
            if spots is None:
                logger().info('prediction aborted')
                return make_response(jsonify({
                    'spots': [],
                    'completed': False
                }))
            publish_mq('prediction', 'Prediction updated')
        except Exception as e:
            logger().exception('Failed in detect_spots')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
        return make_response(jsonify({'spots': spots, 'completed': True}))
예제 #18
0
def _update_seg_labels(spots_dict,
                       scales,
                       zpath_input,
                       zpath_seg_label,
                       zpath_seg_label_vis,
                       auto_bg_thresh=0,
                       c_ratio=0.5,
                       is_livemode=False,
                       memmap_dir=None):
    if is_livemode:
        assert len(spots_dict.keys()) == 1
    za_input = zarr.open(zpath_input, mode='r')
    za_label = zarr.open(zpath_seg_label, mode='a')
    za_label_vis = zarr.open(zpath_seg_label_vis, mode='a')
    keyorder = ['tp', 'fp', 'tn', 'fn', 'tb', 'fb']
    MIN_AREA_ELLIPSOID = 9
    img_shape = za_input.shape[1:]
    n_dims = len(img_shape)
    keybase = Path(za_label.store.path).parent.name
    if scales is None:
        scales = (1., ) * n_dims
    scales = np.array(scales)
    for t, spots in spots_dict.items():
        label_indices = set()
        centroids = []
        label = np.zeros(img_shape, dtype='uint8')
        label_vis = np.zeros(img_shape + (3, ), dtype='uint8')
        if 0 < auto_bg_thresh:
            indices_bg = np.nonzero(
                get_input_at(za_input, t, memmap_dir=memmap_dir) <
                auto_bg_thresh)
            label[indices_bg] = 1
            label_vis[indices_bg] = (255, 0, 0)
            if INDICES_THRESHOLD < len(indices_bg[0]):
                label_indices = None
            else:
                label_indices.update(tuple(map(tuple, np.array(indices_bg).T)))
        cnt = collections.Counter({x: 0 for x in keyorder})
        for spot in spots:
            if get_state() == TrainState.IDLE.value:
                logger().info('update aborted')
                raise KeyboardInterrupt
            cnt[spot['tag']] += 1
            centroid = np.array(spot['pos'][::-1])
            centroid = centroid[-n_dims:]
            centroids.append((centroid / scales).astype(int).tolist())
            covariance = np.array(spot['covariance'][::-1]).reshape(3, 3)
            covariance = covariance[-n_dims:, -n_dims:]
            radii, rotation = np.linalg.eigh(covariance)
            radii = np.sqrt(radii)
            if n_dims == 3:
                draw_func = ellipsoid
                dilate_func = _dilate_3d_indices
            else:
                draw_func = ellipse
                dilate_func = _dilate_2d_indices
            indices_outer = draw_func(centroid, radii, rotation, scales,
                                      img_shape, MIN_AREA_ELLIPSOID)
            if label_indices is not None:
                if INDICES_THRESHOLD < len(indices_outer[0]):
                    label_indices = None
                else:
                    label_indices.update(
                        tuple(map(tuple,
                                  np.array(indices_outer).T)))
                    if INDICES_THRESHOLD < len(label_indices):
                        label_indices = None
            if spot['tag'] in ['tp', 'tb', 'tn']:
                label_offset = 0
                label_vis_value = 255
            else:
                label_offset = 3
                label_vis_value = 127
            cond_outer_1 = np.fmod(label[indices_outer] - 1, 3) <= 1
            if spot['tag'] in ('tp', 'fn'):
                indices_inner = draw_func(centroid, radii * c_ratio, rotation,
                                          scales, img_shape,
                                          MIN_AREA_ELLIPSOID)
                indices_inner_p = dilate_func(*indices_inner, img_shape)
                label[indices_outer] = np.where(cond_outer_1, 2 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_1[..., None],
                                                    (0, label_vis_value, 0),
                                                    label_vis[indices_outer])
                label[indices_inner_p] = 2 + label_offset
                label_vis[indices_inner_p] = (0, label_vis_value, 0)
                cond_inner = np.fmod(label[indices_inner] - 1, 3) <= 2
                label[indices_inner] = np.where(cond_inner, 3 + label_offset,
                                                label[indices_inner])
                label_vis[indices_inner] = np.where(cond_inner[..., None],
                                                    (0, 0, label_vis_value),
                                                    label_vis[indices_inner])
            elif spot['tag'] in ('tb', 'fb'):
                label[indices_outer] = np.where(cond_outer_1, 2 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_1[..., None],
                                                    (0, label_vis_value, 0),
                                                    label_vis[indices_outer])
            elif spot['tag'] in ('tn', 'fp'):
                cond_outer_0 = np.fmod(label[indices_outer] - 1, 3) <= 0
                label[indices_outer] = np.where(cond_outer_0, 1 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_0[..., None],
                                                    (label_vis_value, 0, 0),
                                                    label_vis[indices_outer])
        logger().info('frame:{}, {}'.format(
            t, sorted(cnt.items(), key=lambda i: keyorder.index(i[0]))))
        if label_indices is None:
            for chunk in Path(za_label.store.path).glob(f'{t}.*'):
                chunk.unlink()
            for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'):
                chunk.unlink()
            za_label[t] = label
            za_label_vis[t] = label_vis
        else:
            target = tuple(np.array(list(label_indices)).T)
            target_t = to_fancy_index(t, *target)
            target_vis = (
                tuple(np.tile(target[i], 3) for i in range(n_dims)) +
                (np.array(
                    sum(tuple([(c, ) * len(target[0]) for c in range(3)]),
                        ())), ))
            target_vis_t = to_fancy_index(t, *target_vis)
            for chunk in Path(za_label.store.path).glob(f'{t}.*'):
                chunk.unlink()
            for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'):
                chunk.unlink()
            za_label[target_t] = label[target]
            za_label_vis[target_vis_t] = label_vis[target_vis]
        za_label.attrs[f'label.indices.{t}'] = centroids
        za_label.attrs['updated'] = True
        if memmap_dir:
            for fpath in Path(memmap_dir).glob(
                    f'{keybase}-t{t}*-seglabel.dat'):
                lock = FileLock(str(fpath) + '.lock')
                with lock:
                    if fpath.exists():
                        logger().info(f'remove {fpath}')
                        fpath.unlink()
        if is_livemode:
            if redis_client.get(REDIS_KEY_TIMEPOINT):
                msg = 'Last update/training is ongoing'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            redis_client.set(REDIS_KEY_TIMEPOINT, t)
    return make_response(jsonify({'completed': True}))
예제 #19
0
def log_after_request(response):
    if request.endpoint not in (None, 'gpus'):
        logger().info(
            f'DONE {request.method} {request.path} => [{response.status}]')
    return response
예제 #20
0
def generate_dataset(input,
                     output,
                     is_uint16=False,
                     divisor=1.,
                     is_2d=False,
                     is_message_queue=False,
                     is_multiprocessing=True):
    """Generate a dataset for ELEPHANT.

    Parameters
    ----------
    input : str
        Input .h5 file.
    output : str or Path
        Output directory.
    is_uint16 : bool
        With this flag, the original image will be stored with uint16.
        If None, determine if uint8 or uint16 dynamically.
        default: False (uint8)
    divisor : float
        Divide the original pixel values by this value.
        (with uint8, the values should be scale-downed to 0-255)
        default: 1.0
    is_2d : bool
        With this flag, the original image will be stored as 2d+time.
        default: False (3d+time)
    is_message_queue : bool
        With this flag, progress is reported using pika.BlockingConnection.
    is_multiprocessing : bool
        With this flag, multiprocessing is enabled.

    This function will generate the following files.

    output
    ├── input.h5
    ├── flow_hashes.zarr
    ├── flow_labels.zarr
    ├── flow_outputs.zarr
    ├── imgs.zarr
    ├── seg_labels_vis.zarr
    ├── seg_labels.zarr
    └── seg_outputs.zarr
    """
    logger().info(f'input: {input}')
    logger().info(f'output dir: {output}')
    logger().info(f'divisor: {divisor}')
    with h5py.File(input, 'r') as f:
        # timepoints are stored as 't00000', 't000001', ...
        timepoints = list(filter(re.compile(r't\d{5}').search, list(f.keys())))

        def func(x):
            return x[0] if is_2d else x

        # determine if uint8 or uint16 dynamically
        if is_uint16 is None:
            is_uint16 = False
            for timepoint in tqdm(timepoints):
                if 255 < np.array(func(
                        f[timepoint]['s00']['0']['cells'])).max():
                    is_uint16 = True
                    break
        shape = f[timepoints[0]]['s00']['0']['cells'].shape
    if is_2d:
        shape = shape[-2:]
    n_dims = 3 - is_2d  # 3 or 2
    n_timepoints = len(timepoints)
    p = Path(output)
    chunk_shape = tuple(min(s, 1024) for s in shape[-2:])
    if not is_2d:
        chunk_shape = (1, ) + chunk_shape
    zarr.open(str(p / 'imgs.zarr'),
              'w',
              shape=(n_timepoints, ) + shape,
              chunks=(1, ) + chunk_shape,
              dtype='u2' if is_uint16 else 'u1')
    zarr.open(str(p / 'flow_outputs.zarr'),
              'w',
              shape=(
                  n_timepoints - 1,
                  n_dims,
              ) + shape,
              chunks=(
                  1,
                  1,
              ) + chunk_shape,
              dtype='f2')
    zarr.open(str(p / 'flow_hashes.zarr'),
              'w',
              shape=(n_timepoints - 1, ),
              chunks=(n_timepoints - 1, ),
              dtype='S16')
    zarr.open(str(p / 'flow_labels.zarr'),
              'w',
              shape=(
                  n_timepoints - 1,
                  n_dims + 1,
              ) + shape,
              chunks=(
                  1,
                  1,
              ) + chunk_shape,
              dtype='f4')
    zarr.open(str(p / 'seg_outputs.zarr'),
              'w',
              shape=(n_timepoints, ) + shape + (3, ),
              chunks=(1, ) + chunk_shape + (1, ),
              dtype='f2')
    zarr.open(str(p / 'seg_labels.zarr'),
              'w',
              shape=(n_timepoints, ) + shape,
              chunks=(1, ) + chunk_shape,
              dtype='u1')
    zarr.open(str(p / 'seg_labels_vis.zarr'),
              'w',
              shape=(n_timepoints, ) + shape + (3, ),
              chunks=(1, ) + chunk_shape + (1, ),
              dtype='u1')
    dtype = np.uint16 if is_uint16 else np.uint8
    if n_dims == 2:
        chunks = tuple((slice(None), slice(y, y + chunk_shape[-2]),
                        slice(x, x + chunk_shape[-1]))
                       for y in range(0, shape[-2], chunk_shape[-2])
                       for x in range(0, shape[-1], chunk_shape[-1]))
    else:
        chunks = tuple(
            (slice(z, z + chunk_shape[-3]), slice(y, y + chunk_shape[-2]),
             slice(x, x + chunk_shape[-1]))
            for z in range(0, shape[-3], chunk_shape[-3])
            for y in range(0, shape[-2], chunk_shape[-2])
            for x in range(0, shape[-1], chunk_shape[-1]))
    if is_multiprocessing:
        pool = mp.Pool()
    try:
        for t, timepoint in tqdm(enumerate(timepoints)):
            partial_write_chunk = partial(write_chunk,
                                          zpath=str(p / 'imgs.zarr'),
                                          t=t,
                                          n_dims=n_dims,
                                          is_2d=is_2d,
                                          input=input,
                                          timepoint=timepoint,
                                          divisor=divisor,
                                          dtype=dtype)
            if is_multiprocessing:
                pool.map(partial_write_chunk, chunks)
            else:
                for chunk in chunks:
                    partial_write_chunk(chunk)
            if is_message_queue:
                publish_mq(
                    'dataset',
                    json.dumps({
                        't_max': n_timepoints,
                        't_current': t + 1,
                    }))
    finally:
        if is_multiprocessing:
            pool.close()