def create_app_infrastructure(flask_app):
    swaggerui_blueprint = get_swaggerui_blueprint(
        "", "/openapi.json", config={"app_name": "DataGateway API OpenAPI Spec"},
    )
    flask_app.register_blueprint(swaggerui_blueprint, url_prefix="/")
    spec = APISpec(
        title="DataGateway API",
        version="1.0",
        openapi_version="3.0.3",
        plugins=[RestfulPlugin()],
        security=[{"session_id": []}],
    )

    CORS(flask_app)
    flask_app.url_map.strict_slashes = False
    api = CustomErrorHandledApi(flask_app)

    try:
        backend_type = flask_app.config["TEST_BACKEND"]
        config.set_backend_type(backend_type)
    except KeyError:
        backend_type = config.get_config_value(APIConfigOptions.BACKEND)

    if backend_type == "db":
        flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.get_config_value(
            APIConfigOptions.DB_URL,
        )
        flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
        db.init_app(flask_app)

    initialise_spec(spec)

    return (api, spec)
Ejemplo n.º 2
0
    def test_get_valid_session_details(
        self,
        flask_test_app_icat,
        valid_icat_credentials_header,
    ):
        session_details = flask_test_app_icat.get(
            "/sessions",
            headers=valid_icat_credentials_header,
        )

        session_expiry_datetime = DateHandler.str_to_datetime_object(
            session_details.json["expireDateTime"], )

        current_datetime = datetime.now(tzlocal())
        time_diff = abs(session_expiry_datetime - current_datetime)
        time_diff_minutes = time_diff.seconds / 60

        # Allows a bit of leeway for slow test execution
        assert time_diff_minutes < 120 and time_diff_minutes >= 118

        # Check username is correct
        test_mechanism = config.get_config_value(
            APIConfigOptions.TEST_MECHANISM)
        test_username = config.get_config_value(
            APIConfigOptions.TEST_USER_CREDENTIALS)["username"]
        assert session_details.json[
            "username"] == f"{test_mechanism}/{test_username}"

        # Check session ID matches the header from the request
        assert (session_details.json["id"] ==
                valid_icat_credentials_header["Authorization"].split()[1])
Ejemplo n.º 3
0
 def __init__(self):
     super().__init__(
         config.get_config_value(APIConfigOptions.ICAT_URL),
         checkCert=config.get_config_value(
             APIConfigOptions.ICAT_CHECK_CERT),
     )
     # When clients are cleaned up, sessions won't be logged out
     self.autoLogout = False
Ejemplo n.º 4
0
def icat_client():
    client = Client(
        config.get_config_value(APIConfigOptions.ICAT_URL),
        checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT),
    )
    client.login(
        config.get_config_value(APIConfigOptions.TEST_MECHANISM),
        config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS),
    )
    return client
Ejemplo n.º 5
0
    def test_valid_logout(self, flask_test_app_icat):
        client = Client(
            config.get_config_value(APIConfigOptions.ICAT_URL),
            checkCert=config.get_config_value(
                APIConfigOptions.ICAT_CHECK_CERT),
        )
        client.login(
            config.get_config_value(APIConfigOptions.TEST_MECHANISM),
            config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS),
        )
        creds_header = {"Authorization": f"Bearer {client.sessionId}"}

        logout_response = flask_test_app_icat.delete("/sessions",
                                                     headers=creds_header)

        assert logout_response.status_code == 200
