Exemplo n.º 1
0
    def post(self):
        '''
        Reset seg model.

        '''
        if all(ctype not in request.headers['Content-Type']
               for ctype in ('multipart/form-data', 'application/json')):
            msg = ('Content-Type should be multipart/form-data or '
                   'application/json')
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        if get_state() != TrainState.IDLE.value:
            msg = 'Process is running. Model cannot be reset.'
            logger().error(msg)
            return make_response(jsonify(error=msg), 500)
        try:
            device = get_device()
            if 'multipart/form-data' in request.headers['Content-Type']:
                print(request.form)
                req_json = json.loads(request.form.get('data'))
                file = request.files['file']
                checkpoint = torch.load(file.stream, map_location=device)
                state_dicts = checkpoint if isinstance(checkpoint,
                                                       list) else [checkpoint]
                req_json['url'] = None
            else:
                req_json = request.get_json()
                state_dicts = None
            req_json['device'] = device
            config = ResetConfig(req_json)
            logger().info(config)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            init_seg_models(config.model_path,
                            config.keep_axials,
                            config.device,
                            config.is_3d,
                            config.n_models,
                            config.n_crops,
                            config.zpath_input,
                            config.crop_size,
                            config.scales,
                            url=config.url,
                            state_dicts=state_dicts,
                            is_cpu=config.is_cpu(),
                            cache_maxbytes=config.cache_maxbytes)
        except RuntimeError as e:
            logger().exception('Failed in init_seg_models')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in init_seg_models')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
Exemplo n.º 2
0
    def post(self):
        '''
        Update flow label.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        if redis_client.get(REDIS_KEY_UPDATE_ONGOING_FLOW):
            msg = 'Last update is ongoing'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = FlowTrainConfig(req_json)
            logger().info(config)
            if req_json.get('reset'):
                zarr.open_like(zarr.open(config.zpath_flow_label, mode='r'),
                               config.zpath_flow_label,
                               mode='w')
                return make_response(jsonify({'completed': True}))

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not spots_dict:
                msg = 'nothing to update'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            response = _update_flow_labels(spots_dict, config.scales,
                                           config.zpath_flow_label,
                                           config.flow_norm_factor)
        except RuntimeError as e:
            logger().exception('Failed in update_flow_labels')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in update_flow_labels')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
            redis_client.delete(REDIS_KEY_UPDATE_ONGOING_FLOW)
        return response
Exemplo n.º 3
0
 def __init__(self, config):
     self.dataset_name = config.get('dataset_name')
     self.timepoint = config.get('timepoint')
     if config.get('model_name') is not None:
         self.model_path = os.path.join(MODELS_DIR,
                                        config.get('model_name'))
     self.device = config.get('device', get_device())
     if not torch.cuda.is_available():
         self.device = torch.device("cpu")
     self.is_mixed_precision = config.get('is_mixed_precision', True)
     self.debug = config.get('debug', False)
     self.output_prediction = config.get('output_prediction', False)
     self.is_3d = config.get('is_3d', True)
     self.use_2d = config.get('use_2d', False)
     self.batch_size = config.get('batch_size', 1)
     self.patch_size = config.get('patch')
     self.log_interval = config.get('log_interval', 10)
     self.cache_maxbytes = config.get('cache_maxbytes')
     if self.patch_size is not None:
         self.patch_size = self.patch_size[::-1]
     if config.get('scales') is not None:
         self.scales = config.get('scales')[::-1]
     else:
         self.scales = None
     if config.get('input_size') is not None:
         self.input_size = tuple(config.get('input_size')[::-1])
     else:
         self.input_size = None
     # U-Net has 4 downsamplings
     n_keep_axials = min(4, config.get('n_keep_axials', 4))
     self.keep_axials = tuple(True if i < n_keep_axials else False
                              for i in range(4))
     self.crop_size = config.get('crop_size', DEFAULT_CROP_SIZE)[::-1]
     if not self.is_3d:
         if self.patch_size is not None:
             self.patch_size = self.patch_size[-2:]
         if self.scales is not None:
             self.scales = self.scales[-2:]
         self.crop_size = self.crop_size[-2:]
         if self.input_size is not None:
             self.input_size = self.input_size[-2:]
     if self.dataset_name is not None:
         self.zpath_input = os.path.join(DATASETS_DIR, self.dataset_name,
                                         ZARR_INPUT)
     else:
         self.zpath_input = None
     if config.get('use_memmap'):
         self.memmap_dir = MEMMAPS_DIR
         Path(self.memmap_dir).mkdir(parents=True, exist_ok=True)
     else:
         self.memmap_dir = None
Exemplo n.º 4
0
    def post(self):
        '''
        Download outputs in CTC format.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response('', 204)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = ExportConfig(req_json)
            za_input = zarr.open(config.zpath_input, mode='r')
            config.shape = za_input.shape[1:]
            logger().info(config)
            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            result = export_ctc_labels(config, spots_dict)
            if isinstance(result, str):
                resp = send_file(result)
                # file_remover.cleanup_once_done(resp, result)
            elif not result:
                resp = make_response('', 204)
            else:
                resp = make_response('', 200)
        except RuntimeError as e:
            logger().exception('Failed in export_ctc_labels')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in export_ctc_labels')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
        return resp
