class Classroom(db.Model): __tablename__ = "classrooms" id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Unicode(64), nullable=False) owner = db.Column(db.Integer) teams = db.relationship("Team", passive_deletes=True, secondary=team_classroom, backref="teams") invites = db.relationship("Team", passive_deletes=True, secondary=classroom_invitation, backref="invites") def __contains__(self, obj): if isinstance(obj, Team): return obj in self.teams return False @property def teacher(self): return User.query.filter_by(uid=self.owner).first() @property def size(self): return len(self.teams) @property def scoreboard(self): return sorted(self.teams, key=lambda team: (team.points(), -team.get_last_solved()), reverse=True)
class WrongEgg(db.Model): __tablename__ = "wrong_egg" id = db.Column(db.Integer, primary_key=True) eid = db.Column(db.Integer, db.ForeignKey("eggs.eid"), index=True) tid = db.Column(db.Integer, db.ForeignKey("teams.tid"), index=True) uid = db.Column(db.Integer, db.ForeignKey("users.uid"), index=True) date = db.Column(db.DateTime, default=datetime.utcnow) submission = db.Column(db.Unicode(64))
class WrongFlag(db.Model): __tablename__ = "wrong_flags" id = db.Column(db.Integer, index=True, primary_key=True) pid = db.Column(db.Integer, db.ForeignKey("problems.pid"), index=True) tid = db.Column(db.Integer, db.ForeignKey("teams.tid"), index=True) uid = db.Column(db.Integer, db.ForeignKey("users.uid"), index=True) _date = db.Column("date", db.DateTime, default=datetime.utcnow) flag = db.Column(db.Unicode(256), index=True) @hybrid_property def date(self): return int(time.mktime(self._date.timetuple())) @date.expression def date_expression(self): return self._date
class PasswordResetToken(db.Model): __tablename__ = "password_reset_tokens" id = db.Column(db.Integer, primary_key=True) uid = db.Column(db.Integer, db.ForeignKey("users.uid"), index=True) active = db.Column(db.Boolean) token = db.Column(db.String(32), default=generate_short_string) email = db.Column(db.Unicode(128)) expire = db.Column(db.DateTime) @property def expired(self): return datetime.utcnow() >= self.expire @property def user(self): return User.get_by_id(self.uid)
class Solve(db.Model): __tablename__ = "solves" __table_args__ = (db.UniqueConstraint('pid', 'tid'), ) id = db.Column(db.Integer, index=True, primary_key=True) pid = db.Column(db.Integer, db.ForeignKey("problems.pid"), index=True) tid = db.Column(db.Integer, db.ForeignKey("teams.tid"), index=True) uid = db.Column(db.Integer, db.ForeignKey("users.uid"), index=True) _date = db.Column("date", db.DateTime, default=datetime.utcnow) flag = db.Column(db.Unicode(256)) @hybrid_property def date(self): return int(time.mktime(self._date.timetuple())) @date.expression def date_expression(self): return self._date
class File(db.Model): __tablename__ = "files" id = db.Column(db.Integer, index=True, primary_key=True) pid = db.Column(db.Integer, db.ForeignKey("problems.pid"), index=True) filename = db.Column(db.Unicode(64)) url = db.Column(db.String(128)) @staticmethod def clean_name(name): return filename_filter(name) def __init__(self, pid, filename, data): self.pid = pid self.filename = filename data.seek(0) if not app.config.get("TESTING"): response = save_file(data, suffix="_" + filename) if response.status_code == 200: self.url = response.text
class Egg(db.Model): __tablename__ = "eggs" eid = db.Column(db.Integer, primary_key=True) flag = db.Column(db.Unicode(64), nullable=False, unique=True, index=True) solves = db.relationship("EggSolve", backref="egg", lazy=True)
class Config(db.Model): __tablename__ = "config" cid = db.Column(db.Integer, primary_key=True) key = db.Column(db.Unicode(32), index=True) value = db.Column(db.Text) def __init__(self, key, value): self.key = key self.value = value @classmethod def get_competition_window(cls): return (0, 0) @classmethod def get_team_size(cls): # TODO: actually implement this return 5 @classmethod def get(cls, key, default=None): config = cls.query.filter_by(key=key).first() if config is None: return default return str(config.value) @classmethod def set(cls, key, value): config = cls.query.filter_by(key=key).first() if config is None: config = Config(key, value) db.session.add(config) db.session.commit() @classmethod def get_many(cls, keys): items = cls.query.filter(cls.key.in_(keys)).all() return dict([(item.key, item.value) for item in items]) @classmethod def set_many(cls, configs): for key, value in list(configs.items()): config = cls.query.filter_by(key=key).first() if config is None: config = Config(key, value) config.value = value db.session.add(config) db.session.commit() @classmethod def get_ssh_keys(cls): private_key = cls.get("private_key") public_key = cls.get("public_key") if not (private_key and public_key): key = RSA.generate(2048) private_key = key.exportKey("PEM") public_key = key.publickey().exportKey("OpenSSH") cls.set_many({ "private_key": str(private_key, "utf-8"), "public_key": str(public_key, "utf-8") }) return private_key, public_key def __repr__(self): return "Config({}={})".format(self.key, self.value)
class Team(db.Model): __tablename__ = "teams" tid = db.Column(db.Integer, primary_key=True, index=True) teamname = db.Column(db.Unicode(32), unique=True) school = db.Column(db.Unicode(64)) owner = db.Column(db.Integer) classrooms = db.relationship("Classroom", secondary=team_classroom, backref="classrooms") classroom_invites = db.relationship("Classroom", secondary=classroom_invitation, backref="classroom_invites") members = db.relationship("User", back_populates="team") admin = db.Column(db.Boolean, default=False) shell_user = db.Column(db.String(16), unique=True) shell_pass = db.Column(db.String(32)) banned = db.Column(db.Boolean, default=False) solves = db.relationship("Solve", backref="team", lazy=True) jobs = db.relationship("Job", backref="team", lazy=True) _avatar = db.Column("avatar", db.String(128)) outgoing_invitations = db.relationship("User", secondary=team_player_invitation, lazy="subquery", backref=db.backref( "incoming_invitations", lazy=True)) def __repr__(self): return "%s_%s" % (self.__class__.__name__, self.tid) def __str__(self): return "<Team %s>" % self.tid @property def avatar(self): if not self._avatar: avatar_file = BytesIO() avatar = generate_identicon("team%s" % self.tid) avatar.save(avatar_file, format="PNG") avatar_file.seek(0) response = save_file(avatar_file, prefix="user_avatar_", suffix=".png") if response.status_code == 200: self._avatar = response.text db.session.add(self) db.session.commit() return self._avatar @staticmethod def get_by_id(id): query_results = Team.query.filter_by(tid=id) return query_results.first() @property def size(self): return len(self.members) @hybrid_property @cache.memoize(timeout=120) def observer(self): return User.query.filter( and_(User.tid == self.tid, User.level != USER_REGULAR)).count() @observer.expression @cache.memoize(timeout=120) def observer(self): return db.session.query(User).filter( User.tid == self.tid and User.level != USER_REGULAR).count() @hybrid_property def prop_points(self): return sum( problem.value for problem, solve in db.session.query(Problem, Solve).filter( Solve.tid == self.tid).filter(Problem.pid == Solve.tid).all()) @prop_points.expression def prop_points(self): return db.session.query(Problem, Solve).filter(Solve.tid == self.tid).filter(Problem.pid == Solve.tid)\ .with_entities(func.sum(Problem.value)).scalar() @cache.memoize(timeout=120) def points(self): points = 0 solves = self.solves solves.sort(key=lambda s: s.date, reverse=True) for solve in solves: problem = Problem.query.filter_by(pid=solve.pid).first() points += problem.value return points @cache.memoize(timeout=120) def place(self): scoreboard = Team.scoreboard() if not self.observer: scoreboard = [team for team in scoreboard if not team.observer] i = 0 for i in range(len(scoreboard)): if scoreboard[i].tid == self.tid: break i += 1 return i @hybrid_property def prop_last_solved(self): solve = Solve.query.filter_by(tid=self.tid).order_by( Solve.date).first() if not solve: return 0 return solve.date @cache.memoize(timeout=120) def get_last_solved(self): solves = self.solves solves.sort(key=lambda s: s.date, reverse=True) if solves: solve = solves[0] return solve.date if solve else 0 return 0 def has_unlocked(self, problem): solves = self.solves if not problem.weightmap: return True current = sum( [problem.weightmap.get(solve.problem.name, 0) for solve in solves]) return current >= problem.threshold def get_unlocked_problems(self, admin=False, programming=None): if admin: return Problem.query.order_by(Problem.value).all() match = Problem.value > 0 if programming is not None: match = and_(match, Problem.programming == programming) problems = Problem.query.filter(match).order_by(Problem.value).all() solves = self.solves def unlocked(problem): if not problem.weightmap: return True current = sum([ problem.weightmap.get(solve.problem.name, 0) for solve in solves ]) return current >= problem.threshold return list(filter(unlocked, problems)) def get_jobs(self): return Job.query.filter_by(tid=self.tid).order_by( Job.completion_time.desc()).all() def has_solved(self, pid): return Solve.query.filter_by(tid=self.tid, pid=pid).count() > 0 @classmethod @cache.memoize(timeout=60) def scoreboard(cls): # credit: https://github.com/CTFd/CTFd/blob/master/CTFd/scoreboard.py uniq = db.session\ .query(Solve.tid.label("tid"), Solve.pid.label("pid"))\ .distinct()\ .subquery() # flash("uniq: " + str(uniq).replace("\n", ""), "info") scores = db.session\ .query( # uniq.columns.tid.label("tid"), Solve.tid.label("tid"), db.func.max(Solve.pid).label("pid"), db.func.sum(Problem.value).label("score"), db.func.max(Solve.date).label("date"))\ .join(Problem)\ .group_by(Solve.tid) # flash("scores: " + str(scores).replace("\n", ""), "info") results = union_all(scores).alias("results") sumscores = db.session\ .query(results.columns.tid, db.func.sum(results.columns.score).label("score"), db.func.max(results.columns.pid), db.func.max(results.columns.date).label("date"))\ .group_by(results.columns.tid)\ .subquery() query = db.session\ .query(Team, Team.tid.label("tid"), sumscores.columns.score, sumscores.columns.date)\ .filter(Team.banned == False)\ .join(sumscores, Team.tid == sumscores.columns.tid)\ .order_by(sumscores.columns.score.desc(), sumscores.columns.date) # flash("full query: " + str(query).replace("\n", ""), "info") return query.all() @cache.memoize(timeout=120) def get_score_progression(self): def convert_to_time(time): m, s = divmod(time, 60) h, m = divmod(m, 60) d, h = divmod(h, 24) if d > 0: return "%d:%02d:%02d:%02d" % (d, h, m, s) return "%d:%02d:%02d" % (h, m, s) solves = self.solves solves.sort(key=lambda s: s.date) progression = [["Time", "Score"], [convert_to_time(0), 0]] score = 0 start_time = int(Config.get("start_time", default=0)) for solve in solves: score += solve.problem.value frame = [convert_to_time(solve.date - start_time), score] progression.append(frame) progression.append([convert_to_time(time.time() - start_time), score]) return progression def credentials(self): host = app.config.get("SHELL_HOST") if not host: return None print("host:", host) private_key_contents, _ = Config.get_ssh_keys() private_key = paramiko.rsakey.RSAKey( file_obj=StringIO(private_key_contents)) if not private_key: return None print("private key:", private_key) if not self.shell_user or not self.shell_pass: client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect(host, username="******", pkey=private_key, look_for_keys=False) stdin, stdout, stderr = client.exec_command("\n") data = stdout.read().decode("utf-8").split("\n") for line in data: match = new_user_pattern.match(line) if match: username, password = match.group(1), match.group(2) break else: return None self.shell_user = username self.shell_pass = password db.session.commit() return (username, password) return (self.shell_user, self.shell_pass)
class Problem(db.Model): __tablename__ = "problems" pid = db.Column(db.Integer, index=True, primary_key=True) author = db.Column(db.Unicode(32)) name = db.Column(db.String(32), unique=True) title = db.Column(db.Unicode(64)) description = db.Column(db.Text) hint = db.Column(db.Text) category = db.Column(db.Unicode(64)) value = db.Column(db.Integer) grader = db.Column(db.UnicodeText) autogen = db.Column(db.Boolean) programming = db.Column(db.Boolean) threshold = db.Column(db.Integer) weightmap = db.Column(db.PickleType) solves = db.relationship("Solve", backref="problem", lazy=True) jobs = db.relationship("Job", backref="problem", lazy=True) test_cases = db.Column(db.Integer) time_limit = db.Column(db.Integer) # in seconds memory_limit = db.Column(db.Integer) # in kb generator = db.Column(db.Text) # will use the same grader as regular problems source_verifier = db.Column(db.Text) # may be implemented (possibly) path = db.Column(db.String(128)) # path to problem source code files = db.relationship("File", backref="problem", lazy=True) autogen_files = db.relationship("AutogenFile", backref="problem", lazy=True) @staticmethod def validate_problem(path, name): files = os.listdir(path) valid = True for required_file in ["grader.py", "problem.yml", "description.md"]: if required_file not in files: print("\t* Missing {}".format(required_file)) valid = False if not valid: return valid metadata = yaml.load(open(os.path.join(path, "problem.yml"))) if metadata.get("programming", False): if "generator.py" not in files: print("\t* Missing generator.py") valid = False for required_key in ["test_cases", "time_limit", "memory_limit"]: if required_key not in metadata: print( "\t* Expected required key {} in 'problem.yml'".format( required_key)) valid = False return valid @staticmethod def import_problem(path, name): print(" - {}".format(name)) if not Problem.validate_problem(path, name): return problem = Problem.query.filter_by(name=name).first() if not problem: problem = Problem(name=name) metadata = yaml.load(open(os.path.join(path, "problem.yml"))) problem.author = metadata.get("author", "") problem.title = metadata.get("title", "") problem.category = metadata.get("category", "") problem.value = int(metadata.get("value", "0")) problem.hint = metadata.get("hint", "") problem.autogen = metadata.get("autogen", False) problem.programming = metadata.get("programming", False) problem.description = open(os.path.join(path, "description.md")).read() problem.grader = open(os.path.join(path, "grader.py")).read() problem.path = os.path.realpath(path) if metadata.get("threshold") and type( metadata.get("threshold")) is int: problem.threshold = metadata.get("threshold") problem.weightmap = metadata.get("weightmap", {}) if problem.programming: problem.test_cases = int(metadata.get("test_cases")) problem.time_limit = int(metadata.get("time_limit")) problem.memory_limit = int(metadata.get("memory_limit")) problem.generator = open(os.path.join(path, "generator.py")).read() db.session.add(problem) db.session.flush() db.session.commit() files = metadata.get("files", []) for filename in files: file_path = os.path.join(path, filename) if not os.path.isfile(file_path): print("\t* File '{}' doesn't exist".format(filename, name)) continue source = open(file_path, "rb") file = File(pid=problem.pid, filename=filename, data=source) existing = File.query.filter_by(pid=problem.pid, filename=filename).first() # Update existing file url if existing: existing.url = file.url db.session.add(existing) else: db.session.add(file) db.session.commit() @staticmethod def import_repository(path): if not (os.path.realpath(path) and os.path.exists(path) and os.path.isdir(path)): print("this isn't a path") sys.exit(1) path = os.path.realpath(path) names = os.listdir(path) for name in names: if name.startswith("."): continue problem_dir = os.path.join(path, name) if not os.path.isdir(problem_dir): continue problem_name = os.path.basename(problem_dir) Problem.import_problem(problem_dir, problem_name) @classmethod def categories(cls): def f(c): return c[0] categories = map(f, db.session.query(Problem.category).distinct().all()) return list(categories) @staticmethod def get_by_id(id): query_results = Problem.query.filter_by(pid=id) return query_results.first() if query_results.count() else None @property def solved(self): return Solve.query.filter_by(pid=self.pid, tid=current_user.tid).count() def get_grader(self): grader = imp.new_module("grader") curr = os.getcwd() if self.path: os.chdir(self.path) exec(self.grader, grader.__dict__) os.chdir(curr) return grader def get_autogen(self, tid): autogen = __import__("random") autogen.seed("%s_%s_%s" % (SEED, self.pid, tid)) return autogen def render_description(self, tid): description = markdown(self.description, extras=["fenced-code-blocks"]) try: variables = {} template = Template(description) if self.autogen: autogen = self.get_autogen(tid) grader = self.get_grader() generated_problem = grader.generate(autogen) if "variables" in generated_problem: variables.update(generated_problem["variables"]) if "files" in generated_problem: for file in generated_problem["files"]: url = url_for("chals.autogen", pid=self.pid, filename=file) variables[File.clean_name(file)] = url static_files = File.query.filter_by(pid=self.pid).all() if static_files is not None: for file in static_files: url = "{}/{}".format(app.config["FILESTORE_STATIC"], file.url) variables[File.clean_name(file.filename)] = url description = template.safe_substitute(variables) except Exception as e: description += "<!-- parsing error: {} -->".format( traceback.format_exc()) traceback.print_exc(file=sys.stderr) description = description.replace("${", "{") return description # TODO: clean up the shitty string-enum return # the shitty return is used directly in game.py def try_submit(self, flag): solved = Solve.query.filter_by(tid=current_user.tid, pid=self.pid).first() if solved: return "error", "You've already solved this problem" already_tried = WrongFlag.query.filter_by(tid=current_user.tid, pid=self.pid, flag=flag).count() if already_tried: return "error", "You've already tried this flag" random = None if self.autogen: random = self.get_autogen(current_user.tid) grader = self.get_grader() correct, message = grader.grade(random, flag) if correct: submission = Solve(pid=self.pid, tid=current_user.tid, uid=current_user.uid, flag=flag) db.session.add(submission) db.session.commit() else: if len(flag) < 256: submission = WrongFlag(pid=self.pid, tid=current_user.tid, uid=current_user.uid, flag=flag) db.session.add(submission) db.session.commit() else: # f**k you pass cache.delete_memoized(current_user.team.place) cache.delete_memoized(current_user.team.points) cache.delete_memoized(current_user.team.get_last_solved) cache.delete_memoized(current_user.team.get_score_progression) return "success" if correct else "failure", message def api_summary(self): summary = { field: getattr(self, field) for field in [ 'pid', 'author', 'name', 'title', 'hint', 'category', 'value', 'solved', 'programming' ] } summary['description'] = self.render_description(current_user.tid) return summary
class User(db.Model): __tablename__ = "users" uid = db.Column(db.Integer, index=True, primary_key=True) tid = db.Column(db.Integer, db.ForeignKey("teams.tid")) name = db.Column(db.Unicode(32)) easyctf = db.Column(db.Boolean, index=True, default=False) username = db.Column(db.String(16), unique=True, index=True) email = db.Column(db.String(128), unique=True) _password = db.Column("password", db.String(128)) admin = db.Column(db.Boolean, default=False) level = db.Column(db.Integer) _register_time = db.Column("register_time", db.DateTime, default=datetime.utcnow) reset_token = db.Column(db.String(32)) otp_secret = db.Column(db.String(16)) otp_confirmed = db.Column(db.Boolean, default=False) email_token = db.Column(db.String(32)) email_verified = db.Column(db.Boolean, default=False) team = db.relationship("Team", back_populates="members") solves = db.relationship("Solve", backref="user", lazy=True) jobs = db.relationship("Job", backref="user", lazy=True) _avatar = db.Column("avatar", db.String(128)) outgoing_invitations = db.relationship("Team", secondary=player_team_invitation, lazy="subquery", backref=db.backref( "incoming_invitations", lazy=True)) @property def avatar(self): if not self._avatar: avatar_file = BytesIO() avatar = generate_identicon("user%s" % self.uid) avatar.save(avatar_file, format="PNG") avatar_file.seek(0) response = save_file(avatar_file, prefix="team_avatar_", suffix=".png") if response.status_code == 200: self._avatar = response.text db.session.add(self) db.session.commit() return self._avatar or "" # just so the frontend doesnt 500 def __eq__(self, other): if isinstance(other, User): return self.uid == other.uid return NotImplemented def __str__(self): return "<User %s>" % self.uid def check_password(self, password): return bcrypt.verify(password, self.password) def get_id(self): return str(self.uid) @property def is_anonymous(self): return False @staticmethod @login_manager.user_loader def get_by_id(id): query_results = User.query.filter_by(uid=id) return query_results.first() @property def is_active(self): # TODO This will be based off account standing. return True @property def is_authenticated(self): return True @hybrid_property def password(self): return self._password @password.setter def password(self, password): self._password = bcrypt.encrypt(password, rounds=10) @hybrid_property def register_time(self): return int(time.mktime(self._register_time.timetuple())) @hybrid_property def username_lower(self): return self.username.lower() def get_totp_uri(self): if self.otp_secret is None: secret = base64.b32encode(os.urandom(10)).decode("utf-8").lower() self.otp_secret = secret db.session.add(self) db.session.commit() service_name = Config.get("ctf_name") return "otpauth://totp/%s:%s?secret=%s&issuer=%s" % ( service_name, self.username, self.otp_secret, service_name) def verify_totp(self, token): return onetimepass.valid_totp(token, self.otp_secret) @cache.memoize(timeout=120) def points(self): points = 0 for solve in self.solves: points += solve.problem.value return points