Ejemplo n.º 6
0
def create_client_pool():
    """
    Function to create an object pool for ICAT client objects

    The ObjectPool class uses the singleton design pattern
    """

    return ObjectPool(
        ICATClient,
        min_init=config.get_config_value(
            APIConfigOptions.CLIENT_POOL_INIT_SIZE),
        max_capacity=config.get_config_value(
            APIConfigOptions.CLIENT_POOL_MAX_SIZE),
        max_reusable=0,
        expires=0,
    )
    def get_query_filter(request_filter):
        """
        Given a filter, return a matching Query filter object

        The filters are imported inside this method to enable the unit tests to not rely
        on the contents of `config.json`. If they're imported at the top of the file,
        the backend type won't have been updated if the Flask app has been created from
        an automated test (file imports occur before `create_api_endpoints()` executes).

        :param request_filter: The filter to create the QueryFilter for
        :type request_filter: :class:`dict`
        :return: The QueryFilter object created
        :raises ApiError: If the backend type contains an invalid value
        :raises FilterError: If the filter name is not recognised
        """

        backend_type = config.get_config_value(APIConfigOptions.BACKEND)
        if backend_type == "db":
            from datagateway_api.common.database.filters import (
                DatabaseDistinctFieldFilter as DistinctFieldFilter,
                DatabaseIncludeFilter as IncludeFilter,
                DatabaseLimitFilter as LimitFilter,
                DatabaseOrderFilter as OrderFilter,
                DatabaseSkipFilter as SkipFilter,
                DatabaseWhereFilter as WhereFilter,
            )
        elif backend_type == "python_icat":
            from datagateway_api.common.icat.filters import (
                PythonICATDistinctFieldFilter as DistinctFieldFilter,
                PythonICATIncludeFilter as IncludeFilter,
                PythonICATLimitFilter as LimitFilter,
                PythonICATOrderFilter as OrderFilter,
                PythonICATSkipFilter as SkipFilter,
                PythonICATWhereFilter as WhereFilter,
            )
        else:
            raise ApiError(
                "Cannot select which implementation of filters to import, check the"
                " config file has a valid backend type", )

        filter_name = list(request_filter)[0].lower()
        if filter_name == "where":
            field = list(request_filter[filter_name].keys())[0]
            operation = list(request_filter[filter_name][field].keys())[0]
            value = request_filter[filter_name][field][operation]
            return WhereFilter(field, value, operation)
        elif filter_name == "order":
            field = request_filter["order"].split(" ")[0]
            direction = request_filter["order"].split(" ")[1]
            return OrderFilter(field, direction)
        elif filter_name == "skip":
            return SkipFilter(request_filter["skip"])
        elif filter_name == "limit":
            return LimitFilter(request_filter["limit"])
        elif filter_name == "include":
            return IncludeFilter(request_filter["include"])
        elif filter_name == "distinct":
            return DistinctFieldFilter(request_filter["distinct"])
        else:
            raise FilterError(f" Bad filter: {request_filter}")
def openapi_config(spec):
    # Reorder paths (e.g. get, patch, post) so openapi.yaml only changes when there's a
    # change to the Swagger docs, rather than changing on each startup
    if config.get_config_value(APIConfigOptions.GENERATE_SWAGGER):
        log.debug("Reordering OpenAPI docs to alphabetical order")
        for entity_data in spec._paths.values():
            for endpoint_name in sorted(entity_data.keys()):
                entity_data.move_to_end(endpoint_name)

        openapi_spec_path = Path(__file__).parent / "swagger/openapi.yaml"
        with open(openapi_spec_path, "w") as f:
            f.write(spec.to_yaml())
    def test_valid_popitem(self):
        test_cache = ExtendedLRUCache()
        test_pool = create_client_pool()
        test_client = Client(
            config.get_config_value(APIConfigOptions.ICAT_URL),
            checkCert=config.get_config_value(
                APIConfigOptions.ICAT_CHECK_CERT),
        )

        test_cache.popitem = MagicMock(side_effect=test_cache.popitem)

        @cached(cache=test_cache)
        def get_cached_client(cache_number, client_pool):
            return test_client

        for cache_number in range(
                config.get_config_value(APIConfigOptions.CLIENT_CACHE_SIZE) +
                1, ):
            get_cached_client(cache_number, test_pool)

        assert test_cache.popitem.called
