示例#1
0
    def post(self):
        '''
        Update flow label.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        if redis_client.get(REDIS_KEY_UPDATE_ONGOING_FLOW):
            msg = 'Last update is ongoing'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = FlowTrainConfig(req_json)
            logger().info(config)
            if req_json.get('reset'):
                zarr.open_like(zarr.open(config.zpath_flow_label, mode='r'),
                               config.zpath_flow_label,
                               mode='w')
                return make_response(jsonify({'completed': True}))

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not spots_dict:
                msg = 'nothing to update'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            response = _update_flow_labels(spots_dict, config.scales,
                                           config.zpath_flow_label,
                                           config.flow_norm_factor)
        except RuntimeError as e:
            logger().exception('Failed in update_flow_labels')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in update_flow_labels')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
            redis_client.delete(REDIS_KEY_UPDATE_ONGOING_FLOW)
        return response
示例#2
0
    def post(self):
        '''
        Download outputs in CTC format.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response('', 204)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = ExportConfig(req_json)
            za_input = zarr.open(config.zpath_input, mode='r')
            config.shape = za_input.shape[1:]
            logger().info(config)
            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            result = export_ctc_labels(config, spots_dict)
            if isinstance(result, str):
                resp = send_file(result)
                # file_remover.cleanup_once_done(resp, result)
            elif not result:
                resp = make_response('', 204)
            else:
                resp = make_response('', 200)
        except RuntimeError as e:
            logger().exception('Failed in export_ctc_labels')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in export_ctc_labels')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
        return resp
 def __getitem__(self, index):
     if (redis_client is not None and get_state() == TrainState.IDLE.value):
         raise KeyboardInterrupt
     data_ind = index // len(self.patch_list)
     patch_ind = index % len(self.patch_list)
     slices, _ = self.patch_list[patch_ind]
     return (self.input[(data_ind, ) + tuple(slices)],
             self.keep_axials[data_ind], data_ind, patch_ind)
示例#4
0
    def post(self):
        '''
        Reset seg model.

        '''
        if all(ctype not in request.headers['Content-Type']
               for ctype in ('multipart/form-data', 'application/json')):
            msg = ('Content-Type should be multipart/form-data or '
                   'application/json')
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        if get_state() != TrainState.IDLE.value:
            msg = 'Process is running. Model cannot be reset.'
            logger().error(msg)
            return make_response(jsonify(error=msg), 500)
        try:
            device = get_device()
            if 'multipart/form-data' in request.headers['Content-Type']:
                print(request.form)
                req_json = json.loads(request.form.get('data'))
                file = request.files['file']
                checkpoint = torch.load(file.stream, map_location=device)
                state_dicts = checkpoint if isinstance(checkpoint,
                                                       list) else [checkpoint]
                req_json['url'] = None
            else:
                req_json = request.get_json()
                state_dicts = None
            req_json['device'] = device
            config = ResetConfig(req_json)
            logger().info(config)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            init_seg_models(config.model_path,
                            config.keep_axials,
                            config.device,
                            config.is_3d,
                            config.n_models,
                            config.n_crops,
                            config.zpath_input,
                            config.crop_size,
                            config.scales,
                            url=config.url,
                            state_dicts=state_dicts,
                            is_cpu=config.is_cpu(),
                            cache_maxbytes=config.cache_maxbytes)
        except RuntimeError as e:
            logger().exception('Failed in init_seg_models')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in init_seg_models')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
