Example #1
0
def predict_seg():
    if request.headers['Content-Type'] != 'application/json':
        return jsonify(res='error'), 400
    req_json = request.get_json()
    req_json['device'] = device()
    config = SegmentationEvalConfig(req_json)
    print(config)
    state = int(redis_client.get(REDIS_KEY_STATE))
    redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
    try:
        spots = detect_spots(config)
        get_mq_connection().channel().basic_publish(
            exchange='',
            routing_key='prediction',
            body='Prediction updated'
        )
    except RuntimeError as e:
        print(traceback.format_exc())
        return jsonify(error=f'Runtime Error: {e}'), 500
    except Exception as e:
        print(traceback.format_exc())
        return jsonify(error=f'Exception: {e}'), 500
    finally:
        gc.collect()
        torch.cuda.empty_cache()
        redis_client.set(REDIS_KEY_STATE, state)
    return jsonify({'spots': spots})
Example #2
0
def main():
    config_dict = {
        'dataset_name': "CMU-1",
        'timepoint': 0,
        'model_name': "CMU-1_detection.pth",
        "device": "cuda",
        "is_3d": False,
        "crop_size": [384, 384],
        "scales": [0.5, 0.5],
        "cache_maxbytes": 0,
        "use_median": True,
        "patch": [4096, 4096],
        "use_memmap": False,
    }
    config = SegmentationEvalConfig(config_dict)
    print(config)
    detect_spots(
        str(config.device),
        config.model_path,
        config.keep_axials,
        config.is_pad,
        config.is_3d,
        config.crop_size,
        config.scales,
        config.cache_maxbytes,
        config.use_2d,
        config.use_median,
        config.patch_size,
        config.crop_box,
        config.c_ratio,
        config.p_thresh,
        config.r_min,
        config.r_max,
        config.output_prediction,
        config.zpath_input,
        config.zpath_seg_output,
        config.timepoint,
        None,
        config.memmap_dir,
        config.batch_size,
        config.input_size,
    )
Example #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('command', help='detection | linking | export')
    parser.add_argument('input', help='input directory')
    parser.add_argument('config', help='config file')
    parser.add_argument('spots', help='spots file')
    parser.add_argument('--output', help='output directory')
    args = parser.parse_args()
    # list up input image files
    files = [os.path.join(args.input, f)
             for f in sorted(os.listdir(args.input)) if f.endswith('.tif')]
    with io.open(args.config, 'r', encoding='utf-8') as jsonfile:
        config_data = json.load(jsonfile)
    if args.command == 'detection':
        config = SegmentationEvalConfigTiff(config_data)
        print(config)
        spots = []
        for i, f in tqdm(enumerate(files)):
            config.timepoint = i
            config.tiff_input = f
            spots.extend(detect_spots(config))
        with open(args.spots, 'w') as f:
            json.dump({'spots': spots}, f)
    elif args.command == 'linking':
        with io.open(args.spots, 'r', encoding='utf-8') as jsonfile:
            spots_config_data = json.load(jsonfile)
            t = spots_config_data.get('t')
            spots = spots_config_data.get('spots')
        config_data['timepoint'] = t
        config_data['tiffinput'] = files[t-1:t+1]
        config = FlowEvalConfigTiff(config_data)
        print(config)
        # estimate previous spot positions with flow
        res_spots = spots_with_flow(config, spots)
        with open(args.spots, 'w') as f:
            json.dump({'spots': res_spots}, f)
    elif args.command == 'export':
        config_data['savedir'] = args.output
        config_data['shape'] = skimage.io.imread(files[0]).shape
        config = ExportConfig(config_data)
        print(config)
        # load spots and group by t
        with io.open(args.spots, 'r', encoding='utf-8') as jsonfile:
            spots_data = json.load(jsonfile)
        spots_dict = collections.defaultdict(list)
        for spot in spots_data:
            spots_dict[spot['t']].append(spot)
        spots_dict = collections.OrderedDict(sorted(spots_dict.items()))
        # export labels
        export_ctc_labels(config, spots_dict)
    else:
        parser.print_help()