Exemplo n.º 5
0
    def post(self):
        '''
        Reset flow model.

        '''
        if all(ctype not in request.headers['Content-Type']
               for ctype in ('multipart/form-data', 'application/json')):
            msg = ('Content-Type should be multipart/form-data or '
                   'application/json')
            logger().error(msg)
            return (jsonify(error=msg), 400)
        if get_state() != TrainState.IDLE.value:
            return make_response(jsonify(error='Process is running'), 500)
        try:
            device = get_device()
            if 'multipart/form-data' in request.headers['Content-Type']:
                print(request.form)
                req_json = json.loads(request.form.get('data'))
                file = request.files['file']
                checkpoint = torch.load(file.stream, map_location=device)
                state_dicts = checkpoint if isinstance(checkpoint,
                                                       list) else [checkpoint]
                req_json['url'] = None
            else:
                req_json = request.get_json()
                state_dicts = None
            req_json['device'] = device
            config = ResetConfig(req_json)
            logger().info(config)
            init_flow_models(config.model_path,
                             config.device,
                             config.is_3d,
                             url=config.url,
                             state_dicts=state_dicts)
        except RuntimeError as e:
            logger().exception('Failed in reset_flow_models')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in reset_flow_models')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        return make_response(jsonify({'completed': True}))
Exemplo n.º 6
0
    def post(self):
        '''
        Train seg model.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.n_crops < 1:
                msg = 'n_crops should be a positive number'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not (spots_dict or config.is_livemode):
                msg = 'nothing to train'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            if get_state() != TrainState.IDLE.value:
                msg = 'Process is running'
                logger().error(msg)
                return make_response(jsonify(error=msg), 500)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            if config.is_livemode:
                redis_client.delete(REDIS_KEY_TIMEPOINT)
            else:
                try:
                    _update_seg_labels(spots_dict,
                                       config.scales,
                                       config.zpath_input,
                                       config.zpath_seg_label,
                                       config.zpath_seg_label_vis,
                                       config.auto_bg_thresh,
                                       config.c_ratio,
                                       memmap_dir=config.memmap_dir)
                except KeyboardInterrupt:
                    return make_response(jsonify({'completed': False}))
            step_offset = 0
            for path in sorted(Path(config.log_dir).glob('event*')):
                try:
                    *_, last_record = TFRecordDataset(str(path))
                    last = event_pb2.Event.FromString(last_record.numpy()).step
                    step_offset = max(step_offset, last + 1)
                except Exception:
                    pass
            epoch_start = 0
            async_result = train_seg_task.delay(
                list(spots_dict.keys()),
                config.batch_size,
                config.crop_size,
                config.class_weights,
                config.false_weight,
                config.model_path,
                config.n_epochs,
                config.keep_axials,
                config.scales,
                config.lr,
                config.n_crops,
                config.is_3d,
                config.is_livemode,
                config.scale_factor_base,
                config.rotation_angle,
                config.contrast,
                config.zpath_input,
                config.zpath_seg_label,
                config.log_interval,
                config.log_dir,
                step_offset,
                epoch_start,
                config.is_cpu(),
                config.is_mixed_precision,
                config.cache_maxbytes,
                config.memmap_dir,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('training aborted')
                    return make_response(jsonify({'completed': False}))
        except Exception as e:
            logger().exception('Failed in train_seg')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
Exemplo n.º 7
0
    def post(self):
        '''
        Update seg label.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        if redis_client.get(REDIS_KEY_UPDATE_ONGOING_SEG):
            msg = 'Last update is ongoing'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            redis_client.set(REDIS_KEY_UPDATE_ONGOING_SEG, 1)
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.is_livemode and redis_client.get(REDIS_KEY_TIMEPOINT):
                msg = 'Last update/training is ongoing'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            if req_json.get('reset'):
                try:
                    zarr.open_like(zarr.open(config.zpath_seg_label, mode='r'),
                                   config.zpath_seg_label,
                                   mode='w')
                    zarr.open_like(zarr.open(config.zpath_seg_label_vis,
                                             mode='r'),
                                   config.zpath_seg_label_vis,
                                   mode='w')
                except RuntimeError as e:
                    logger().exception('Failed in opening zarr')
                    return make_response(jsonify(error=f'Runtime Error: {e}'),
                                         500)
                except Exception as e:
                    logger().exception('Failed in opening zarr')
                    return make_response(jsonify(error=f'Exception: {e}'), 500)
                return make_response(jsonify({'completed': True}))

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not spots_dict:
                msg = 'nothing to update'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            if config.is_livemode and len(spots_dict.keys()) != 1:
                msg = 'Livemode should update only a single timepoint'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            try:
                response = _update_seg_labels(spots_dict,
                                              config.scales,
                                              config.zpath_input,
                                              config.zpath_seg_label,
                                              config.zpath_seg_label_vis,
                                              config.auto_bg_thresh,
                                              config.c_ratio,
                                              config.is_livemode,
                                              memmap_dir=config.memmap_dir)
            except KeyboardInterrupt:
                return make_response(jsonify({'completed': False}))
            except Exception as e:
                logger().exception('Failed in _update_seg_labels')
                return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
            redis_client.delete(REDIS_KEY_UPDATE_ONGOING_SEG)
        return response