def create_api_endpoints(flask_app, api, spec):
    try:
        backend_type = flask_app.config["TEST_BACKEND"]
        config.set_backend_type(backend_type)
    except KeyError:
        backend_type = config.get_config_value(APIConfigOptions.BACKEND)

    backend = create_backend(backend_type)

    icat_client_pool = None
    if backend_type == "python_icat":
        # Create client pool
        icat_client_pool = create_client_pool()

    for entity_name in endpoints:
        get_endpoint_resource = get_endpoint(
            entity_name, endpoints[entity_name], backend, client_pool=icat_client_pool,
        )
        api.add_resource(get_endpoint_resource, f"/{entity_name.lower()}")
        spec.path(resource=get_endpoint_resource, api=api)

        get_id_endpoint_resource = get_id_endpoint(
            entity_name, endpoints[entity_name], backend, client_pool=icat_client_pool,
        )
        api.add_resource(get_id_endpoint_resource, f"/{entity_name.lower()}/<int:id_>")
        spec.path(resource=get_id_endpoint_resource, api=api)

        get_count_endpoint_resource = get_count_endpoint(
            entity_name, endpoints[entity_name], backend, client_pool=icat_client_pool,
        )
        api.add_resource(get_count_endpoint_resource, f"/{entity_name.lower()}/count")
        spec.path(resource=get_count_endpoint_resource, api=api)

        get_find_one_endpoint_resource = get_find_one_endpoint(
            entity_name, endpoints[entity_name], backend, client_pool=icat_client_pool,
        )
        api.add_resource(
            get_find_one_endpoint_resource, f"/{entity_name.lower()}/findone",
        )
        spec.path(resource=get_find_one_endpoint_resource, api=api)

    # Session endpoint
    session_endpoint_resource = session_endpoints(backend, client_pool=icat_client_pool)
    api.add_resource(session_endpoint_resource, "/sessions")
    spec.path(resource=session_endpoint_resource, api=api)

    # Table specific endpoints
    instrument_facility_cycle_resource = instrument_facility_cycles_endpoint(
        backend, client_pool=icat_client_pool,
    )
    api.add_resource(
        instrument_facility_cycle_resource, "/instruments/<int:id_>/facilitycycles",
    )
    spec.path(resource=instrument_facility_cycle_resource, api=api)

    count_instrument_facility_cycle_res = count_instrument_facility_cycles_endpoint(
        backend, client_pool=icat_client_pool,
    )
    api.add_resource(
        count_instrument_facility_cycle_res,
        "/instruments/<int:id_>/facilitycycles/count",
    )
    spec.path(resource=count_instrument_facility_cycle_res, api=api)

    instrument_investigation_resource = instrument_investigation_endpoint(
        backend, client_pool=icat_client_pool,
    )
    api.add_resource(
        instrument_investigation_resource,
        "/instruments/<int:instrument_id>/facilitycycles/<int:cycle_id>/investigations",
    )
    spec.path(resource=instrument_investigation_resource, api=api)

    count_instrument_investigation_resource = count_instrument_investigation_endpoint(
        backend, client_pool=icat_client_pool,
    )
    api.add_resource(
        count_instrument_investigation_resource,
        "/instruments/<int:instrument_id>/facilitycycles/<int:cycle_id>/investigations"
        "/count",
    )
    spec.path(resource=count_instrument_investigation_resource, api=api)
 def test_valid_cache_creation(self):
     test_cache = ExtendedLRUCache()
     assert test_cache.maxsize == config.get_config_value(
         APIConfigOptions.CLIENT_CACHE_SIZE, )
Ejemplo n.º 12
0
import logging

from flask import Flask

from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.logger_setup import setup_logger
from datagateway_api.src.api_start_utils import (
    create_api_endpoints,
    create_app_infrastructure,
    create_openapi_endpoint,
    openapi_config,
)

setup_logger()
log = logging.getLogger()
log.info("Logging now setup")

app = Flask(__name__)
api, spec = create_app_infrastructure(app)
create_api_endpoints(app, api, spec)
openapi_config(spec)
create_openapi_endpoint(app, spec)

if __name__ == "__main__":
    app.run(
        host=config.get_config_value(APIConfigOptions.HOST),
        port=config.get_config_value(APIConfigOptions.PORT),
        debug=config.get_config_value(APIConfigOptions.DEBUG_MODE),
        use_reloader=config.get_config_value(APIConfigOptions.FLASK_RELOADER),
    )
