コード例 #1
0
class ZstdJsonSerializer(Serializer):
    def __init__(self):
        self.compressor = ZstdCompressor()
        self.decompressor = ZstdDecompressor()

    def serialize(self, item) -> bytes:
        return self.compressor.compress(json.dumps(item).encode('utf8'))

    def deserialize(self, serialized_item: bytes):
        return json.loads(
            self.decompressor.decompress(serialized_item).decode('utf8'))
コード例 #2
0
def post_submission(contest_id: str) -> Response:
    _, body = _validate_request()
    problem_id, code, env_id = body.problem_id, body.code, body.environment_id

    cctx = ZstdCompressor()
    code_encoded = code.encode('utf8')
    code = cctx.compress(code_encoded)

    with transaction() as s:
        u = _validate_token(s, required=True)
        assert (u)
        if not s.query(Environment).filter(Environment.id == env_id).count():
            abort(400)  # bodyが不正なので400
        if not s.query(Contest).filter(Contest.id == contest_id).count():
            abort(404)  # contest_idはURLに含まれるため404
        if not s.query(Problem).filter(Problem.contest_id == contest_id,
                                       Problem.id == problem_id).count():
            abort(400)  # bodyが不正なので400
        queued_submission_count = s.query(Submission).filter(
            Submission.user_id == u['id'],
            Submission.status.in_([JudgeStatus.Waiting,
                                   JudgeStatus.Running])).count()
        if queued_submission_count > app.config['user_judge_queue_limit']:
            abort(429)
        submission = Submission(contest_id=contest_id,
                                problem_id=problem_id,
                                user_id=u['id'],
                                code=code,
                                code_bytes=len(code_encoded),
                                environment_id=env_id)
        s.add(submission)
        s.flush()
        ret = submission.to_summary_dict()
        ret['user_name'] = u['name']

    conn = pika.BlockingConnection(get_mq_conn_params())
    ch = conn.channel()
    ch.queue_declare(queue='judge_queue')
    ch.basic_publish(exchange='',
                     routing_key='judge_queue',
                     body=pickle.dumps((contest_id, problem_id, ret['id'])))
    ch.close()
    conn.close()
    return jsonify(ret, status=201)
コード例 #3
0
ファイル: test_api.py プロジェクト: shiodat/PenguinJudge
    def test_submission(self, mock_conn, mock_get_params):
        # TODO(kazuki): API経由に書き換える
        env = dict(name='Python 3.7', test_image_name='docker-image')
        with transaction() as s:
            env = Environment(**env)
            s.add(env)
            s.flush()
            env = env.to_dict()

        start_time = datetime.now(tz=timezone.utc)
        contest_id = app.post_json('/contests', {
            'id': 'abc000',
            'title': 'ABC000',
            'description': '# ABC000\n\nほげほげ\n',
            'start_time': start_time.isoformat(),
            'end_time': (start_time + timedelta(hours=1)).isoformat(),
            'published': True,
        }, headers=self.admin_headers).json['id']
        prefix = '/contests/{}'.format(contest_id)
        app.post_json(
            '{}/problems'.format(prefix), dict(
                id='A', title='A Problem', description='# A', time_limit=2,
                score=100
            ), headers=self.admin_headers)

        # TODO(kazuki): API経由に書き換える
        ctx = ZstdCompressor()
        with transaction() as s:
            s.add(TestCase(
                contest_id=contest_id,
                problem_id='A',
                id='1',
                input=ctx.compress(b'1'),
                output=ctx.compress(b'2')))

        app.get('{}/submissions'.format(prefix), status=403)
        self.assertEqual([], app.get(
            '{}/submissions'.format(prefix), headers=self.admin_headers).json)
        app.get('/contests/invalid/submissions', status=404)

        code = 'print("Hello World")'
        resp = app.post_json('{}/submissions'.format(prefix), {
            'problem_id': 'A',
            'environment_id': env['id'],
            'code': code,
        }, headers=self.admin_headers).json
        self.assertEqual([resp], app.get(
            '{}/submissions'.format(prefix), headers=self.admin_headers).json)
        app.get('{}/submissions/{}'.format(prefix, resp['id']), status=404)
        resp2 = app.get('{}/submissions/{}'.format(prefix, resp['id']),
                        headers=self.admin_headers).json
        self.assertEqual(resp2.pop('code'), code)
        resp['tests'] = []
        self.assertEqual(resp, resp2)

        app.post_json('{}/submissions'.format(prefix), {
            'problem_id': 'invalid',
            'environment_id': env['id'],
            'code': code,
        }, headers=self.admin_headers, status=400)
        app.post_json('{}/submissions'.format(prefix), {
            'problem_id': 'A',
            'environment_id': 99999,
            'code': code,
        }, headers=self.admin_headers, status=400)
        app.get('{}/submissions/99999'.format(prefix), status=404)

        contest_id2 = app.post_json('/contests', {
            'id': 'abc001',
            'title': 'ABC001',
            'description': '# ABC001',
            'start_time': start_time.isoformat(),
            'end_time': (start_time + timedelta(hours=1)).isoformat(),
        }, headers=self.admin_headers).json['id']
        app.get(
            '/contests/{}/submissions/{}'.format(contest_id2, resp['id']),
            status=404)

        with transaction() as s:
            s.query(Contest).update({'end_time': start_time})
        app.get('{}/submissions'.format(prefix))
