def _get_user_from_sso(jwt_token, token): try: data = jwt.decode(jwt_token, Config.jwt_secret, algorithms=['HS256']) username = data["username"] name = data.get("firstName", username) surname = data.get("lastName", "") if username != token: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Use the same username from the SSO") if Database.get_user(username) is None: Database.begin() Database.add_user(username, name, surname, sso_user=True, autocommit=False) for task in Database.get_tasks(): Database.add_user_task(username, task["name"], autocommit=False) Database.commit() Logger.info("NEW_USER", "User %s created from SSO" % username) return Database.get_user(username) except jwt.exceptions.DecodeError: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Please login at %s" % Config.sso_url)
def test_end_time(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_end_time(20), 1190) self.assertEqual(BaseHandler.get_end_time(0), 1170)
def upload_source(self, input, file): """ POST /upload_source """ alerts = [] if get_exeflags(file["content"]): alerts.append({ "severity": "warning", "message": "You have submitted an executable! Please send the " "source code." }) Logger.info("UPLOAD", "User %s has uploaded an executable" % input["token"]) if not alerts: alerts.append({ "severity": "success", "message": "Source file uploaded correctly." }) source_id = Database.gen_id() try: path = StorageManager.new_source_file(source_id, file["name"]) except ValueError: BaseHandler.raise_exc(BadRequest, "INVALID_FILENAME", "The provided file has an invalid name") StorageManager.save_file(path, file["content"]) file_size = StorageManager.get_file_size(path) Database.add_source(source_id, input["id"], path, file_size) Logger.info("UPLOAD", "User %s has uploaded the source %s" % ( input["token"], source_id)) output = BaseHandler.format_dates(Database.get_source(source_id)) output["validation"] = {"alerts": alerts} return output
def upload_output(self, input, file): """ POST /upload_output """ output_id = Database.gen_id() try: path = StorageManager.new_output_file(output_id, file["name"]) except ValueError: BaseHandler.raise_exc(BadRequest, "INVALID_FILENAME", "The provided file has an invalid name") StorageManager.save_file(path, file["content"]) file_size = StorageManager.get_file_size(path) try: result = ContestManager.evaluate_output(input["task"], input["path"], path) except: BaseHandler.raise_exc(InternalServerError, "INTERNAL_ERROR", "Failed to evaluate the output") Database.add_output(output_id, input["id"], path, file_size, result) Logger.info( "UPLOAD", "User %s has uploaded the output %s" % (input["token"], output_id)) return InfoHandler.patch_output(Database.get_output(output_id))
def test_window_end_time(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) Database.set_meta("window_duration", 100) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_window_end_time(10, 20), 1130) self.assertEqual(BaseHandler.get_window_end_time(0, 0), 1100)
def handle(*args, **kwargs): if param in kwargs: thing = getter(kwargs[param]) if thing is None: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "No such " + name) del kwargs[param] else: thing = None kwargs[name] = thing return handler(*args, **kwargs)
def _ensure_contest_started(): start_timestamp = Database.get_meta("start_time", type=int) start_datetime = ( datetime.fromtimestamp(start_timestamp, timezone.utc) if start_timestamp is not None else None ) if not start_datetime or start_datetime > datetime.now(timezone.utc): BaseHandler.raise_exc( Forbidden, "FORBIDDEN", "The contest has not started yet" )
def start(self): """ POST /admin/start """ if Database.get_meta("start_time", default=None, type=int) is not None: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Contest has already been started!") start_time = int(time.time()) Database.set_meta("start_time", start_time) Logger.info("CONTEST", "Contest started") return BaseHandler.format_dates({"start_time": start_time}, fields=["start_time"])
def _ensure_contest_running(token=None): """ Makes sure that the contest is running for the user, if any. If the user has a time window it is also checked. :param token: The optional token of the user. """ extra_time = None start_delay = None if token: Validators._ensure_window_start(token) user = Database.get_user(token) if user: extra_time = user["extra_time"] start_delay = user["contest_start_delay"] if Database.get_meta("start_time") is None: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The contest has not started yet") contest_end = BaseHandler.get_end_time(extra_time) window_end = BaseHandler.get_window_end_time(extra_time, start_delay) now = time.time() # check the contest is not finished if contest_end < now: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The contest has ended") # if a window is present check it's not finished if window_end and window_end < now: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Your window has ended")
def patch_submission(submission): """ Given a submission from a SQL query with some JOIN create a dict by splitting the keys using _ :param submission: A dict with the submission :return: A dict with some properties nested """ result = {} for k, v in submission.items(): if "_" in k: a, b = k.split("_") if a not in result: result[a] = {} result[a][b] = v else: result[k] = v feedback = json.loads(result["output"]["result"].decode()) result["feedback"] = feedback["feedback"] temp = InfoHandler.patch_output(result["output"]) del result["output"] result = BaseHandler.format_dates(result) result["output"] = temp return result
def upload_pack(self, file): """ POST /admin/upload_pack """ if Database.get_meta("admin_token"): BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The pack has already been extracted") elif os.path.exists(Config.encrypted_file): BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The pack has already been uploaded") if not crypto.validate(file["content"]): self.raise_exc(Forbidden, "BAD_FILE", "The uploaded file is " "not valid") StorageManager.save_file(os.path.realpath(Config.encrypted_file), file["content"]) return {}
def status(self): """ POST /admin/status """ start_time = Database.get_meta('start_time', type=int) extra_time = Database.get_meta('extra_time', type=int, default=0) end_time = BaseHandler.get_end_time(0) return BaseHandler.format_dates( { "start_time": start_time, "extra_time": extra_time, "end_time": end_time, "loaded": ContestManager.has_contest }, fields=["start_time", "end_time"])
def generate_input(self, task, user): """ POST /generate_input """ token = user["token"] if Database.get_user_task(token, task["name"])["current_attempt"]: self.raise_exc(Forbidden, "FORBIDDEN", "You already have a ready input!") attempt = Database.get_next_attempt(token, task["name"]) id, path = ContestManager.get_input(task["name"], attempt) size = StorageManager.get_file_size(path) Database.begin() try: Database.add_input(id, token, task["name"], attempt, path, size, autocommit=False) Database.set_user_attempt(token, task["name"], attempt, autocommit=False) Database.commit() except: Database.rollback() raise Logger.info( "CONTEST", "Generated input %s for user %s on task %s" % (id, token, task["name"])) return BaseHandler.format_dates(Database.get_input(id=id))
def test_window_end_time_no_window(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) # Database.set_meta("window_duration", nope) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_window_end_time(20, 42), None)
def test_get_file_content(self): request = Request(Environ()) stream = _io.BytesIO("hello world".encode()) request.files = {"file": FileStorage(stream=stream, filename="foo")} self.assertEqual("hello world", BaseHandler._get_file_content(request).decode())
def test_get_ip_3_proxies(self): Config.num_proxies = 3 headers = {"X-Forwarded-For": "1.2.3.4, 5.6.7.8, 8.8.8.8"} env = EnvironBuilder(headers=headers).get_environ() env["REMOTE_ADDR"] = "6.6.6.6" request = Request(env) ip = BaseHandler.get_ip(request) self.assertEqual("1.2.3.4", ip)
def log(self, start_date: str, end_date: str, level: str, category: str = None): """ POST /admin/log """ if level not in Logger.HUMAN_MESSAGES: self.raise_exc(BadRequest, "INVALID_PARAMETER", "The level provided is invalid") level = Logger.HUMAN_MESSAGES.index(level) try: start_date = dateutil.parser.parse(start_date).timestamp() end_date = dateutil.parser.parse(end_date).timestamp() except ValueError as e: BaseHandler.raise_exc(BadRequest, "INVALID_PARAMETER", str(e)) return BaseHandler.format_dates( {"items": Logger.get_logs(level, category, start_date, end_date)})
def start(self, start_time: str): """ POST /admin/start """ previous_start = Database.get_meta("start_time", type=int) now = int(time.time()) if previous_start and now > previous_start: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Contest has already been started!") actual_start = None if start_time == "reset": Database.del_meta("start_time") return {"start_time": None} elif start_time == "now": actual_start = now else: actual_start = dateutil.parser.parse(start_time).timestamp() Database.set_meta("start_time", int(actual_start)) Logger.info("CONTEST", "Contest starts at " + str(actual_start)) return BaseHandler.format_dates({"start_time": actual_start}, fields=["start_time"])
def _validate_admin_token(token, ip): """ Ensure the admin token is valid :param token: Token to check :param ip: IP of the client """ correct_token = Database.get_meta("admin_token") if correct_token is None: ContestManager.extract_contest(token) ContestManager.read_from_disk() correct_token = token if token != correct_token: Logger.warning("LOGIN_ADMIN", "Admin login failed from %s" % ip) BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Invalid admin token!") else: if Database.register_admin_ip(ip): Logger.info("LOGIN_ADMIN", "An admin has connected from a new ip: %s" % ip)
def wsgi_app(self, environ, start_response): route = self.router.bind_to_environ(environ) request = Request(environ) try: endpoint, args = route.match() except HTTPException: Logger.warning("HTTP_ERROR", "%s %s %s 404" % (BaseHandler.get_ip(request), request.method, request.url)) return NotFound() controller, action = endpoint.split("#") return self.handlers[controller].handle(action, args, request)
def get_user(self, user): """ GET /user/<token> """ token = user["token"] user["contest"] = self.get_contest() if not user["contest"]["has_started"]: del user["extra_time"] return user end_time = InfoHandler.get_end_time(user["extra_time"]) if user["contest_start_delay"] is not None: end_time = min( end_time, InfoHandler.get_window_end_time(user["extra_time"], user["contest_start_delay"]), ) user["end_time"] = end_time del user["extra_time"] user["tasks"] = {} tasks = Database.get_user_task(token) for task in tasks: task_name = task["task"] if task["current_attempt"] is not None: current_input = Database.get_input( token=token, task=task_name, attempt=task["current_attempt"]) else: current_input = None user["tasks"][task_name] = { "name": task_name, "score": task["score"], "current_input": current_input, } user["total_score"] = sum(task["score"] for task in tasks) return BaseHandler.format_dates(user, fields=["end_time"])
def patch_output(output): """ Given an output remove the private fields :param output: A dict from the outputs database table :return: The formatted and sanitized dict """ result = { "id": output["id"], "date": output["date"], "path": output["path"], "size": output["size"], "validation": json.loads(output["result"].decode())["validation"], } if "input" in output: result["input"] = output["input"] return BaseHandler.format_dates(result)
def handle(*args, **kwargs): if "_request" in kwargs: request = kwargs["_request"] jwt_token = request.cookies.get("token", None) else: jwt_token = None token = kwargs["token"] user = Database.get_user(token) if not user and not Config.jwt_secret: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "No such user") elif not user and Config.jwt_secret and jwt_token: kwargs["user"] = Validators._get_user_from_sso( jwt_token, token) elif not user and Config.jwt_secret and not jwt_token: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Please login at %s" % Config.sso_url) elif not Config.jwt_secret and not user["sso_user"]: kwargs["user"] = user elif not Config.jwt_secret and user["sso_user"]: BaseHandler.raise_exc( Forbidden, "FORBIDDEN", "No login method available for this user") elif Config.jwt_secret and not user["sso_user"]: kwargs["user"] = user elif Config.jwt_secret and user["sso_user"]: kwargs["user"] = Validators._get_user_from_sso( jwt_token, token) else: BaseHandler.raise_exc( BadRequest, # pragma: nocover "INTERNAL_ERROR", "Login failed") # makes sure the window starts if Validators._ensure_window_start(token): kwargs["user"] = Database.get_user(token) del kwargs["token"] if "_request" in kwargs: del kwargs["_request"] return handler(*args, **kwargs)
def test_format_dates(self): dct = { "date": 12345678, "nondate": 12345678, "we": { "need": { "to": { "go": { "deeper": 1010101010 } } } } } formatted = BaseHandler.format_dates(dct, fields=["date", "deeper"]) self.assertEqual( datetime.datetime.fromtimestamp(12345678).isoformat(), formatted["date"]) self.assertEqual(12345678, formatted["nondate"]) self.assertEqual( datetime.datetime.fromtimestamp(1010101010).isoformat(), formatted["we"]["need"]["to"]["go"]["deeper"])
def test_get_ip_no_proxies(self): Config.num_proxies = 0 request = Request(Environ(REMOTE_ADDR="1.2.3.4")) ip = BaseHandler.get_ip(request) self.assertEqual("1.2.3.4", ip)
class TestBaseHandler(unittest.TestCase): def setUp(self): Utils.prepare_test() self.handler = BaseHandler() self.log_backup = Logger.LOG_LEVEL Logger.LOG_LEVEL = 9001 # disable the logs def tearDown(self): Logger.LOG_LEVEL = self.log_backup def test_raise_exc(self): with self.assertRaises(Forbidden) as ex: self.handler.raise_exc(Forbidden, "EX_CODE", "Ex message") response = ex.exception.response self.assertEqual("application/json", response.mimetype) data = json.loads(response.data.decode()) self.assertEqual("EX_CODE", data["code"]) self.assertEqual("Ex message", data["message"]) class DummyHandler(BaseHandler): def dummy_endpoint(self, param: int = 123): return {"incremented": param + 1} def required(self, param): self.raise_exc(Forbidden, "NOBUONO", "nononono") def myip(self, _ip): return _ip def file(self, file): return file["name"] @Validators.validate_input_id @Validators.validate_output_id def with_decorators(self, input, output): pass @patch("terry.handlers.base_handler.BaseHandler._call", return_value={"foo": "bar"}) def test_handle(self, call_mock): handler = TestBaseHandler.DummyHandler() response = handler.handle("dummy_endpoint", 42, 123) endpoint = handler.dummy_endpoint call_mock.assert_called_once_with(endpoint, 42, 123) self.assertEqual(200, response.code) self.assertEqual("application/json", response.mimetype) self.assertDictEqual({"foo": "bar"}, json.loads(response.data.decode())) @patch("terry.handlers.base_handler.BaseHandler._call", return_value=None) def test_handle_no_content(self, call_mock): handler = TestBaseHandler.DummyHandler() response = handler.handle("dummy_endpoint", 42, 123) self.assertEqual(204, response.code) @patch("terry.handlers.base_handler.BaseHandler._call", side_effect=Forbidden()) def test_handle_exceptions(self, call_mock): handler = TestBaseHandler.DummyHandler() response = handler.handle("dummy_endpoint", 42, 123) self.assertIsInstance(response, Forbidden) def test_parse_body(self): request = Request({}) request.form = {"foo": "bar"} body = self.handler.parse_body(request) self.assertEqual("bar", body["foo"]) def test_end_time(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_end_time(20), 1190) self.assertEqual(BaseHandler.get_end_time(0), 1170) def test_end_time_not_started(self): self.assertIsNone(BaseHandler.get_end_time(0)) def test_window_end_time(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) Database.set_meta("window_duration", 100) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_window_end_time(10, 20), 1130) self.assertEqual(BaseHandler.get_window_end_time(0, 0), 1100) def test_window_end_time_no_window(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) # Database.set_meta("window_duration", nope) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_window_end_time(20, 42), None) def test_format_dates(self): dct = { "date": 12345678, "nondate": 12345678, "we": { "need": { "to": { "go": { "deeper": 1010101010 } } } }, } formatted = BaseHandler.format_dates(dct, fields=["date", "deeper"]) self.assertEqual( datetime.fromtimestamp(12345678, timezone.utc).isoformat(), formatted["date"], ) self.assertEqual(12345678, formatted["nondate"]) self.assertEqual( datetime.fromtimestamp(1010101010, timezone.utc).isoformat(), formatted["we"]["need"]["to"]["go"]["deeper"], ) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value="1.2.3.4") @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) request.form = {"param": 42} res = handler._call(handler.dummy_endpoint, {}, request) self.assertEqual(43, res["incremented"]) Logger.c.execute("SELECT * FROM logs WHERE category = 'HTTP'") row = Logger.c.fetchone() self.assertIn("1.2.3.4", row[3]) self.assertIn("dummy_endpoint", row[3]) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_default(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) res = handler._call(handler.dummy_endpoint, {}, request) self.assertEqual(124, res["incremented"]) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_cast_parameter(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) request.form = {"param": "42"} res = handler._call(handler.dummy_endpoint, {}, request) self.assertEqual(43, res["incremented"]) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_fail_cast_parameter(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) request.form = {"param": "nope"} with self.assertRaises(BadRequest): handler._call(handler.dummy_endpoint, {}, request) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_route_args(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) res = handler._call(handler.dummy_endpoint, {"param": "42"}, request) self.assertEqual(43, res["incremented"]) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_required_args(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) with self.assertRaises(BadRequest) as ex: handler._call(handler.required, {}, request) response = ex.exception.response self.assertIn("MISSING_PARAMETER", response.data.decode()) self.assertIn("param", response.data.decode()) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_with_error(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) with self.assertRaises(Forbidden) as ex: handler._call(handler.required, {"param": 42}, request) response = ex.exception.response self.assertIn("NOBUONO", response.data.decode()) self.assertIn("nononono", response.data.decode()) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value="1.2.3.4") @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_general_attrs(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) res = handler._call(handler.myip, {}, request) self.assertEqual("1.2.3.4", res) def test_call_file(self): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) request.files = {"file": FileStorage(filename="foo")} res = handler._call(handler.file, {}, request) self.assertEqual("foo", res) def test_call_with_decorators(self): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) Database.add_user("token", "", "") Database.add_task("poldo", "", "", 1, 1) Database.add_input("inputid", "token", "poldo", 1, "", 42) Database.add_output("outputid", "inputid", "", 42, "") handler._call( handler.with_decorators, { "input_id": "inputid", "output_id": "outputid" }, request, ) def test_get_file_name(self): request = Request(Environ()) request.files = {"file": FileStorage(filename="foo")} self.assertEqual("foo", BaseHandler._get_file_name(request)) def test_get_file_name_no_file(self): request = Request(Environ()) request.files = {} self.assertIsNone(BaseHandler._get_file_name(request)) def test_get_file_content(self): request = Request(Environ()) stream = _io.BytesIO("hello world".encode()) request.files = {"file": FileStorage(stream=stream, filename="foo")} self.assertEqual("hello world", BaseHandler._get_file_content(request).decode()) def test_get_file_content_no_file(self): request = Request(Environ()) request.files = {} self.assertIsNone(BaseHandler._get_file_content(request)) def test_get_ip_no_proxies(self): Config.num_proxies = 0 request = Request(Environ(REMOTE_ADDR="1.2.3.4")) ip = BaseHandler.get_ip(request) self.assertEqual("1.2.3.4", ip) def test_get_ip_3_proxies(self): Config.num_proxies = 3 headers = {"X-Forwarded-For": "1.2.3.4, 5.6.7.8, 8.8.8.8"} env = EnvironBuilder(headers=headers).get_environ() env["REMOTE_ADDR"] = "6.6.6.6" request = Request(env) ip = BaseHandler.get_ip(request) self.assertEqual("1.2.3.4", ip)
def _ensure_contest_started(): if Database.get_meta("start_time") is None: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The contest has not started yet")
def setUp(self): Utils.prepare_test() self.handler = BaseHandler() self.log_backup = Logger.LOG_LEVEL Logger.LOG_LEVEL = 9001 # disable the logs
def test_get_file_name(self): request = Request(Environ()) request.files = {"file": FileStorage(filename="foo")} self.assertEqual("foo", BaseHandler._get_file_name(request))
def test_get_file_content_no_file(self): request = Request(Environ()) request.files = {} self.assertIsNone(BaseHandler._get_file_content(request))