Ejemplo n.º 13
0
class TestSessionHandling:
    def test_get_valid_session_details(
        self,
        flask_test_app_icat,
        valid_icat_credentials_header,
    ):
        session_details = flask_test_app_icat.get(
            "/sessions",
            headers=valid_icat_credentials_header,
        )

        session_expiry_datetime = DateHandler.str_to_datetime_object(
            session_details.json["expireDateTime"], )

        current_datetime = datetime.now(tzlocal())
        time_diff = abs(session_expiry_datetime - current_datetime)
        time_diff_minutes = time_diff.seconds / 60

        # Allows a bit of leeway for slow test execution
        assert time_diff_minutes < 120 and time_diff_minutes >= 118

        # Check username is correct
        test_mechanism = config.get_config_value(
            APIConfigOptions.TEST_MECHANISM)
        test_username = config.get_config_value(
            APIConfigOptions.TEST_USER_CREDENTIALS)["username"]
        assert session_details.json[
            "username"] == f"{test_mechanism}/{test_username}"

        # Check session ID matches the header from the request
        assert (session_details.json["id"] ==
                valid_icat_credentials_header["Authorization"].split()[1])

    def test_get_invalid_session_details(
        self,
        bad_credentials_header,
        flask_test_app_icat,
    ):
        session_details = flask_test_app_icat.get(
            "/sessions",
            headers=bad_credentials_header,
        )

        assert session_details.status_code == 403

    def test_refresh_session(self, valid_icat_credentials_header,
                             flask_test_app_icat):
        pre_refresh_session_details = flask_test_app_icat.get(
            "/sessions",
            headers=valid_icat_credentials_header,
        )

        refresh_session = flask_test_app_icat.put(
            "/sessions",
            headers=valid_icat_credentials_header,
        )

        post_refresh_session_details = flask_test_app_icat.get(
            "/sessions",
            headers=valid_icat_credentials_header,
        )

        assert refresh_session.status_code == 200

        assert (pre_refresh_session_details.json["expireDateTime"] !=
                post_refresh_session_details.json["expireDateTime"])

    @pytest.mark.usefixtures("single_investigation_test_data")
    @pytest.mark.parametrize(
        "request_body",
        [
            pytest.param(
                {
                    "username":
                    config.get_config_value(
                        APIConfigOptions.TEST_USER_CREDENTIALS, )["username"],
                    "password":
                    config.get_config_value(
                        APIConfigOptions.TEST_USER_CREDENTIALS, )["password"],
                    "mechanism":
                    config.get_config_value(APIConfigOptions.TEST_MECHANISM, ),
                },
                id="Normal request body",
            ),
            pytest.param(
                {
                    "username":
                    config.get_config_value(
                        APIConfigOptions.TEST_USER_CREDENTIALS, )["username"],
                    "password":
                    config.get_config_value(
                        APIConfigOptions.TEST_USER_CREDENTIALS, )["password"],
                },
                id="Missing mechanism in request body",
            ),
        ],
    )
    def test_valid_login(
        self,
        flask_test_app_icat,
        icat_client,
        icat_query,
        request_body,
    ):
        login_response = flask_test_app_icat.post("/sessions",
                                                  json=request_body)

        icat_client.sessionId = login_response.json["sessionID"]
        icat_query.setAggregate("COUNT")
        title_filter = PythonICATWhereFilter(
            "title",
            "Test data for the Python ICAT Backend on DataGateway API",
            "like",
        )
        title_filter.apply_filter(icat_query)

        test_query = icat_client.search(icat_query)

        assert test_query == [1] and login_response.status_code == 201

    @pytest.mark.parametrize(
        "request_body, expected_response_code",
        [
            pytest.param(
                {
                    "username":
                    "******",
                    "password":
                    "******",
                    "mechanism":
                    config.get_config_value(APIConfigOptions.TEST_MECHANISM, ),
                },
                403,
                id="Invalid credentials",
            ),
            pytest.param({}, 400, id="Missing credentials"),
        ],
    )
    def test_invalid_login(
        self,
        flask_test_app_icat,
        request_body,
        expected_response_code,
    ):
        login_response = flask_test_app_icat.post("/sessions",
                                                  json=request_body)

        assert login_response.status_code == expected_response_code

    def test_valid_logout(self, flask_test_app_icat):
        client = Client(
            config.get_config_value(APIConfigOptions.ICAT_URL),
            checkCert=config.get_config_value(
                APIConfigOptions.ICAT_CHECK_CERT),
        )
        client.login(
            config.get_config_value(APIConfigOptions.TEST_MECHANISM),
            config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS),
        )
        creds_header = {"Authorization": f"Bearer {client.sessionId}"}

        logout_response = flask_test_app_icat.delete("/sessions",
                                                     headers=creds_header)

        assert logout_response.status_code == 200

    def test_invalid_logout(self, bad_credentials_header, flask_test_app_icat):
        logout_response = flask_test_app_icat.delete(
            "/sessions",
            headers=bad_credentials_header,
        )

        assert logout_response.status_code == 403
Ejemplo n.º 14
0
 def __init__(self):
     super().__init__(maxsize=config.get_config_value(
         APIConfigOptions.CLIENT_CACHE_SIZE), )
    "--years",
    "-y",
    dest="years",
    help="Provide number of years to generate",
    type=int,
    default=20,
)
args = parser.parse_args()
SEED = args.seed
YEARS = args.years  # 4 Cycles per years generated
faker = Faker()
Faker.seed(SEED)


engine = create_engine(
    config.get_config_value(APIConfigOptions.DB_URL),
    poolclass=QueuePool,
    pool_size=100,
    max_overflow=0,
)
session_factory = sessionmaker(engine)
session = scoped_session(session_factory)()


def post_entity(entity):
    """
    Given an entity, insert it into the ICAT DB
    :param entity:  The entity to be inserted
    :return: None
    """
    session.add(entity)