示例#5
0
def _update_flow_labels(spots_dict, scales, zpath_flow_label,
                        flow_norm_factor):
    za_label = zarr.open(zpath_flow_label, mode='a')
    MIN_AREA_ELLIPSOID = 9
    n_dims = len(za_label.shape) - 2
    for t, spots in spots_dict.items():
        label = np.zeros(za_label.shape[1:], dtype='float32')
        label_indices = set()
        centroids = []
        for spot in spots:
            if get_state() == TrainState.IDLE.value:
                logger().info('update aborted')
                return make_response(jsonify({'completed': False}))
            centroid = np.array(spot['pos'][::-1])
            centroid = centroid[-n_dims:]
            centroids.append((centroid / scales).astype(int).tolist())
            covariance = np.array(spot['covariance'][::-1]).reshape(3, 3)
            covariance = covariance[-n_dims:, -n_dims:]
            radii, rotation = np.linalg.eigh(covariance)
            radii = np.sqrt(radii)
            draw_func = ellipsoid if n_dims == 3 else ellipse
            indices = draw_func(centroid, radii, rotation, scales,
                                label.shape[-n_dims:], MIN_AREA_ELLIPSOID)
            weight = 1  # if spot['tag'] in ['tp'] else false_weight
            displacement = spot['displacement']  # X, Y, Z
            for i in range(n_dims):
                ind = (np.full(len(indices[0]), i), ) + indices
                label[ind] = (displacement[i] / scales[-1 - i] /
                              flow_norm_factor[i])
            # last channels is for weight
            ind = (np.full(len(indices[0]), -1), ) + indices
            label[ind] = weight
            label_indices.update(
                tuple(map(tuple,
                          np.stack(indices, axis=1).tolist())))
        logger().info(f'frame:{t+1}, {len(spots)} linkings')
        target = tuple(np.array(list(label_indices)).T)
        target = (np.array(
            sum(tuple([(i, ) * len(target[0]) for i in range(n_dims + 1)]),
                ())), ) + tuple(
                    np.tile(target[i], n_dims + 1) for i in range(n_dims))
        target_t = (np.full(len(target[0]), t), ) + target
        za_label.attrs[f'label.indices.{t}'] = centroids
        za_label.attrs['updated'] = True
        za_label[target_t] = label[target]
    return make_response(jsonify({'completed': True}))