Exemplo n.º 8
0
    def post(self):
        '''
        Predict seg.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationEvalConfig(req_json)
            logger().info(config)
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            async_result = detect_spots_task.delay(
                str(config.device),
                config.model_path,
                config.keep_axials,
                config.is_pad,
                config.is_3d,
                config.crop_size,
                config.scales,
                config.cache_maxbytes,
                config.use_2d,
                config.use_median,
                config.patch_size,
                config.crop_box,
                config.c_ratio,
                config.p_thresh,
                config.r_min,
                config.r_max,
                config.output_prediction,
                config.zpath_input,
                config.zpath_seg_output,
                config.timepoint,
                None,
                config.memmap_dir,
                config.batch_size,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('prediction aborted')
                    return make_response(
                        jsonify({
                            'spots': [],
                            'completed': False
                        }))
            if async_result.failed():
                raise async_result.result
            spots = async_result.result
            if spots is None:
                logger().info('prediction aborted')
                return make_response(jsonify({
                    'spots': [],
                    'completed': False
                }))
            publish_mq('prediction', 'Prediction updated')
        except Exception as e:
            logger().exception('Failed in detect_spots')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
        return make_response(jsonify({'spots': spots, 'completed': True}))
Exemplo n.º 9
0
def test_get_device():
    assert get_device() in (torch.device("cpu"), torch.device("cuda"))
Exemplo n.º 10
0
def device():
    if 'device' not in g:
        g.device = get_device()
    return g.device