def create_app_infrastructure(flask_app): swaggerui_blueprint = get_swaggerui_blueprint( "", "/openapi.json", config={"app_name": "DataGateway API OpenAPI Spec"}, ) flask_app.register_blueprint(swaggerui_blueprint, url_prefix="/") spec = APISpec( title="DataGateway API", version="1.0", openapi_version="3.0.3", plugins=[RestfulPlugin()], security=[{"session_id": []}], ) CORS(flask_app) flask_app.url_map.strict_slashes = False api = CustomErrorHandledApi(flask_app) try: backend_type = flask_app.config["TEST_BACKEND"] config.set_backend_type(backend_type) except KeyError: backend_type = config.get_config_value(APIConfigOptions.BACKEND) if backend_type == "db": flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.get_config_value( APIConfigOptions.DB_URL, ) flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False db.init_app(flask_app) initialise_spec(spec) return (api, spec)
def test_get_valid_session_details( self, flask_test_app_icat, valid_icat_credentials_header, ): session_details = flask_test_app_icat.get( "/sessions", headers=valid_icat_credentials_header, ) session_expiry_datetime = DateHandler.str_to_datetime_object( session_details.json["expireDateTime"], ) current_datetime = datetime.now(tzlocal()) time_diff = abs(session_expiry_datetime - current_datetime) time_diff_minutes = time_diff.seconds / 60 # Allows a bit of leeway for slow test execution assert time_diff_minutes < 120 and time_diff_minutes >= 118 # Check username is correct test_mechanism = config.get_config_value( APIConfigOptions.TEST_MECHANISM) test_username = config.get_config_value( APIConfigOptions.TEST_USER_CREDENTIALS)["username"] assert session_details.json[ "username"] == f"{test_mechanism}/{test_username}" # Check session ID matches the header from the request assert (session_details.json["id"] == valid_icat_credentials_header["Authorization"].split()[1])
def __init__(self): super().__init__( config.get_config_value(APIConfigOptions.ICAT_URL), checkCert=config.get_config_value( APIConfigOptions.ICAT_CHECK_CERT), ) # When clients are cleaned up, sessions won't be logged out self.autoLogout = False
def icat_client(): client = Client( config.get_config_value(APIConfigOptions.ICAT_URL), checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT), ) client.login( config.get_config_value(APIConfigOptions.TEST_MECHANISM), config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS), ) return client
def test_valid_logout(self, flask_test_app_icat): client = Client( config.get_config_value(APIConfigOptions.ICAT_URL), checkCert=config.get_config_value( APIConfigOptions.ICAT_CHECK_CERT), ) client.login( config.get_config_value(APIConfigOptions.TEST_MECHANISM), config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS), ) creds_header = {"Authorization": f"Bearer {client.sessionId}"} logout_response = flask_test_app_icat.delete("/sessions", headers=creds_header) assert logout_response.status_code == 200
def create_client_pool(): """ Function to create an object pool for ICAT client objects The ObjectPool class uses the singleton design pattern """ return ObjectPool( ICATClient, min_init=config.get_config_value( APIConfigOptions.CLIENT_POOL_INIT_SIZE), max_capacity=config.get_config_value( APIConfigOptions.CLIENT_POOL_MAX_SIZE), max_reusable=0, expires=0, )
def get_query_filter(request_filter): """ Given a filter, return a matching Query filter object The filters are imported inside this method to enable the unit tests to not rely on the contents of `config.json`. If they're imported at the top of the file, the backend type won't have been updated if the Flask app has been created from an automated test (file imports occur before `create_api_endpoints()` executes). :param request_filter: The filter to create the QueryFilter for :type request_filter: :class:`dict` :return: The QueryFilter object created :raises ApiError: If the backend type contains an invalid value :raises FilterError: If the filter name is not recognised """ backend_type = config.get_config_value(APIConfigOptions.BACKEND) if backend_type == "db": from datagateway_api.common.database.filters import ( DatabaseDistinctFieldFilter as DistinctFieldFilter, DatabaseIncludeFilter as IncludeFilter, DatabaseLimitFilter as LimitFilter, DatabaseOrderFilter as OrderFilter, DatabaseSkipFilter as SkipFilter, DatabaseWhereFilter as WhereFilter, ) elif backend_type == "python_icat": from datagateway_api.common.icat.filters import ( PythonICATDistinctFieldFilter as DistinctFieldFilter, PythonICATIncludeFilter as IncludeFilter, PythonICATLimitFilter as LimitFilter, PythonICATOrderFilter as OrderFilter, PythonICATSkipFilter as SkipFilter, PythonICATWhereFilter as WhereFilter, ) else: raise ApiError( "Cannot select which implementation of filters to import, check the" " config file has a valid backend type", ) filter_name = list(request_filter)[0].lower() if filter_name == "where": field = list(request_filter[filter_name].keys())[0] operation = list(request_filter[filter_name][field].keys())[0] value = request_filter[filter_name][field][operation] return WhereFilter(field, value, operation) elif filter_name == "order": field = request_filter["order"].split(" ")[0] direction = request_filter["order"].split(" ")[1] return OrderFilter(field, direction) elif filter_name == "skip": return SkipFilter(request_filter["skip"]) elif filter_name == "limit": return LimitFilter(request_filter["limit"]) elif filter_name == "include": return IncludeFilter(request_filter["include"]) elif filter_name == "distinct": return DistinctFieldFilter(request_filter["distinct"]) else: raise FilterError(f" Bad filter: {request_filter}")
def openapi_config(spec): # Reorder paths (e.g. get, patch, post) so openapi.yaml only changes when there's a # change to the Swagger docs, rather than changing on each startup if config.get_config_value(APIConfigOptions.GENERATE_SWAGGER): log.debug("Reordering OpenAPI docs to alphabetical order") for entity_data in spec._paths.values(): for endpoint_name in sorted(entity_data.keys()): entity_data.move_to_end(endpoint_name) openapi_spec_path = Path(__file__).parent / "swagger/openapi.yaml" with open(openapi_spec_path, "w") as f: f.write(spec.to_yaml())
def test_valid_popitem(self): test_cache = ExtendedLRUCache() test_pool = create_client_pool() test_client = Client( config.get_config_value(APIConfigOptions.ICAT_URL), checkCert=config.get_config_value( APIConfigOptions.ICAT_CHECK_CERT), ) test_cache.popitem = MagicMock(side_effect=test_cache.popitem) @cached(cache=test_cache) def get_cached_client(cache_number, client_pool): return test_client for cache_number in range( config.get_config_value(APIConfigOptions.CLIENT_CACHE_SIZE) + 1, ): get_cached_client(cache_number, test_pool) assert test_cache.popitem.called
def create_api_endpoints(flask_app, api, spec): try: backend_type = flask_app.config["TEST_BACKEND"] config.set_backend_type(backend_type) except KeyError: backend_type = config.get_config_value(APIConfigOptions.BACKEND) backend = create_backend(backend_type) icat_client_pool = None if backend_type == "python_icat": # Create client pool icat_client_pool = create_client_pool() for entity_name in endpoints: get_endpoint_resource = get_endpoint( entity_name, endpoints[entity_name], backend, client_pool=icat_client_pool, ) api.add_resource(get_endpoint_resource, f"/{entity_name.lower()}") spec.path(resource=get_endpoint_resource, api=api) get_id_endpoint_resource = get_id_endpoint( entity_name, endpoints[entity_name], backend, client_pool=icat_client_pool, ) api.add_resource(get_id_endpoint_resource, f"/{entity_name.lower()}/<int:id_>") spec.path(resource=get_id_endpoint_resource, api=api) get_count_endpoint_resource = get_count_endpoint( entity_name, endpoints[entity_name], backend, client_pool=icat_client_pool, ) api.add_resource(get_count_endpoint_resource, f"/{entity_name.lower()}/count") spec.path(resource=get_count_endpoint_resource, api=api) get_find_one_endpoint_resource = get_find_one_endpoint( entity_name, endpoints[entity_name], backend, client_pool=icat_client_pool, ) api.add_resource( get_find_one_endpoint_resource, f"/{entity_name.lower()}/findone", ) spec.path(resource=get_find_one_endpoint_resource, api=api) # Session endpoint session_endpoint_resource = session_endpoints(backend, client_pool=icat_client_pool) api.add_resource(session_endpoint_resource, "/sessions") spec.path(resource=session_endpoint_resource, api=api) # Table specific endpoints instrument_facility_cycle_resource = instrument_facility_cycles_endpoint( backend, client_pool=icat_client_pool, ) api.add_resource( instrument_facility_cycle_resource, "/instruments/<int:id_>/facilitycycles", ) spec.path(resource=instrument_facility_cycle_resource, api=api) count_instrument_facility_cycle_res = count_instrument_facility_cycles_endpoint( backend, client_pool=icat_client_pool, ) api.add_resource( count_instrument_facility_cycle_res, "/instruments/<int:id_>/facilitycycles/count", ) spec.path(resource=count_instrument_facility_cycle_res, api=api) instrument_investigation_resource = instrument_investigation_endpoint( backend, client_pool=icat_client_pool, ) api.add_resource( instrument_investigation_resource, "/instruments/<int:instrument_id>/facilitycycles/<int:cycle_id>/investigations", ) spec.path(resource=instrument_investigation_resource, api=api) count_instrument_investigation_resource = count_instrument_investigation_endpoint( backend, client_pool=icat_client_pool, ) api.add_resource( count_instrument_investigation_resource, "/instruments/<int:instrument_id>/facilitycycles/<int:cycle_id>/investigations" "/count", ) spec.path(resource=count_instrument_investigation_resource, api=api)
def test_valid_cache_creation(self): test_cache = ExtendedLRUCache() assert test_cache.maxsize == config.get_config_value( APIConfigOptions.CLIENT_CACHE_SIZE, )
import logging from flask import Flask from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.common.logger_setup import setup_logger from datagateway_api.src.api_start_utils import ( create_api_endpoints, create_app_infrastructure, create_openapi_endpoint, openapi_config, ) setup_logger() log = logging.getLogger() log.info("Logging now setup") app = Flask(__name__) api, spec = create_app_infrastructure(app) create_api_endpoints(app, api, spec) openapi_config(spec) create_openapi_endpoint(app, spec) if __name__ == "__main__": app.run( host=config.get_config_value(APIConfigOptions.HOST), port=config.get_config_value(APIConfigOptions.PORT), debug=config.get_config_value(APIConfigOptions.DEBUG_MODE), use_reloader=config.get_config_value(APIConfigOptions.FLASK_RELOADER), )
class TestSessionHandling: def test_get_valid_session_details( self, flask_test_app_icat, valid_icat_credentials_header, ): session_details = flask_test_app_icat.get( "/sessions", headers=valid_icat_credentials_header, ) session_expiry_datetime = DateHandler.str_to_datetime_object( session_details.json["expireDateTime"], ) current_datetime = datetime.now(tzlocal()) time_diff = abs(session_expiry_datetime - current_datetime) time_diff_minutes = time_diff.seconds / 60 # Allows a bit of leeway for slow test execution assert time_diff_minutes < 120 and time_diff_minutes >= 118 # Check username is correct test_mechanism = config.get_config_value( APIConfigOptions.TEST_MECHANISM) test_username = config.get_config_value( APIConfigOptions.TEST_USER_CREDENTIALS)["username"] assert session_details.json[ "username"] == f"{test_mechanism}/{test_username}" # Check session ID matches the header from the request assert (session_details.json["id"] == valid_icat_credentials_header["Authorization"].split()[1]) def test_get_invalid_session_details( self, bad_credentials_header, flask_test_app_icat, ): session_details = flask_test_app_icat.get( "/sessions", headers=bad_credentials_header, ) assert session_details.status_code == 403 def test_refresh_session(self, valid_icat_credentials_header, flask_test_app_icat): pre_refresh_session_details = flask_test_app_icat.get( "/sessions", headers=valid_icat_credentials_header, ) refresh_session = flask_test_app_icat.put( "/sessions", headers=valid_icat_credentials_header, ) post_refresh_session_details = flask_test_app_icat.get( "/sessions", headers=valid_icat_credentials_header, ) assert refresh_session.status_code == 200 assert (pre_refresh_session_details.json["expireDateTime"] != post_refresh_session_details.json["expireDateTime"]) @pytest.mark.usefixtures("single_investigation_test_data") @pytest.mark.parametrize( "request_body", [ pytest.param( { "username": config.get_config_value( APIConfigOptions.TEST_USER_CREDENTIALS, )["username"], "password": config.get_config_value( APIConfigOptions.TEST_USER_CREDENTIALS, )["password"], "mechanism": config.get_config_value(APIConfigOptions.TEST_MECHANISM, ), }, id="Normal request body", ), pytest.param( { "username": config.get_config_value( APIConfigOptions.TEST_USER_CREDENTIALS, )["username"], "password": config.get_config_value( APIConfigOptions.TEST_USER_CREDENTIALS, )["password"], }, id="Missing mechanism in request body", ), ], ) def test_valid_login( self, flask_test_app_icat, icat_client, icat_query, request_body, ): login_response = flask_test_app_icat.post("/sessions", json=request_body) icat_client.sessionId = login_response.json["sessionID"] icat_query.setAggregate("COUNT") title_filter = PythonICATWhereFilter( "title", "Test data for the Python ICAT Backend on DataGateway API", "like", ) title_filter.apply_filter(icat_query) test_query = icat_client.search(icat_query) assert test_query == [1] and login_response.status_code == 201 @pytest.mark.parametrize( "request_body, expected_response_code", [ pytest.param( { "username": "******", "password": "******", "mechanism": config.get_config_value(APIConfigOptions.TEST_MECHANISM, ), }, 403, id="Invalid credentials", ), pytest.param({}, 400, id="Missing credentials"), ], ) def test_invalid_login( self, flask_test_app_icat, request_body, expected_response_code, ): login_response = flask_test_app_icat.post("/sessions", json=request_body) assert login_response.status_code == expected_response_code def test_valid_logout(self, flask_test_app_icat): client = Client( config.get_config_value(APIConfigOptions.ICAT_URL), checkCert=config.get_config_value( APIConfigOptions.ICAT_CHECK_CERT), ) client.login( config.get_config_value(APIConfigOptions.TEST_MECHANISM), config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS), ) creds_header = {"Authorization": f"Bearer {client.sessionId}"} logout_response = flask_test_app_icat.delete("/sessions", headers=creds_header) assert logout_response.status_code == 200 def test_invalid_logout(self, bad_credentials_header, flask_test_app_icat): logout_response = flask_test_app_icat.delete( "/sessions", headers=bad_credentials_header, ) assert logout_response.status_code == 403
def __init__(self): super().__init__(maxsize=config.get_config_value( APIConfigOptions.CLIENT_CACHE_SIZE), )
"--years", "-y", dest="years", help="Provide number of years to generate", type=int, default=20, ) args = parser.parse_args() SEED = args.seed YEARS = args.years # 4 Cycles per years generated faker = Faker() Faker.seed(SEED) engine = create_engine( config.get_config_value(APIConfigOptions.DB_URL), poolclass=QueuePool, pool_size=100, max_overflow=0, ) session_factory = sessionmaker(engine) session = scoped_session(session_factory)() def post_entity(entity): """ Given an entity, insert it into the ICAT DB :param entity: The entity to be inserted :return: None """ session.add(entity)