def _get_user_from_sso(jwt_token, token): try: data = jwt.decode(jwt_token, Config.jwt_secret, algorithms=['HS256']) username = data["username"] name = data.get("firstName", username) surname = data.get("lastName", "") if username != token: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Use the same username from the SSO") if Database.get_user(username) is None: Database.begin() Database.add_user(username, name, surname, sso_user=True, autocommit=False) for task in Database.get_tasks(): Database.add_user_task(username, task["name"], autocommit=False) Database.commit() Logger.info("NEW_USER", "User %s created from SSO" % username) return Database.get_user(username) except jwt.exceptions.DecodeError: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Please login at %s" % Config.sso_url)
def upload_source(self, input, file): """ POST /upload_source """ alerts = [] if get_exeflags(file["content"]): alerts.append({ "severity": "warning", "message": "You have submitted an executable! Please send the " "source code." }) Logger.info("UPLOAD", "User %s has uploaded an executable" % input["token"]) if not alerts: alerts.append({ "severity": "success", "message": "Source file uploaded correctly." }) source_id = Database.gen_id() try: path = StorageManager.new_source_file(source_id, file["name"]) except ValueError: BaseHandler.raise_exc(BadRequest, "INVALID_FILENAME", "The provided file has an invalid name") StorageManager.save_file(path, file["content"]) file_size = StorageManager.get_file_size(path) Database.add_source(source_id, input["id"], path, file_size) Logger.info("UPLOAD", "User %s has uploaded the source %s" % ( input["token"], source_id)) output = BaseHandler.format_dates(Database.get_source(source_id)) output["validation"] = {"alerts": alerts} return output
def upload_output(self, input, file): """ POST /upload_output """ output_id = Database.gen_id() try: path = StorageManager.new_output_file(output_id, file["name"]) except ValueError: BaseHandler.raise_exc(BadRequest, "INVALID_FILENAME", "The provided file has an invalid name") StorageManager.save_file(path, file["content"]) file_size = StorageManager.get_file_size(path) try: result = ContestManager.evaluate_output(input["task"], input["path"], path) except: BaseHandler.raise_exc(InternalServerError, "INTERNAL_ERROR", "Failed to evaluate the output") Database.add_output(output_id, input["id"], path, file_size, result) Logger.info( "UPLOAD", "User %s has uploaded the output %s" % (input["token"], output_id)) return InfoHandler.patch_output(Database.get_output(output_id))
def handle(*args, **kwargs): if param in kwargs: thing = getter(kwargs[param]) if thing is None: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "No such " + name) del kwargs[param] else: thing = None kwargs[name] = thing return handler(*args, **kwargs)
def _ensure_contest_started(): start_timestamp = Database.get_meta("start_time", type=int) start_datetime = ( datetime.fromtimestamp(start_timestamp, timezone.utc) if start_timestamp is not None else None ) if not start_datetime or start_datetime > datetime.now(timezone.utc): BaseHandler.raise_exc( Forbidden, "FORBIDDEN", "The contest has not started yet" )
def start(self): """ POST /admin/start """ if Database.get_meta("start_time", default=None, type=int) is not None: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Contest has already been started!") start_time = int(time.time()) Database.set_meta("start_time", start_time) Logger.info("CONTEST", "Contest started") return BaseHandler.format_dates({"start_time": start_time}, fields=["start_time"])
def upload_pack(self, file): """ POST /admin/upload_pack """ if Database.get_meta("admin_token"): BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The pack has already been extracted") elif os.path.exists(Config.encrypted_file): BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The pack has already been uploaded") if not crypto.validate(file["content"]): self.raise_exc(Forbidden, "BAD_FILE", "The uploaded file is " "not valid") StorageManager.save_file(os.path.realpath(Config.encrypted_file), file["content"]) return {}
def log(self, start_date: str, end_date: str, level: str, category: str = None): """ POST /admin/log """ if level not in Logger.HUMAN_MESSAGES: self.raise_exc(BadRequest, "INVALID_PARAMETER", "The level provided is invalid") level = Logger.HUMAN_MESSAGES.index(level) try: start_date = dateutil.parser.parse(start_date).timestamp() end_date = dateutil.parser.parse(end_date).timestamp() except ValueError as e: BaseHandler.raise_exc(BadRequest, "INVALID_PARAMETER", str(e)) return BaseHandler.format_dates( {"items": Logger.get_logs(level, category, start_date, end_date)})
def _ensure_contest_running(token=None): """ Makes sure that the contest is running for the user, if any. If the user has a time window it is also checked. :param token: The optional token of the user. """ extra_time = None start_delay = None if token: Validators._ensure_window_start(token) user = Database.get_user(token) if user: extra_time = user["extra_time"] start_delay = user["contest_start_delay"] if Database.get_meta("start_time") is None: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The contest has not started yet") contest_end = BaseHandler.get_end_time(extra_time) window_end = BaseHandler.get_window_end_time(extra_time, start_delay) now = time.time() # check the contest is not finished if contest_end < now: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The contest has ended") # if a window is present check it's not finished if window_end and window_end < now: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Your window has ended")
def _validate_admin_token(token, ip): """ Ensure the admin token is valid :param token: Token to check :param ip: IP of the client """ correct_token = Database.get_meta("admin_token") if correct_token is None: ContestManager.extract_contest(token) ContestManager.read_from_disk() correct_token = token if token != correct_token: Logger.warning("LOGIN_ADMIN", "Admin login failed from %s" % ip) BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Invalid admin token!") else: if Database.register_admin_ip(ip): Logger.info("LOGIN_ADMIN", "An admin has connected from a new ip: %s" % ip)
def start(self, start_time: str): """ POST /admin/start """ previous_start = Database.get_meta("start_time", type=int) now = int(time.time()) if previous_start and now > previous_start: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Contest has already been started!") actual_start = None if start_time == "reset": Database.del_meta("start_time") return {"start_time": None} elif start_time == "now": actual_start = now else: actual_start = dateutil.parser.parse(start_time).timestamp() Database.set_meta("start_time", int(actual_start)) Logger.info("CONTEST", "Contest starts at " + str(actual_start)) return BaseHandler.format_dates({"start_time": actual_start}, fields=["start_time"])
def handle(*args, **kwargs): if "_request" in kwargs: request = kwargs["_request"] jwt_token = request.cookies.get("token", None) else: jwt_token = None token = kwargs["token"] user = Database.get_user(token) if not user and not Config.jwt_secret: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "No such user") elif not user and Config.jwt_secret and jwt_token: kwargs["user"] = Validators._get_user_from_sso( jwt_token, token) elif not user and Config.jwt_secret and not jwt_token: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "Please login at %s" % Config.sso_url) elif not Config.jwt_secret and not user["sso_user"]: kwargs["user"] = user elif not Config.jwt_secret and user["sso_user"]: BaseHandler.raise_exc( Forbidden, "FORBIDDEN", "No login method available for this user") elif Config.jwt_secret and not user["sso_user"]: kwargs["user"] = user elif Config.jwt_secret and user["sso_user"]: kwargs["user"] = Validators._get_user_from_sso( jwt_token, token) else: BaseHandler.raise_exc( BadRequest, # pragma: nocover "INTERNAL_ERROR", "Login failed") # makes sure the window starts if Validators._ensure_window_start(token): kwargs["user"] = Database.get_user(token) del kwargs["token"] if "_request" in kwargs: del kwargs["_request"] return handler(*args, **kwargs)
def _ensure_contest_started(): if Database.get_meta("start_time") is None: BaseHandler.raise_exc(Forbidden, "FORBIDDEN", "The contest has not started yet")
class TestBaseHandler(unittest.TestCase): def setUp(self): Utils.prepare_test() self.handler = BaseHandler() self.log_backup = Logger.LOG_LEVEL Logger.LOG_LEVEL = 9001 # disable the logs def tearDown(self): Logger.LOG_LEVEL = self.log_backup def test_raise_exc(self): with self.assertRaises(Forbidden) as ex: self.handler.raise_exc(Forbidden, "EX_CODE", "Ex message") response = ex.exception.response self.assertEqual("application/json", response.mimetype) data = json.loads(response.data.decode()) self.assertEqual("EX_CODE", data["code"]) self.assertEqual("Ex message", data["message"]) class DummyHandler(BaseHandler): def dummy_endpoint(self, param: int = 123): return {"incremented": param + 1} def required(self, param): self.raise_exc(Forbidden, "NOBUONO", "nononono") def myip(self, _ip): return _ip def file(self, file): return file["name"] @Validators.validate_input_id @Validators.validate_output_id def with_decorators(self, input, output): pass @patch("terry.handlers.base_handler.BaseHandler._call", return_value={"foo": "bar"}) def test_handle(self, call_mock): handler = TestBaseHandler.DummyHandler() response = handler.handle("dummy_endpoint", 42, 123) endpoint = handler.dummy_endpoint call_mock.assert_called_once_with(endpoint, 42, 123) self.assertEqual(200, response.code) self.assertEqual("application/json", response.mimetype) self.assertDictEqual({"foo": "bar"}, json.loads(response.data.decode())) @patch("terry.handlers.base_handler.BaseHandler._call", return_value=None) def test_handle_no_content(self, call_mock): handler = TestBaseHandler.DummyHandler() response = handler.handle("dummy_endpoint", 42, 123) self.assertEqual(204, response.code) @patch("terry.handlers.base_handler.BaseHandler._call", side_effect=Forbidden()) def test_handle_exceptions(self, call_mock): handler = TestBaseHandler.DummyHandler() response = handler.handle("dummy_endpoint", 42, 123) self.assertIsInstance(response, Forbidden) def test_parse_body(self): request = Request({}) request.form = {"foo": "bar"} body = self.handler.parse_body(request) self.assertEqual("bar", body["foo"]) def test_end_time(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_end_time(20), 1190) self.assertEqual(BaseHandler.get_end_time(0), 1170) def test_end_time_not_started(self): self.assertIsNone(BaseHandler.get_end_time(0)) def test_window_end_time(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) Database.set_meta("window_duration", 100) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_window_end_time(10, 20), 1130) self.assertEqual(BaseHandler.get_window_end_time(0, 0), 1100) def test_window_end_time_no_window(self): Database.set_meta("start_time", 1000) Database.set_meta("contest_duration", 150) # Database.set_meta("window_duration", nope) Database.set_meta("extra_time", 20) self.assertEqual(BaseHandler.get_window_end_time(20, 42), None) def test_format_dates(self): dct = { "date": 12345678, "nondate": 12345678, "we": { "need": { "to": { "go": { "deeper": 1010101010 } } } }, } formatted = BaseHandler.format_dates(dct, fields=["date", "deeper"]) self.assertEqual( datetime.fromtimestamp(12345678, timezone.utc).isoformat(), formatted["date"], ) self.assertEqual(12345678, formatted["nondate"]) self.assertEqual( datetime.fromtimestamp(1010101010, timezone.utc).isoformat(), formatted["we"]["need"]["to"]["go"]["deeper"], ) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value="1.2.3.4") @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) request.form = {"param": 42} res = handler._call(handler.dummy_endpoint, {}, request) self.assertEqual(43, res["incremented"]) Logger.c.execute("SELECT * FROM logs WHERE category = 'HTTP'") row = Logger.c.fetchone() self.assertIn("1.2.3.4", row[3]) self.assertIn("dummy_endpoint", row[3]) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_default(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) res = handler._call(handler.dummy_endpoint, {}, request) self.assertEqual(124, res["incremented"]) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_cast_parameter(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) request.form = {"param": "42"} res = handler._call(handler.dummy_endpoint, {}, request) self.assertEqual(43, res["incremented"]) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_fail_cast_parameter(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) request.form = {"param": "nope"} with self.assertRaises(BadRequest): handler._call(handler.dummy_endpoint, {}, request) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_route_args(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) res = handler._call(handler.dummy_endpoint, {"param": "42"}, request) self.assertEqual(43, res["incremented"]) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_required_args(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) with self.assertRaises(BadRequest) as ex: handler._call(handler.required, {}, request) response = ex.exception.response self.assertIn("MISSING_PARAMETER", response.data.decode()) self.assertIn("param", response.data.decode()) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_with_error(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) with self.assertRaises(Forbidden) as ex: handler._call(handler.required, {"param": 42}, request) response = ex.exception.response self.assertIn("NOBUONO", response.data.decode()) self.assertIn("nononono", response.data.decode()) @patch("terry.handlers.base_handler.BaseHandler.get_ip", return_value="1.2.3.4") @patch("terry.handlers.base_handler.BaseHandler._get_file_content", return_value=42) @patch("terry.handlers.base_handler.BaseHandler._get_file_name", return_value=42) def test_call_general_attrs(self, name_mock, content_mock, ip_mock): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) res = handler._call(handler.myip, {}, request) self.assertEqual("1.2.3.4", res) def test_call_file(self): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) request.files = {"file": FileStorage(filename="foo")} res = handler._call(handler.file, {}, request) self.assertEqual("foo", res) def test_call_with_decorators(self): handler = TestBaseHandler.DummyHandler() env = Environ({"wsgi.input": None}) request = Request(env) Database.add_user("token", "", "") Database.add_task("poldo", "", "", 1, 1) Database.add_input("inputid", "token", "poldo", 1, "", 42) Database.add_output("outputid", "inputid", "", 42, "") handler._call( handler.with_decorators, { "input_id": "inputid", "output_id": "outputid" }, request, ) def test_get_file_name(self): request = Request(Environ()) request.files = {"file": FileStorage(filename="foo")} self.assertEqual("foo", BaseHandler._get_file_name(request)) def test_get_file_name_no_file(self): request = Request(Environ()) request.files = {} self.assertIsNone(BaseHandler._get_file_name(request)) def test_get_file_content(self): request = Request(Environ()) stream = _io.BytesIO("hello world".encode()) request.files = {"file": FileStorage(stream=stream, filename="foo")} self.assertEqual("hello world", BaseHandler._get_file_content(request).decode()) def test_get_file_content_no_file(self): request = Request(Environ()) request.files = {} self.assertIsNone(BaseHandler._get_file_content(request)) def test_get_ip_no_proxies(self): Config.num_proxies = 0 request = Request(Environ(REMOTE_ADDR="1.2.3.4")) ip = BaseHandler.get_ip(request) self.assertEqual("1.2.3.4", ip) def test_get_ip_3_proxies(self): Config.num_proxies = 3 headers = {"X-Forwarded-For": "1.2.3.4, 5.6.7.8, 8.8.8.8"} env = EnvironBuilder(headers=headers).get_environ() env["REMOTE_ADDR"] = "6.6.6.6" request = Request(env) ip = BaseHandler.get_ip(request) self.assertEqual("1.2.3.4", ip)
def extract_contest(token): """ Decrypt and extract the contest and store the used admin token in the database """ if "-" not in token: BaseHandler.raise_exc(Forbidden, "WRONG_PASSWORD", "The provided password is malformed") try: username, password = token.split("-", 1) secret, scrambled_password = decode_data(password, SECRET_LEN) file_password = recover_file_password(username, secret, scrambled_password) except ValueError: BaseHandler.raise_exc(Forbidden, "WRONG_PASSWORD", "The provided password is malformed") try: with open(Config.encrypted_file, "rb") as encrypted_file: encrypted_data = encrypted_file.read() decrypted_data = decode(file_password, encrypted_data) with open(Config.decrypted_file, "wb") as decrypted_file: decrypted_file.write(decrypted_data) except FileNotFoundError: BaseHandler.raise_exc(NotFound, "NOT_FOUND", "The contest pack was not uploaded yet") except nacl.exceptions.CryptoError: BaseHandler.raise_exc(Forbidden, "WRONG_PASSWORD", "The provided password is wrong") except OSError as ex: BaseHandler.raise_exc(InternalServerError, "FAILED", str(ex)) zip_abs_path = os.path.realpath(Config.decrypted_file) wd = os.getcwd() try: os.makedirs(Config.contest_path, exist_ok=True) os.chdir(Config.contest_path) with zipfile.ZipFile(zip_abs_path) as f: f.extractall() real_yaml = os.path.join("__users__", username + ".yaml") if not os.path.exists(real_yaml): BaseHandler.raise_exc(Forbidden, "WRONG_PASSWORD", "Invalid username for the given pack") os.symlink(real_yaml, "contest.yaml") Logger.info("CONTEST", "Contest extracted") except zipfile.BadZipFile as ex: BaseHandler.raise_exc(Forbidden, "FAILED", str(ex)) finally: os.chdir(wd) Database.set_meta("admin_token", token)
def read_from_disk(remove_enc=True): """ Load a task from the disk and load the data into the database """ try: contest = ContestManager.import_contest(Config.contest_path) except FileNotFoundError as ex: error = ( "Contest not found, you probably need to unzip it. Missing file %s" % ex.filename) Logger.warning("CONTEST", error) shutil.rmtree(Config.statementdir, ignore_errors=True) shutil.rmtree(Config.web_statementdir, ignore_errors=True) shutil.rmtree(Config.contest_path, ignore_errors=True) if remove_enc: with suppress(Exception): os.remove(Config.encrypted_file) with suppress(Exception): os.remove(Config.decrypted_file) Database.del_meta("admin_token") BaseHandler.raise_exc(UnprocessableEntity, "CONTEST", error) if not Database.get_meta("contest_imported", default=False, type=bool): Database.begin() try: Database.set_meta("contest_duration", contest["duration"], autocommit=False) Database.set_meta("contest_name", contest.get("name", "Contest"), autocommit=False) Database.set_meta( "contest_description", contest.get("description", ""), autocommit=False, ) Database.set_meta( "window_duration", # if None the contest is not USACO-style contest.get("window_duration"), autocommit=False, ) count = 0 for task in contest["tasks"]: Database.add_task( task["name"], task["description"], task["statement_path"], task["max_score"], count, autocommit=False, ) count += 1 for user in contest["users"]: Database.add_user(user["token"], user["name"], user["surname"], autocommit=False) for user in Database.get_users(): for task in Database.get_tasks(): Database.add_user_task(user["token"], task["name"], autocommit=False) Database.set_meta("contest_imported", True, autocommit=False) Database.commit() except: Database.rollback() raise else: # TODO: check that the contest is still the same pass # store the task in the ContestManager singleton ContestManager.tasks = dict( (task["name"], task) for task in contest["tasks"]) ContestManager.has_contest = True # create the queues for the task inputs for task in ContestManager.tasks: ContestManager.input_queue[task] = gevent.queue.Queue( Config.queue_size) gevent.spawn(ContestManager.worker, task)