def parse_args():
    parser = argparse.ArgumentParser()
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("-s2D", "--segmentation_model_2D", default=False, help="If the model's output is a 2D segmentation mask",
        action='store_true')
    group.add_argument("-s3D", "--segmentation_model_3D", default=False, help="If the model's output is a 3D segmentation mask",
        action='store_true')
    group.add_argument("-b", "--bounding_box_model", default=False, help="If the model's output are bounding boxes",
        action='store_true')
    group.add_argument("-cl", "--classification_model", default=False, help="If the model's output are labels",
        action='store_true')  
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    args = parse_args()
    app = Gateway(__name__)
    app.register_error_handler(Exception, handle_exception)
    if args.bounding_box_model:
        app.add_inference_route('/', request_handler_bbox)
    elif args.segmentation_model_3D:
        app.add_inference_route('/', request_handler_3D_segmentation)
    elif args.classification_model:
        app.add_inference_route('/', request_handler_classification)
    else:
        app.add_inference_route('/', request_handler_2D_segmentation)

    app.add_healthcheck_route(healthcheck_handler)
    app.run(host='0.0.0.0', port=8000, debug=True, use_reloader=True)
示例#2
0
                {
                    "label": result_data['type'],
                    "probability": round(result_data['probability'], 2),
                    "SOPInstanceUID": dcm.SOPInstanceUID,
                    "top_left": [0, 0],
                    "bottom_right": [image_width, image_height]
                }
            )

    return response_json, []


def request_handler(json_input, dicom_instances, input_digest):
    """
    A mock inference model that returns a mask array of ones of size (height * depth, width)
    """
    transaction_logger = tagged_logger.TaggedLogger(logger)
    transaction_logger.add_tags({'input_hash': input_digest})
    transaction_logger.info('mock_model received json_input={}'.format(json_input))

    return get_prediction_covid(dicom_instances)


if __name__ == '__main__':
    x_ray = Xray()
    app = Gateway(__name__)
    app.register_error_handler(Exception, handle_exception)
    app.add_inference_route('/', request_handler)

    app.run(host='0.0.0.0', port=8000, debug=True, use_reloader=True)