Example #4
0
def detect_spots_task(device,
                      model_path,
                      keep_axials=(True, ) * 4,
                      is_pad=False,
                      is_3d=True,
                      crop_size=(16, 384, 384),
                      scales=None,
                      cache_maxbytes=None,
                      use_2d=False,
                      use_median=False,
                      patch_size=None,
                      crop_box=None,
                      c_ratio=0.4,
                      p_thresh=0.5,
                      r_min=0,
                      r_max=1e6,
                      output_prediction=False,
                      zpath_input=None,
                      zpath_seg_output=None,
                      timepoint=None,
                      tiff_input=None,
                      memmap_dir=None,
                      batch_size=1,
                      input_size=None):
    """
    Detect spots at the specified timepoint.

    Parameters
    ----------

    Returns
    -------
    spots : list
        Detected spots as list. None is returned on error or cancel.
    """
    return detect_spots(device, model_path, keep_axials, is_pad, is_3d,
                        crop_size, scales, cache_maxbytes, use_2d, use_median,
                        patch_size, crop_box, c_ratio, p_thresh, r_min, r_max,
                        output_prediction, zpath_input, zpath_seg_output,
                        timepoint, tiff_input, memmap_dir, batch_size,
                        tuple(input_size))
Example #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('command', help='detection | linking | export')
    parser.add_argument('input', help='input directory')
    parser.add_argument('config', help='config file')
    parser.add_argument('spots', help='spots file')
    parser.add_argument('--output', help='output directory')
    args = parser.parse_args()
    # list up input image files
    files = [
        os.path.join(args.input, f) for f in sorted(os.listdir(args.input))
        if f.endswith('.tif')
    ]
    with io.open(args.config, 'r', encoding='utf-8') as jsonfile:
        config_data = json.load(jsonfile)
    if args.command == 'detection':
        config_data['patch'] = [
            int(get_next_multiple(s * 0.75, 16))
            for s in skimage.io.imread(files[0]).shape[-2:]
        ]
        config = SegmentationEvalConfigTiff(config_data)
        print(config)
        spots = []
        for i, f in tqdm(enumerate(files)):
            config.timepoint = i
            config.tiff_input = f
            spots.extend(
                detect_spots(
                    str(config.device),
                    config.model_path,
                    config.keep_axials,
                    config.is_pad,
                    config.is_3d,
                    config.crop_size,
                    config.scales,
                    config.cache_maxbytes,
                    config.use_2d,
                    config.use_median,
                    config.patch_size,
                    config.crop_box,
                    config.c_ratio,
                    config.p_thresh,
                    config.r_min,
                    config.r_max,
                    config.output_prediction,
                    None,
                    None,
                    config.timepoint,
                    config.tiff_input,
                    config.memmap_dir,
                    config.batch_size,
                    config.input_size,
                ))
        with open(args.spots, 'w') as f:
            json.dump({'spots': spots}, f)
    elif args.command == 'linking':
        with io.open(args.spots, 'r', encoding='utf-8') as jsonfile:
            spots_config_data = json.load(jsonfile)
            t = spots_config_data.get('t')
            spots = spots_config_data.get('spots')
        config_data['timepoint'] = t
        config_data['tiffinput'] = files[t - 1:t + 1]
        config = FlowEvalConfigTiff(config_data)
        print(config)
        # estimate previous spot positions with flow
        res_spots = spots_with_flow(config, spots)
        with open(args.spots, 'w') as f:
            json.dump({'spots': res_spots}, f)
    elif args.command == 'export':
        config_data['savedir'] = args.output
        config_data['shape'] = skimage.io.imread(files[0]).shape
        config_data['t_start'] = 0
        config_data['t_end'] = len(files) - 1
        config = ExportConfig(config_data)
        print(config)
        # load spots and group by t
        with io.open(args.spots, 'r', encoding='utf-8') as jsonfile:
            spots_data = json.load(jsonfile)
        spots_dict = collections.defaultdict(list)
        for spot in spots_data:
            spots_dict[spot['t']].append(spot)
        spots_dict = collections.OrderedDict(sorted(spots_dict.items()))
        # export labels
        export_ctc_labels(config, spots_dict)
    else:
        parser.print_help()