コード例 #4
0
def upload_test_dataset(contest_id: str, problem_id: str) -> Response:
    from collections import Counter
    import shutil
    from zipfile import ZipFile
    from contextlib import ExitStack
    from tempfile import TemporaryFile

    zctx = ZstdCompressor()
    test_cases = []
    ret = []
    with ExitStack() as stack:
        f = stack.enter_context(TemporaryFile())
        shutil.copyfileobj(request.stream, f)
        f.seek(0)
        z = stack.enter_context(ZipFile(f))
        counts = Counter()  # type: ignore
        path_mapping = {}
        for x in z.namelist():
            if not (x.endswith('.in') or x.endswith('.out')):
                continue
            name = os.path.basename(x)
            counts.update([os.path.splitext(name)[0]])
            path_mapping[name] = x

        for k, v in counts.items():
            if v != 2:
                continue
            try:
                with z.open(path_mapping[k + '.in']) as zi:
                    in_data = zctx.compress(zi.read())
                with z.open(path_mapping[k + '.out']) as zo:
                    out_data = zctx.compress(zo.read())
            except Exception:
                continue
            test_cases.append(
                dict(contest_id=contest_id,
                     problem_id=problem_id,
                     id=k,
                     input=in_data,
                     output=out_data))
            ret.append(k)

    # 参照がないテストケースのみを削除し、
    # 参照があるテストケースはUPDATE、
    # 新規テストケースはINSERTする
    from sqlalchemy import exists
    with transaction() as s:
        _ = _validate_token(s, admin_required=True)
        s.query(TestCase).filter(
            TestCase.contest_id == contest_id,
            TestCase.problem_id == problem_id,
            ~exists().where(JudgeResult.test_id == TestCase.id)).delete(
                synchronize_session=False)
        for kwargs in test_cases:
            q = s.query(TestCase).filter(TestCase.contest_id == contest_id,
                                         TestCase.problem_id == problem_id,
                                         TestCase.id == kwargs['id'])
            if q.count() == 0:
                s.add(TestCase(**kwargs))
            else:
                kwargs.pop('contest_id')
                kwargs.pop('problem_id')
                kwargs.pop('id')
                q.update(kwargs, synchronize_session=False)
    return jsonify(ret)