示例#6
0
    def post(self):
        '''
        Update process state.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        req_json = request.get_json()
        state = req_json.get(REDIS_KEY_STATE)
        if (not isinstance(state, int)
                or state not in TrainState._value2member_map_):
            msg = f'Invalid state: {state}'
            logger().error(msg)
            return make_response(jsonify(res=msg), 400)
        redis_client.set(REDIS_KEY_STATE, state)
        return make_response(jsonify(success=True, state=get_state()))
示例#7
0
    def post(self):
        '''
        Reset flow model.

        '''
        if all(ctype not in request.headers['Content-Type']
               for ctype in ('multipart/form-data', 'application/json')):
            msg = ('Content-Type should be multipart/form-data or '
                   'application/json')
            logger().error(msg)
            return (jsonify(error=msg), 400)
        if get_state() != TrainState.IDLE.value:
            return make_response(jsonify(error='Process is running'), 500)
        try:
            device = get_device()
            if 'multipart/form-data' in request.headers['Content-Type']:
                print(request.form)
                req_json = json.loads(request.form.get('data'))
                file = request.files['file']
                checkpoint = torch.load(file.stream, map_location=device)
                state_dicts = checkpoint if isinstance(checkpoint,
                                                       list) else [checkpoint]
                req_json['url'] = None
            else:
                req_json = request.get_json()
                state_dicts = None
            req_json['device'] = device
            config = ResetConfig(req_json)
            logger().info(config)
            init_flow_models(config.model_path,
                             config.device,
                             config.is_3d,
                             url=config.url,
                             state_dicts=state_dicts)
        except RuntimeError as e:
            logger().exception('Failed in reset_flow_models')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in reset_flow_models')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        return make_response(jsonify({'completed': True}))
示例#8
0
    def post(self):
        '''
        Train seg model.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.n_crops < 1:
                msg = 'n_crops should be a positive number'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not (spots_dict or config.is_livemode):
                msg = 'nothing to train'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            if get_state() != TrainState.IDLE.value:
                msg = 'Process is running'
                logger().error(msg)
                return make_response(jsonify(error=msg), 500)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            if config.is_livemode:
                redis_client.delete(REDIS_KEY_TIMEPOINT)
            else:
                try:
                    _update_seg_labels(spots_dict,
                                       config.scales,
                                       config.zpath_input,
                                       config.zpath_seg_label,
                                       config.zpath_seg_label_vis,
                                       config.auto_bg_thresh,
                                       config.c_ratio,
                                       memmap_dir=config.memmap_dir)
                except KeyboardInterrupt:
                    return make_response(jsonify({'completed': False}))
            step_offset = 0
            for path in sorted(Path(config.log_dir).glob('event*')):
                try:
                    *_, last_record = TFRecordDataset(str(path))
                    last = event_pb2.Event.FromString(last_record.numpy()).step
                    step_offset = max(step_offset, last + 1)
                except Exception:
                    pass
            epoch_start = 0
            async_result = train_seg_task.delay(
                list(spots_dict.keys()),
                config.batch_size,
                config.crop_size,
                config.class_weights,
                config.false_weight,
                config.model_path,
                config.n_epochs,
                config.keep_axials,
                config.scales,
                config.lr,
                config.n_crops,
                config.is_3d,
                config.is_livemode,
                config.scale_factor_base,
                config.rotation_angle,
                config.contrast,
                config.zpath_input,
                config.zpath_seg_label,
                config.log_interval,
                config.log_dir,
                step_offset,
                epoch_start,
                config.is_cpu(),
                config.is_mixed_precision,
                config.cache_maxbytes,
                config.memmap_dir,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('training aborted')
                    return make_response(jsonify({'completed': False}))
        except Exception as e:
            logger().exception('Failed in train_seg')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
示例#9
0
    def post(self):
        '''
        Update seg label.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        if redis_client.get(REDIS_KEY_UPDATE_ONGOING_SEG):
            msg = 'Last update is ongoing'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            redis_client.set(REDIS_KEY_UPDATE_ONGOING_SEG, 1)
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.is_livemode and redis_client.get(REDIS_KEY_TIMEPOINT):
                msg = 'Last update/training is ongoing'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            if req_json.get('reset'):
                try:
                    zarr.open_like(zarr.open(config.zpath_seg_label, mode='r'),
                                   config.zpath_seg_label,
                                   mode='w')
                    zarr.open_like(zarr.open(config.zpath_seg_label_vis,
                                             mode='r'),
                                   config.zpath_seg_label_vis,
                                   mode='w')
                except RuntimeError as e:
                    logger().exception('Failed in opening zarr')
                    return make_response(jsonify(error=f'Runtime Error: {e}'),
                                         500)
                except Exception as e:
                    logger().exception('Failed in opening zarr')
                    return make_response(jsonify(error=f'Exception: {e}'), 500)
                return make_response(jsonify({'completed': True}))

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not spots_dict:
                msg = 'nothing to update'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            if config.is_livemode and len(spots_dict.keys()) != 1:
                msg = 'Livemode should update only a single timepoint'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            try:
                response = _update_seg_labels(spots_dict,
                                              config.scales,
                                              config.zpath_input,
                                              config.zpath_seg_label,
                                              config.zpath_seg_label_vis,
                                              config.auto_bg_thresh,
                                              config.c_ratio,
                                              config.is_livemode,
                                              memmap_dir=config.memmap_dir)
            except KeyboardInterrupt:
                return make_response(jsonify({'completed': False}))
            except Exception as e:
                logger().exception('Failed in _update_seg_labels')
                return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
            redis_client.delete(REDIS_KEY_UPDATE_ONGOING_SEG)
        return response
示例#10
0
    def post(self):
        '''
        Predict seg.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationEvalConfig(req_json)
            logger().info(config)
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            async_result = detect_spots_task.delay(
                str(config.device),
                config.model_path,
                config.keep_axials,
                config.is_pad,
                config.is_3d,
                config.crop_size,
                config.scales,
                config.cache_maxbytes,
                config.use_2d,
                config.use_median,
                config.patch_size,
                config.crop_box,
                config.c_ratio,
                config.p_thresh,
                config.r_min,
                config.r_max,
                config.output_prediction,
                config.zpath_input,
                config.zpath_seg_output,
                config.timepoint,
                None,
                config.memmap_dir,
                config.batch_size,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('prediction aborted')
                    return make_response(
                        jsonify({
                            'spots': [],
                            'completed': False
                        }))
            if async_result.failed():
                raise async_result.result
            spots = async_result.result
            if spots is None:
                logger().info('prediction aborted')
                return make_response(jsonify({
                    'spots': [],
                    'completed': False
                }))
            publish_mq('prediction', 'Prediction updated')
        except Exception as e:
            logger().exception('Failed in detect_spots')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
        return make_response(jsonify({'spots': spots, 'completed': True}))
示例#11
0
def _update_seg_labels(spots_dict,
                       scales,
                       zpath_input,
                       zpath_seg_label,
                       zpath_seg_label_vis,
                       auto_bg_thresh=0,
                       c_ratio=0.5,
                       is_livemode=False,
                       memmap_dir=None):
    if is_livemode:
        assert len(spots_dict.keys()) == 1
    za_input = zarr.open(zpath_input, mode='r')
    za_label = zarr.open(zpath_seg_label, mode='a')
    za_label_vis = zarr.open(zpath_seg_label_vis, mode='a')
    keyorder = ['tp', 'fp', 'tn', 'fn', 'tb', 'fb']
    MIN_AREA_ELLIPSOID = 9
    img_shape = za_input.shape[1:]
    n_dims = len(img_shape)
    keybase = Path(za_label.store.path).parent.name
    if scales is None:
        scales = (1., ) * n_dims
    scales = np.array(scales)
    for t, spots in spots_dict.items():
        label_indices = set()
        centroids = []
        label = np.zeros(img_shape, dtype='uint8')
        label_vis = np.zeros(img_shape + (3, ), dtype='uint8')
        if 0 < auto_bg_thresh:
            indices_bg = np.nonzero(
                get_input_at(za_input, t, memmap_dir=memmap_dir) <
                auto_bg_thresh)
            label[indices_bg] = 1
            label_vis[indices_bg] = (255, 0, 0)
            if INDICES_THRESHOLD < len(indices_bg[0]):
                label_indices = None
            else:
                label_indices.update(tuple(map(tuple, np.array(indices_bg).T)))
        cnt = collections.Counter({x: 0 for x in keyorder})
        for spot in spots:
            if get_state() == TrainState.IDLE.value:
                logger().info('update aborted')
                raise KeyboardInterrupt
            cnt[spot['tag']] += 1
            centroid = np.array(spot['pos'][::-1])
            centroid = centroid[-n_dims:]
            centroids.append((centroid / scales).astype(int).tolist())
            covariance = np.array(spot['covariance'][::-1]).reshape(3, 3)
            covariance = covariance[-n_dims:, -n_dims:]
            radii, rotation = np.linalg.eigh(covariance)
            radii = np.sqrt(radii)
            if n_dims == 3:
                draw_func = ellipsoid
                dilate_func = _dilate_3d_indices
            else:
                draw_func = ellipse
                dilate_func = _dilate_2d_indices
            indices_outer = draw_func(centroid, radii, rotation, scales,
                                      img_shape, MIN_AREA_ELLIPSOID)
            if label_indices is not None:
                if INDICES_THRESHOLD < len(indices_outer[0]):
                    label_indices = None
                else:
                    label_indices.update(
                        tuple(map(tuple,
                                  np.array(indices_outer).T)))
                    if INDICES_THRESHOLD < len(label_indices):
                        label_indices = None
            if spot['tag'] in ['tp', 'tb', 'tn']:
                label_offset = 0
                label_vis_value = 255
            else:
                label_offset = 3
                label_vis_value = 127
            cond_outer_1 = np.fmod(label[indices_outer] - 1, 3) <= 1
            if spot['tag'] in ('tp', 'fn'):
                indices_inner = draw_func(centroid, radii * c_ratio, rotation,
                                          scales, img_shape,
                                          MIN_AREA_ELLIPSOID)
                indices_inner_p = dilate_func(*indices_inner, img_shape)
                label[indices_outer] = np.where(cond_outer_1, 2 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_1[..., None],
                                                    (0, label_vis_value, 0),
                                                    label_vis[indices_outer])
                label[indices_inner_p] = 2 + label_offset
                label_vis[indices_inner_p] = (0, label_vis_value, 0)
                cond_inner = np.fmod(label[indices_inner] - 1, 3) <= 2
                label[indices_inner] = np.where(cond_inner, 3 + label_offset,
                                                label[indices_inner])
                label_vis[indices_inner] = np.where(cond_inner[..., None],
                                                    (0, 0, label_vis_value),
                                                    label_vis[indices_inner])
            elif spot['tag'] in ('tb', 'fb'):
                label[indices_outer] = np.where(cond_outer_1, 2 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_1[..., None],
                                                    (0, label_vis_value, 0),
                                                    label_vis[indices_outer])
            elif spot['tag'] in ('tn', 'fp'):
                cond_outer_0 = np.fmod(label[indices_outer] - 1, 3) <= 0
                label[indices_outer] = np.where(cond_outer_0, 1 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_0[..., None],
                                                    (label_vis_value, 0, 0),
                                                    label_vis[indices_outer])
        logger().info('frame:{}, {}'.format(
            t, sorted(cnt.items(), key=lambda i: keyorder.index(i[0]))))
        if label_indices is None:
            for chunk in Path(za_label.store.path).glob(f'{t}.*'):
                chunk.unlink()
            for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'):
                chunk.unlink()
            za_label[t] = label
            za_label_vis[t] = label_vis
        else:
            target = tuple(np.array(list(label_indices)).T)
            target_t = to_fancy_index(t, *target)
            target_vis = (
                tuple(np.tile(target[i], 3) for i in range(n_dims)) +
                (np.array(
                    sum(tuple([(c, ) * len(target[0]) for c in range(3)]),
                        ())), ))
            target_vis_t = to_fancy_index(t, *target_vis)
            for chunk in Path(za_label.store.path).glob(f'{t}.*'):
                chunk.unlink()
            for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'):
                chunk.unlink()
            za_label[target_t] = label[target]
            za_label_vis[target_vis_t] = label_vis[target_vis]
        za_label.attrs[f'label.indices.{t}'] = centroids
        za_label.attrs['updated'] = True
        if memmap_dir:
            for fpath in Path(memmap_dir).glob(
                    f'{keybase}-t{t}*-seglabel.dat'):
                lock = FileLock(str(fpath) + '.lock')
                with lock:
                    if fpath.exists():
                        logger().info(f'remove {fpath}')
                        fpath.unlink()
        if is_livemode:
            if redis_client.get(REDIS_KEY_TIMEPOINT):
                msg = 'Last update/training is ongoing'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            redis_client.set(REDIS_KEY_TIMEPOINT, t)
    return make_response(jsonify({'completed': True}))
示例#12
0
    def get(self):
        '''
        Check Process state.

        '''
        return make_response(jsonify(success=True, state=get_state()))
示例#13
0
    def __getitem__(self, index):
        """
        Input shape: (2, (D,) H, W)
        Label shape: (ndim+1, (D,) H, W)
        Label channels: (flow_x, flow_y, flow_z, mask)
        """
        if (redis_client is not None and get_state() == TrainState.IDLE.value):
            raise KeyboardInterrupt
        if self.length is not None:
            i_frame = np.random.choice(self.indices)
        else:
            i_frame = self.indices[index // self.n_crops]
        img_input = get_inputs_at(self.za_input,
                                  i_frame,
                                  cache_dict=self.cache_dict_input,
                                  memmap_dir=self.memmap_dir,
                                  img_size=self.img_size)
        if self.za_input.shape[1:] != self.img_size:
            resize_factor = [
                self.img_size[d] / self.za_input.shape[1 + d]
                for d in range(self.n_dims)
            ]
        else:
            resize_factor = [
                1,
            ] * self.n_dims
        img_label = self._get_label_at(i_frame, img_size=self.img_size)
        if self.is_eval:
            tensor_input = torch.from_numpy(img_input)
            tensor_label = torch.from_numpy(img_label)
            tensor_target = torch.cat((tensor_label, tensor_input), )
            return (tensor_input, self.keep_axials), tensor_target

        while True:
            if 0 < sum(self.scale_factors):
                item_crop_size = [
                    randrange(
                        min(
                            img_input.shape[i + 1],
                            round(self.crop_size[i] *
                                  (1. - self.scale_factors[i]))),
                        min(
                            img_input.shape[i + 1] + 1,
                            int(self.crop_size[i] *
                                (1. + self.scale_factors[i])) + 1))
                    for i in range(self.n_dims)
                ]
            else:
                item_crop_size = self.crop_size
            if self.rotation_angle is not None and 0 < self.rotation_angle:
                # rotate image
                theta = randint(-self.rotation_angle, self.rotation_angle)
                cos_theta = math.cos(math.radians(theta))
                sin_theta = math.sin(math.radians(theta))
                for i in (-2, -1):
                    item_crop_size[i] *= (abs(cos_theta) + abs(sin_theta))
                    item_crop_size[i] = math.ceil(item_crop_size[i])
                item_crop_size = [
                    min(img_input.shape[i], item_crop_size[i])
                    for i in range(self.n_dims)
                ]
            za_label_a = zarr.open(self.zpath_flow_label, mode='a')
            index_pool = za_label_a.attrs.get(f'label.indices.{i_frame}')
            if index_pool is None:
                index_pool = np.argwhere(0 < img_label[-1])
                za_label_a.attrs[f'label.indices.{i_frame}'] = tuple(
                    map(tuple, index_pool.tolist()))
            base_index = index_pool[randrange(len(index_pool))]
            origins = [
                randint(
                    max(0, (int(base_index[i] * resize_factor[i]) -
                            (item_crop_size[i] - 1))),
                    min((img_input.shape[1 + i] - item_crop_size[i]),
                        int(base_index[i] * resize_factor[i])))
                for i in range(self.n_dims)
            ]
            slices = (slice(None), ) + tuple(
                slice(origins[i], origins[i] + item_crop_size[i])
                for i in range(self.n_dims))
            sliced_label = img_label[slices].copy()
            # assert 0 < sliced_label[-1].max()
            sliced_input = img_input[slices].copy()

            # scale labels by resize factor
            if 0 < sum(self.scale_factors):
                sliced_label[0] *= self.crop_size[-1] / item_crop_size[-1]  # X
                sliced_label[1] *= self.crop_size[-2] / item_crop_size[-2]  # Y
                if self.n_dims == 3:
                    sliced_label[2] *= self.crop_size[-3] / \
                        item_crop_size[-3]  # Z

            if self.rotation_angle is not None and 0 < self.rotation_angle:
                if self.n_dims == 3:
                    sliced_input = np.array([
                        [
                            rotate(
                                sliced_input[c, z],
                                theta,
                                resize=True,
                                preserve_range=True,
                                order=1,  # 1: Bi-linear (default)
                            ) for z in range(sliced_input.shape[1])
                        ] for c in range(sliced_input.shape[0])
                    ])
                else:
                    sliced_input = np.array([
                        rotate(
                            sliced_input[c],
                            theta,
                            resize=True,
                            preserve_range=True,
                            order=1,  # 1: Bi-linear (default)
                        ) for c in range(sliced_input.shape[0])
                    ])
                h_crop, w_crop = self.crop_size[-2:]
                h_rotate, w_rotate = sliced_input.shape[-2:]
                r_origin = max(0, (h_rotate - h_crop) // 2)
                c_origin = max(0, (w_rotate - w_crop) // 2)
                sliced_input = sliced_input[..., r_origin:r_origin + h_crop,
                                            c_origin:c_origin + w_crop]
                # rotate label
                if self.n_dims == 3:
                    sliced_label = np.array([
                        [
                            rotate(
                                sliced_label[c, z],
                                theta,
                                resize=True,
                                preserve_range=True,
                                order=0,  # 0: Nearest-neighbor
                            ) for z in range(sliced_label.shape[1])
                        ] for c in range(sliced_label.shape[0])
                    ])
                else:
                    sliced_label = np.array([
                        rotate(
                            sliced_label[c],
                            theta,
                            resize=True,
                            preserve_range=True,
                            order=0,  # 0: Nearest-neighbor
                        ) for c in range(sliced_label.shape[0])
                    ])
                sliced_label = sliced_label[..., r_origin:r_origin + h_crop,
                                            c_origin:c_origin + w_crop]
                if sliced_label[-1].max() == 0:
                    continue
                # update flow label (y axis is inversed in the image coordinate)
                cos_theta = math.cos(math.radians(theta))
                sin_theta = math.sin(math.radians(theta))
                sliced_label_x = sliced_label[0].copy()
                sliced_label_y = sliced_label[1].copy() * -1
                sliced_label[0] = (cos_theta * sliced_label_x -
                                   sin_theta * sliced_label_y)
                sliced_label[1] = (sin_theta * sliced_label_x +
                                   cos_theta * sliced_label_y) * -1
            if 0 < sliced_label[-1].max():
                break

        tensor_input = torch.from_numpy(sliced_input)
        tensor_label = torch.from_numpy(sliced_label)
        if tensor_input.shape[1:] != self.crop_size:
            interpolate_mode = 'trilinear' if self.n_dims == 3 else 'bilinear'
            tensor_input = F.interpolate(tensor_input[None],
                                         self.crop_size,
                                         mode=interpolate_mode,
                                         align_corners=True)[0]
            tensor_label = F.interpolate(tensor_label[None].float(),
                                         self.crop_size,
                                         mode='nearest')[0]
        is_flip = True
        if is_flip:
            flip_dims = [
                -(1 + i) for i, v in enumerate(torch.rand(self.n_dims))
                if v < 0.5
            ]
            tensor_input = torch.flip(tensor_input, flip_dims)
            tensor_label = torch.flip(tensor_label, flip_dims)
            for flip_dim in flip_dims:
                tensor_label[-1 - flip_dim] *= -1
        # Channel order: (flow_x, flow_y, flow_z, mask, input_t0, input_t1)
        tensor_target = torch.cat((tensor_label, tensor_input), )
        return (tensor_input, self.keep_axials), tensor_target
示例#14
0
    def __getitem__(self, index):
        """
        Input shape: ((D,) H, W)
        Label shape: ((D,) H, W)
        Label values: 0: unlabeled, 1: BG (LW), 2: Outer (LW), 3: Inner (LW)
                                    4: BG (HW), 5: Outer (HW), 6: Inner (HW)
        """
        if (redis_client is not None and get_state() == TrainState.IDLE.value):
            raise KeyboardInterrupt
        if self.is_ae:
            i_frame = randrange(self.za_input.shape[0])
        else:
            if self.is_livemode:
                while True:
                    v = redis_client.get(REDIS_KEY_TIMEPOINT)
                    if v is not None:
                        i_frame = int(v)
                        img_label = self._get_label_at(i_frame, self.img_size)
                        break
                    if (get_state() == TrainState.IDLE.value):
                        raise KeyboardInterrupt
                    if self.length != int(redis_client.get(REDIS_KEY_NCROPS)):
                        return ((torch.tensor(-200.), self.keep_axials),
                                torch.tensor(-200))
            else:
                if self.length is not None:
                    i_frame = np.random.choice(self.indices)
                else:
                    i_frame = self.indices[index // self.n_crops]
                img_label = self._get_label_at(i_frame, self.img_size)
        img_input = get_input_at(self.za_input,
                                 i_frame,
                                 self.cache_dict_input,
                                 self.memmap_dir,
                                 img_size=self.img_size)
        if self.za_input.shape[1:] != self.img_size:
            resize_factor = [
                self.img_size[d] / self.za_input.shape[1 + d]
                for d in range(img_input.ndim)
            ]
        else:
            resize_factor = [
                1,
            ] * img_input.ndim
        if self.is_eval:
            tensor_input = torch.from_numpy(img_input[None])
            tensor_label = torch.from_numpy(img_label).long()
            return (tensor_input, self.keep_axials), tensor_label - 1

        while True:
            if 0 < sum(self.scale_factors):
                item_crop_size = [
                    randrange(
                        min(
                            img_input.shape[i],
                            round(self.crop_size[i] *
                                  (1. - self.scale_factors[i]))),
                        min(
                            img_input.shape[i] + 1,
                            int(self.crop_size[i] *
                                (1. + self.scale_factors[i])) + 1))
                    for i in range(self.n_dims)
                ]
            else:
                item_crop_size = self.crop_size
            if self.rotation_angle is not None and 0 < self.rotation_angle:
                # rotate image
                theta = randint(-self.rotation_angle, self.rotation_angle)
                cos_theta = math.cos(math.radians(theta))
                sin_theta = math.sin(math.radians(theta))
                for i in (-2, -1):
                    item_crop_size[i] *= (abs(cos_theta) + abs(sin_theta))
                    item_crop_size[i] = math.ceil(item_crop_size[i])
                item_crop_size = [
                    min(img_input.shape[i], item_crop_size[i])
                    for i in range(self.n_dims)
                ]
            if not self.is_ae:
                za_label_a = zarr.open(self.zpath_seg_label, mode='a')
                index_pool = za_label_a.attrs.get(f'label.indices.{i_frame}')
                if index_pool is None:
                    index_pool = np.argwhere(0 < img_label)
                    za_label_a.attrs[f'label.indices.{i_frame}'] = tuple(
                        map(tuple, index_pool.tolist()))
            if self.is_ae:
                origins = [
                    randint(0, img_input.shape[i] - item_crop_size[i])
                    for i in range(self.n_dims)
                ]
            else:
                base_index = index_pool[randrange(len(index_pool))]
                origins = [
                    randint(
                        max(0, (int(base_index[i] * resize_factor[i]) -
                                (item_crop_size[i] - 1))),
                        min((img_input.shape[i] - item_crop_size[i]),
                            int(base_index[i] * resize_factor[i])))
                    for i in range(self.n_dims)
                ]
            slices = tuple(
                slice(origins[i], origins[i] + item_crop_size[i])
                for i in range(self.n_dims))
            if not self.is_ae:
                sliced_label = img_label[slices].copy()

            sliced_input = img_input[slices].copy()

            if not self.is_ae and 0 < self.contrast:
                fg_index = np.isin(sliced_label, (2, 3, 5, 6))
                bg_index = np.isin(sliced_label, (1, 4))
                if fg_index.any() and bg_index.any():
                    fg_mean = sliced_input[fg_index].mean()
                    bg_mean = sliced_input[bg_index].mean()
                    cr_factor = (
                        ((fg_mean - bg_mean) * uniform(self.contrast, 1) +
                         bg_mean) / fg_mean)
                    sliced_input[fg_index] *= cr_factor

            if self.rotation_angle is not None and 0 < self.rotation_angle:
                if self.n_dims == 3:
                    sliced_input = np.array([
                        rotate(
                            sliced_input[z],
                            theta,
                            resize=True,
                            preserve_range=True,
                            order=1,  # 1: Bi-linear (default)
                        ) for z in range(sliced_input.shape[0])
                    ])
                else:
                    sliced_input = rotate(
                        sliced_input,
                        theta,
                        resize=True,
                        preserve_range=True,
                        order=1,  # 1: Bi-linear (default)
                    )
                h_crop, w_crop = self.crop_size[-2:]
                h_rotate, w_rotate = sliced_input.shape[-2:]
                r_origin = max(0, (h_rotate - h_crop) // 2)
                c_origin = max(0, (w_rotate - w_crop) // 2)
                sliced_input = sliced_input[..., r_origin:r_origin + h_crop,
                                            c_origin:c_origin + w_crop]

                # rotate label
                if not self.is_ae:
                    if self.n_dims == 3:
                        sliced_label = np.array([
                            rotate(
                                sliced_label[z],
                                theta,
                                resize=True,
                                preserve_range=True,
                                order=0,  # 0: Nearest-neighbor
                            ) for z in range(sliced_label.shape[0])
                        ])
                    else:
                        sliced_label = rotate(
                            sliced_label,
                            theta,
                            resize=True,
                            preserve_range=True,
                            order=0,  # 0: Nearest-neighbor
                        )
                    sliced_label = sliced_label[...,
                                                r_origin:r_origin + h_crop,
                                                c_origin:c_origin + w_crop]
                    if sliced_label.max() == 0:
                        continue
            break

        tensor_input = torch.from_numpy(sliced_input)[None]
        if not self.is_ae:
            tensor_label = torch.from_numpy(sliced_label)
        if tensor_input.shape[1:] != self.crop_size:
            interpolate_mode = 'trilinear' if self.n_dims == 3 else 'bilinear'
            tensor_input = F.interpolate(tensor_input[None],
                                         self.crop_size,
                                         mode=interpolate_mode,
                                         align_corners=True)[0]
            if not self.is_ae:
                tensor_label = F.interpolate(tensor_label[None, None].float(),
                                             self.crop_size,
                                             mode='nearest')[0, 0]
        if self.is_ae:
            return (tensor_input, self.keep_axials), tensor_input
        tensor_label = tensor_label.long()
        flip_dims = [
            -(1 + i) for i, v in enumerate(torch.rand(self.n_dims)) if v < 0.5
        ]
        tensor_input = torch.flip(tensor_input, flip_dims)
        tensor_label = torch.flip(tensor_label, flip_dims)
        # -1: unlabeled, 0: BG (LW), 1: Outer (LW), 2: Inner (LW)
        #                3: BG (HW), 4: Outer (HW), 5: Inner (HW)
        return (tensor_input, self.keep_axials), tensor_label - 1