コード例 #5
0
class TinyIndexer(TinyIndexBase):
    def __init__(self, item_type: type, index_path: str, num_pages: int,
                 page_size: int):
        super().__init__(item_type, num_pages, page_size)
        self.index_path = index_path
        self.compressor = ZstdCompressor()
        self.decompressor = ZstdDecompressor()
        self.index_file = None
        self.mmap = None

    def __enter__(self):
        self.create_if_not_exists()
        self.index_file = open(self.index_path, 'r+b')
        self.mmap = mmap(self.index_file.fileno(), 0)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.mmap.close()
        self.index_file.close()

    # def index(self, documents: List[TokenizedDocument]):
    #     for document in documents:
    #         for token in document.tokens:
    #             self._index_document(document, token)

    def index(self, key: str, value):
        print("Index", value)
        assert type(value) == self.item_type, f"Can only index the specified type" \
                                              f" ({self.item_type.__name__})"
        page_index = self._get_key_page_index(key)
        current_page = self.get_page(page_index)
        if current_page is None:
            current_page = []
        value_tuple = astuple(value)
        print("Value tuple", value_tuple)
        current_page.append(value_tuple)
        try:
            # print("Page", current_page)
            self._write_page(current_page, page_index)
        except ValueError:
            pass

    def _write_page(self, data, i):
        """
        Serialise the data using JSON, compress it and store it at index i.
        If the data is too big, it will raise a ValueError and not store anything
        """
        serialised_data = json.dumps(data)
        compressed_data = self.compressor.compress(
            serialised_data.encode('utf8'))
        page_length = len(compressed_data)
        if page_length > self.page_size:
            raise ValueError(
                f"Data is too big ({page_length}) for page size ({self.page_size})"
            )
        padding = b'\x00' * (self.page_size - page_length)
        self.mmap[i * self.page_size:(i + 1) *
                  self.page_size] = compressed_data + padding

    def create_if_not_exists(self):
        if not os.path.isfile(self.index_path):
            file_length = self.num_pages * self.page_size
            with open(self.index_path, 'wb') as index_file:
                index_file.write(b'\x00' * file_length)
コード例 #6
0
class AtlasRecordsWriter:
    """
    Write Atlas results in ND-JSON format.

    .. code-block:: python

        from fetchmesh.io import AtlasRecordsWriter
        with AtlasRecordsWriter("results.ndjson") as w:
            w.write({"msm_id": 1001, "prb_id": 1, "...": "..."})
    """

    file: Path
    """Output file path."""

    filters: List[StreamFilter[dict]] = field(default_factory=list)
    """List of filters to apply before writing the records."""

    append: bool = False
    """
    Whether to create a new file, or to append the records to an existing file.
    If append is set to false, and the output file already exists, it will be deleted.
    When append is set to false, the output file will be deleted if an exception happens.
    """

    log: bool = False
    """Record the size (in bytes) of each record. See :any:`LogEntry`."""

    compression: bool = False
    """
    Compresse the records using zstandard.
    We use the one-shot compression API and write one frame per record.
    This results in larger files than a single frame for all the records,
    but it allows us to build an index and make the file seekable.
    We use a pre-built dictionary (see :any:`dictionary`) to reduce the size of the compressed records.
    """

    compression_ctx: Optional[ZstdCompressor] = field(default=None, init=False)

    @property
    def log_file(self) -> Path:
        """Path to the (optional) log file."""
        return self.file.with_suffix(self.file.suffix + ".log")

    def __post_init__(self):
        self.file = Path(self.file)

    def __enter__(self):
        mode = "ab" if self.append else "wb"

        # (1) Open the output file
        self.f = self.file.open(mode)

        # (2) Open the log file
        if self.log:
            self.log_f = self.log_file.open(mode)

        # (3) Setup the compression context
        if self.compression:
            dict_data = ZstdCompressionDict(dictionary.read_bytes())
            self.compression_ctx = ZstdCompressor(dict_data=dict_data)

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        # (1) Close the output file
        self.f.flush()
        self.f.close()

        # (2) Close the log file
        if self.log:
            self.log_f.flush()
            self.log_f.close()

        # (3) Handle exceptions
        # NOTE: In append mode, we do not delete the file.
        if exc_type:
            if not self.append:
                self.file.unlink()
            if not self.append and self.log_file.exists():
                self.log_file.unlink()
            print_exception(exc_type, exc_value, traceback)

        # Do not reraise exceptions, excepted for KeyboardInterrupt.
        return exc_type is not KeyboardInterrupt

    def write(self, record: dict):
        """Write a single record."""

        # (1) Filter the record
        for filter_ in self.filters:
            if not filter_.keep(record):
                return

        # (2) Serialize and encode the record
        data = json_trydumps(record) + "\n"
        data = data.encode("utf-8")

        # (3) Compresse the record
        if self.compression_ctx:
            data = self.compression_ctx.compress(data)

        # (4) Update the log
        if self.log:
            entry = LogEntry.pack(len(data), record["msm_id"],
                                  record["prb_id"])
            self.log_f.write(entry)

        # (5) Write the record to the output file
        self.f.write(data)

    def writeall(self, records: Iterable[dict]):
        """Write all the records."""

        for record in records:
            self.write(record)