def open_write(self, path: str) -> IO[bytes]:
     compressor = ZstdCompressor(level=self.level)
     fobj = open(path, 'wb')
     try:
         with compressor.stream_writer(fobj) as writer:
             yield writer
     finally:
         fobj.close()
Example #2
0
 def __init__(self, item_type: type, index_path: str, num_pages: int,
              page_size: int):
     super().__init__(item_type, num_pages, page_size)
     self.index_path = index_path
     self.compressor = ZstdCompressor()
     self.decompressor = ZstdDecompressor()
     self.index_file = None
     self.mmap = None
Example #3
0
    def __enter__(self):
        mode = "ab" if self.append else "wb"

        # (1) Open the output file
        self.f = self.file.open(mode)

        # (2) Open the log file
        if self.log:
            self.log_f = self.log_file.open(mode)

        # (3) Setup the compression context
        if self.compression:
            dict_data = ZstdCompressionDict(dictionary.read_bytes())
            self.compression_ctx = ZstdCompressor(dict_data=dict_data)

        return self
Example #4
0
def getCompressorFunction(expect_compression):
    # spell-checker: ignore zstd, closefd

    try:
        from zstandard import ZstdCompressor  # pylint: disable=I0021,import-error
    except ImportError:
        assert not expect_compression

        @contextmanager
        def useSameFile(output_file):
            yield output_file

        return b"X", useSameFile
    else:
        assert expect_compression

        compressor_context = ZstdCompressor(level=22)

        @contextmanager
        def useCompressedFile(output_file):
            with compressor_context.stream_writer(
                    output_file, closefd=False) as compressed_file:
                yield compressed_file

        onefile_logger.info("Using compression for onefile payload.")

        return b"Y", useCompressedFile
Example #5
0
def post_submission(contest_id: str) -> Response:
    _, body = _validate_request()
    problem_id, code, env_id = body.problem_id, body.code, body.environment_id

    cctx = ZstdCompressor()
    code_encoded = code.encode('utf8')
    code = cctx.compress(code_encoded)

    with transaction() as s:
        u = _validate_token(s, required=True)
        assert (u)
        if not s.query(Environment).filter(Environment.id == env_id).count():
            abort(400)  # bodyが不正なので400
        if not s.query(Contest).filter(Contest.id == contest_id).count():
            abort(404)  # contest_idはURLに含まれるため404
        if not s.query(Problem).filter(Problem.contest_id == contest_id,
                                       Problem.id == problem_id).count():
            abort(400)  # bodyが不正なので400
        queued_submission_count = s.query(Submission).filter(
            Submission.user_id == u['id'],
            Submission.status.in_([JudgeStatus.Waiting,
                                   JudgeStatus.Running])).count()
        if queued_submission_count > app.config['user_judge_queue_limit']:
            abort(429)
        submission = Submission(contest_id=contest_id,
                                problem_id=problem_id,
                                user_id=u['id'],
                                code=code,
                                code_bytes=len(code_encoded),
                                environment_id=env_id)
        s.add(submission)
        s.flush()
        ret = submission.to_summary_dict()
        ret['user_name'] = u['name']

    conn = pika.BlockingConnection(get_mq_conn_params())
    ch = conn.channel()
    ch.queue_declare(queue='judge_queue')
    ch.basic_publish(exchange='',
                     routing_key='judge_queue',
                     body=pickle.dumps((contest_id, problem_id, ret['id'])))
    ch.close()
    conn.close()
    return jsonify(ret, status=201)
Example #6
0
class ZstdJsonSerializer(Serializer):
    def __init__(self):
        self.compressor = ZstdCompressor()
        self.decompressor = ZstdDecompressor()

    def serialize(self, item) -> bytes:
        return self.compressor.compress(json.dumps(item).encode('utf8'))

    def deserialize(self, serialized_item: bytes):
        return json.loads(
            self.decompressor.decompress(serialized_item).decode('utf8'))
Example #7
0
def compressBlockTask(in_queue, out_list, readyForWork, pleaseKillYourself):
    while True:
        readyForWork.increment()
        item = in_queue.get()
        readyForWork.decrement()
        if (pleaseKillYourself.value() > 0):
            break
        buffer, compressionLevel, compressedblockSizeList, chunkRelativeBlockID = item  # compressedblockSizeList IS UNUSED VARIABLE
        if buffer == 0:
            return
        compressed = ZstdCompressor(level=compressionLevel).compress(buffer)
        out_list[chunkRelativeBlockID] = compressed if len(compressed) < len(
            buffer) else buffer
Example #8
0
def compressBlockTask(in_queue, out_list, readyForWork, pleaseKillYourself, blockSize):
	while True:
		readyForWork.increment()
		item = in_queue.get()
		#readyForWork.decrement() # https://github.com/nicoboss/nsz/issues/80
		if pleaseKillYourself.value() > 0:
			break
		buffer, compressionLevel, compressedblockSizeList, chunkRelativeBlockID = item # compressedblockSizeList IS UNUSED VARIABLE
		if buffer == 0:
			return
		if compressionLevel == 0 and len(buffer) == blockSize: # https://github.com/nicoboss/nsz/issues/79
			out_list[chunkRelativeBlockID] = buffer
		else:
			compressed = ZstdCompressor(level=compressionLevel).compress(buffer)
			out_list[chunkRelativeBlockID] = compressed if len(compressed) < len(buffer) else buffer
Example #9
0
def test_decompressing_text_io_wrapper(tmp_path: Path) -> None:
    content = "This is just\nsome test content.\n"
    content_len = len(content.encode(encoding="UTF-8"))

    file = tmp_path / "file.txt"
    file.write_text(content, encoding="UTF-8")

    for progress_bar in [True, False]:
        with DecompressingTextIOWrapper(file,
                                        encoding="UTF-8",
                                        progress_bar=progress_bar) as fin:
            assert fin.size() == content_len
            assert fin.tell() == 0
            assert fin.read() == content
            assert fin.tell() == content_len

        with DecompressingTextIOWrapper(file,
                                        encoding="UTF-8",
                                        progress_bar=progress_bar) as fin:
            newline_pos = content.index("\n") + 1
            assert fin.read(4) == "This"
            assert [content[4:newline_pos], content[newline_pos:]] == list(fin)

    for extension, open_func in [
        ("gz", cast(_TOpenFunc, gzip.open)),
        ("bz2", cast(_TOpenFunc, bz2.open)),
        ("xz", cast(_TOpenFunc, lzma.open)),
    ]:
        compressed_file = tmp_path / ("file." + extension)
        with open_func(compressed_file, "wt", encoding="UTF-8") as fout:
            fout.write(content)
        with DecompressingTextIOWrapper(compressed_file,
                                        encoding="UTF-8") as fin:
            assert fin.tell() == 0
            assert fin.read() == content
            assert fin.tell() > 0

    compressed_file = tmp_path / "file.zst"
    compressed_file.write_bytes(ZstdCompressor().compress(
        content.encode(encoding="UTF-8")))
    with DecompressingTextIOWrapper(compressed_file, encoding="UTF-8") as fin:
        assert fin.tell() == 0
        assert fin.read() == content
        assert fin.tell() > 0
Example #10
0
                        help="Output file")
    args = parser.parse_args()

    with open(args.manifest) as f:
        selected = yaml.safe_load(f.read())
    with TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)
        dl_cache = tmpdir / "cache"
        downloadPackages(selected, dl_cache)
        unpacked = tmpdir / "unpack"
        extractPackages(selected, dl_cache, unpacked)
        stem = Path(Path(args.output.stem).stem)
        # Create an archive containing all the paths in lowercase form for
        # cross-compiles.
        with open(args.output, "wb") as f:
            with ZstdCompressor().stream_writer(f) as z:
                with tarfile.open(mode="w|", fileobj=z) as tar:
                    for subpath, dest in (
                        ("VC", "vc"),
                        ("Program Files/Windows Kits/10", "windows kits/10"),
                        ("DIA SDK", "dia sdk"),
                    ):
                        subpath = unpacked / subpath
                        dest = Path(dest)
                        for root, dirs, files in os.walk(subpath):
                            relpath = Path(root).relative_to(subpath)
                            for f in files:
                                path = Path(root) / f
                                info = tar.gettarinfo(path)
                                with open(path, "rb") as fh:
                                    info.name = str(stem / dest / relpath /
Example #11
0
    def test_submission(self, mock_conn, mock_get_params):
        # TODO(kazuki): API経由に書き換える
        env = dict(name='Python 3.7', test_image_name='docker-image')
        with transaction() as s:
            env = Environment(**env)
            s.add(env)
            s.flush()
            env = env.to_dict()

        start_time = datetime.now(tz=timezone.utc)
        contest_id = app.post_json('/contests', {
            'id': 'abc000',
            'title': 'ABC000',
            'description': '# ABC000\n\nほげほげ\n',
            'start_time': start_time.isoformat(),
            'end_time': (start_time + timedelta(hours=1)).isoformat(),
            'published': True,
        }, headers=self.admin_headers).json['id']
        prefix = '/contests/{}'.format(contest_id)
        app.post_json(
            '{}/problems'.format(prefix), dict(
                id='A', title='A Problem', description='# A', time_limit=2,
                score=100
            ), headers=self.admin_headers)

        # TODO(kazuki): API経由に書き換える
        ctx = ZstdCompressor()
        with transaction() as s:
            s.add(TestCase(
                contest_id=contest_id,
                problem_id='A',
                id='1',
                input=ctx.compress(b'1'),
                output=ctx.compress(b'2')))

        app.get('{}/submissions'.format(prefix), status=403)
        self.assertEqual([], app.get(
            '{}/submissions'.format(prefix), headers=self.admin_headers).json)
        app.get('/contests/invalid/submissions', status=404)

        code = 'print("Hello World")'
        resp = app.post_json('{}/submissions'.format(prefix), {
            'problem_id': 'A',
            'environment_id': env['id'],
            'code': code,
        }, headers=self.admin_headers).json
        self.assertEqual([resp], app.get(
            '{}/submissions'.format(prefix), headers=self.admin_headers).json)
        app.get('{}/submissions/{}'.format(prefix, resp['id']), status=404)
        resp2 = app.get('{}/submissions/{}'.format(prefix, resp['id']),
                        headers=self.admin_headers).json
        self.assertEqual(resp2.pop('code'), code)
        resp['tests'] = []
        self.assertEqual(resp, resp2)

        app.post_json('{}/submissions'.format(prefix), {
            'problem_id': 'invalid',
            'environment_id': env['id'],
            'code': code,
        }, headers=self.admin_headers, status=400)
        app.post_json('{}/submissions'.format(prefix), {
            'problem_id': 'A',
            'environment_id': 99999,
            'code': code,
        }, headers=self.admin_headers, status=400)
        app.get('{}/submissions/99999'.format(prefix), status=404)

        contest_id2 = app.post_json('/contests', {
            'id': 'abc001',
            'title': 'ABC001',
            'description': '# ABC001',
            'start_time': start_time.isoformat(),
            'end_time': (start_time + timedelta(hours=1)).isoformat(),
        }, headers=self.admin_headers).json['id']
        app.get(
            '/contests/{}/submissions/{}'.format(contest_id2, resp['id']),
            status=404)

        with transaction() as s:
            s.query(Contest).update({'end_time': start_time})
        app.get('{}/submissions'.format(prefix))
Example #12
0
 def compress(self, fobj: IO[bytes]) -> IO[bytes]:
     compressor = ZstdCompressor(level=self.level)
     with compressor.stream_writer(fobj) as writer:
         yield writer
Example #13
0
def create_tinfoil_index(index_to_write: dict,
                         out_path: Path,
                         compression_flag: int,
                         rsa_pub_key_path: Path = None,
                         vm_path: Path = None):
    to_compress_buffer = b""

    if vm_path is not None and vm_path.is_file():
        to_compress_buffer += b"\x13\x37\xB0\x0B"
        vm_buffer = b""

        with open(vm_path, "rb") as vm_stream:
            vm_buffer += vm_stream.read()

        to_compress_buffer += len(vm_buffer).to_bytes(4, "little")
        to_compress_buffer += vm_buffer

    to_compress_buffer += bytes(json_serialize(index_to_write).encode())

    to_write_buffer = b""
    session_key = b""

    if compression_flag == CompressionFlag.ZSTD_COMPRESSION:
        to_write_buffer += ZstdCompressor(
            level=22).compress(to_compress_buffer)

    elif compression_flag == CompressionFlag.ZLIB_COMPRESSION:
        to_write_buffer += zlib_compress(to_compress_buffer, 9)

    elif compression_flag == CompressionFlag.NO_COMPRESSION:
        to_write_buffer += to_compress_buffer

    else:
        raise NotImplementedError(
            "Compression method supplied is not implemented yet.")

    data_size = len(to_write_buffer)
    flag = None
    to_write_buffer += (b"\x00" * (0x10 - (data_size % 0x10)))

    if rsa_pub_key_path is not None and rsa_pub_key_path.is_file():

        def rand_aes_key_generator() -> bytes:
            return randint(0, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF).to_bytes(
                0x10, byteorder="big")

        rsa_pub_key = import_rsa_key(open(rsa_pub_key_path).read())
        rand_aes_key = rand_aes_key_generator()

        pkcs1_oaep_ctx = new_pkcs1_oaep_ctx(rsa_pub_key,
                                            hashAlgo=SHA256,
                                            label=b"")
        aes_ctx = new_aes_ctx(rand_aes_key, MODE_ECB)

        session_key += pkcs1_oaep_ctx.encrypt(rand_aes_key)
        to_write_buffer = aes_ctx.encrypt(to_write_buffer)
        flag = compression_flag | EncryptionFlag.ENCRYPT
    else:
        session_key += b"\x00" * 0x100
        flag = compression_flag | EncryptionFlag.NO_ENCRYPT

    Path(out_path.parent).mkdir(parents=True, exist_ok=True)

    with open(out_path, "wb") as out_stream:
        out_stream.write(b"TINFOIL")
        out_stream.write(flag.to_bytes(1, byteorder="little"))
        out_stream.write(session_key)
        out_stream.write(data_size.to_bytes(8, "little"))
        out_stream.write(to_write_buffer)
Example #14
0
def solidCompress(filePath, compressionLevel=18, outputDir=None, threads=-1):
    ncaHeaderSize = 0x4000
    filePath = str(Path(filePath).resolve())
    container = factory(filePath)
    container.open(filePath, 'rb')
    CHUNK_SZ = 0x1000000
    nszPath = str(
        Path(filePath[0:-1] + 'z' if outputDir == None else Path(outputDir).
             joinpath(Path(filePath[0:-1] + 'z').name)).resolve(strict=False))

    for nspf in container:
        if isinstance(nspf, Ticket.Ticket):
            nspf.getRightsId()
            break  # No need to go for other objects

    Print.info('compressing (level %d) %s -> %s' %
               (compressionLevel, filePath, nszPath))
    newNsp = Pfs0.Pfs0Stream(nszPath)

    try:
        for nspf in container:
            if isinstance(
                    nspf,
                    Nca.Nca) and nspf.header.contentType == Type.Content.DATA:
                Print.info('skipping delta fragment')
                continue
            if isinstance(nspf, Nca.Nca) and (
                    nspf.header.contentType == Type.Content.PROGRAM
                    or nspf.header.contentType == Type.Content.PUBLICDATA):
                if isNcaPacked(nspf, ncaHeaderSize):
                    newFileName = nspf._path[0:-1] + 'z'
                    f = newNsp.add(newFileName, nspf.size)
                    start = f.tell()
                    nspf.seek(0)
                    f.write(nspf.read(ncaHeaderSize))
                    sections = []

                    for fs in sortedFs(nspf):
                        sections += fs.getEncryptionSections()

                    if len(sections) == 0:
                        raise Exception(
                            "NCA can't be decrypted. Outdated keys.txt?")
                    header = b'NCZSECTN'
                    header += len(sections).to_bytes(8, 'little')

                    for fs in sections:
                        header += fs.offset.to_bytes(8, 'little')
                        header += fs.size.to_bytes(8, 'little')
                        header += fs.cryptoType.to_bytes(8, 'little')
                        header += b'\x00' * 8
                        header += fs.cryptoKey
                        header += fs.cryptoCounter

                    f.write(header)
                    decompressedBytes = ncaHeaderSize

                    with tqdm(total=nspf.size, unit_scale=True,
                              unit="B") as bar:
                        partitions = [
                            nspf.partition(
                                offset=section.offset,
                                size=section.size,
                                n=None,
                                cryptoType=section.cryptoType,
                                cryptoKey=section.cryptoKey,
                                cryptoCounter=bytearray(section.cryptoCounter),
                                autoOpen=True) for section in sections
                        ]
                        partNr = 0
                        bar.update(f.tell())
                        cctx = ZstdCompressor(
                            level=compressionLevel, threads=threads
                        ) if threads > 1 else ZstdCompressor(
                            level=compressionLevel)
                        compressor = cctx.stream_writer(f)

                        while True:
                            buffer = partitions[partNr].read(CHUNK_SZ)

                            while (len(buffer) < CHUNK_SZ
                                   and partNr < len(partitions) - 1):
                                partitions[partNr].close()
                                partitions[partNr] = None
                                partNr += 1
                                buffer += partitions[partNr].read(CHUNK_SZ -
                                                                  len(buffer))

                            if len(buffer) == 0:
                                break
                            compressor.write(buffer)
                            decompressedBytes += len(buffer)
                            bar.update(len(buffer))

                        partitions[partNr].close()
                        partitions[partNr] = None

                    compressor.flush(FLUSH_FRAME)
                    compressor.flush(COMPRESSOBJ_FLUSH_FINISH)
                    written = f.tell() - start
                    print('compressed %d%% %d -> %d  - %s' %
                          (int(written * 100 / nspf.size), decompressedBytes,
                           written, nspf._path))
                    newNsp.resize(newFileName, written)
                    continue
                else:
                    print('not packed!')
            f = newNsp.add(nspf._path, nspf.size)
            nspf.seek(0)
            while not nspf.eof():
                buffer = nspf.read(CHUNK_SZ)
                f.write(buffer)
    except KeyboardInterrupt:
        remove(nszPath)
        raise KeyboardInterrupt
    except BaseException:
        Print.error(format_exc())
        remove(nszPath)
    finally:
        newNsp.close()
        container.close()
    return nszPath
Example #15
0
class TinyIndexer(TinyIndexBase):
    def __init__(self, item_type: type, index_path: str, num_pages: int,
                 page_size: int):
        super().__init__(item_type, num_pages, page_size)
        self.index_path = index_path
        self.compressor = ZstdCompressor()
        self.decompressor = ZstdDecompressor()
        self.index_file = None
        self.mmap = None

    def __enter__(self):
        self.create_if_not_exists()
        self.index_file = open(self.index_path, 'r+b')
        self.mmap = mmap(self.index_file.fileno(), 0)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.mmap.close()
        self.index_file.close()

    # def index(self, documents: List[TokenizedDocument]):
    #     for document in documents:
    #         for token in document.tokens:
    #             self._index_document(document, token)

    def index(self, key: str, value):
        print("Index", value)
        assert type(value) == self.item_type, f"Can only index the specified type" \
                                              f" ({self.item_type.__name__})"
        page_index = self._get_key_page_index(key)
        current_page = self.get_page(page_index)
        if current_page is None:
            current_page = []
        value_tuple = astuple(value)
        print("Value tuple", value_tuple)
        current_page.append(value_tuple)
        try:
            # print("Page", current_page)
            self._write_page(current_page, page_index)
        except ValueError:
            pass

    def _write_page(self, data, i):
        """
        Serialise the data using JSON, compress it and store it at index i.
        If the data is too big, it will raise a ValueError and not store anything
        """
        serialised_data = json.dumps(data)
        compressed_data = self.compressor.compress(
            serialised_data.encode('utf8'))
        page_length = len(compressed_data)
        if page_length > self.page_size:
            raise ValueError(
                f"Data is too big ({page_length}) for page size ({self.page_size})"
            )
        padding = b'\x00' * (self.page_size - page_length)
        self.mmap[i * self.page_size:(i + 1) *
                  self.page_size] = compressed_data + padding

    def create_if_not_exists(self):
        if not os.path.isfile(self.index_path):
            file_length = self.num_pages * self.page_size
            with open(self.index_path, 'wb') as index_file:
                index_file.write(b'\x00' * file_length)
Example #16
0
 def __init__(self):
     self.compressor = ZstdCompressor()
     self.decompressor = ZstdDecompressor()
Example #17
0
def processContainer(readContainer, writeContainer, compressionLevel, threads,
                     stusReport, id, pleaseNoPrint):
    for nspf in readContainer:
        if isinstance(
                nspf,
                Nca.Nca) and nspf.header.contentType == Type.Content.DATA:
            Print.info('Skipping delta fragment {0}'.format(nspf._path))
            continue

        if isinstance(nspf, Nca.Nca) and (
                nspf.header.contentType == Type.Content.PROGRAM
                or nspf.header.contentType == Type.Content.PUBLICDATA
        ) and nspf.size > UNCOMPRESSABLE_HEADER_SIZE:
            if isNcaPacked(nspf):

                offsetFirstSection = sortedFs(nspf)[0].offset
                newFileName = nspf._path[0:-1] + 'z'

                with writeContainer.add(newFileName, nspf.size,
                                        pleaseNoPrint) as f:
                    start = f.tell()

                    nspf.seek(0)
                    f.write(nspf.read(UNCOMPRESSABLE_HEADER_SIZE))

                    sections = []
                    for fs in sortedFs(nspf):
                        sections += fs.getEncryptionSections()

                    if len(sections) == 0:
                        raise Exception(
                            "NCA can't be decrypted. Outdated keys.txt?")

                    header = b'NCZSECTN'
                    header += len(sections).to_bytes(8, 'little')

                    i = 0
                    for fs in sections:
                        i += 1
                        header += fs.offset.to_bytes(8, 'little')
                        header += fs.size.to_bytes(8, 'little')
                        header += fs.cryptoType.to_bytes(8, 'little')
                        header += b'\x00' * 8
                        header += fs.cryptoKey
                        header += fs.cryptoCounter

                    f.write(header)

                    blockID = 0
                    chunkRelativeBlockID = 0
                    startChunkBlockID = 0
                    blocksHeaderFilePos = f.tell()
                    compressedblockSizeList = []

                    decompressedBytes = UNCOMPRESSABLE_HEADER_SIZE

                    stusReport[id] = [0, 0, nspf.size]

                    partitions = []
                    if offsetFirstSection - UNCOMPRESSABLE_HEADER_SIZE > 0:
                        partitions.append(
                            nspf.partition(offset=UNCOMPRESSABLE_HEADER_SIZE,
                                           size=offsetFirstSection -
                                           UNCOMPRESSABLE_HEADER_SIZE,
                                           cryptoType=Type.Crypto.CTR.NONE,
                                           autoOpen=True))
                    for section in sections:
                        #Print.info('offset: %x\t\tsize: %x\t\ttype: %d\t\tiv%s' % (section.offset, section.size, section.cryptoType, str(hx(section.cryptoCounter))), pleaseNoPrint)
                        partitions.append(
                            nspf.partition(offset=section.offset,
                                           size=section.size,
                                           cryptoType=section.cryptoType,
                                           cryptoKey=section.cryptoKey,
                                           cryptoCounter=bytearray(
                                               section.cryptoCounter),
                                           autoOpen=True))
                    if UNCOMPRESSABLE_HEADER_SIZE - offsetFirstSection > 0:
                        partitions[0].seek(UNCOMPRESSABLE_HEADER_SIZE -
                                           offsetFirstSection)

                    partNr = 0
                    stusReport[id] = [nspf.tell(), f.tell(), nspf.size]
                    if threads > 1:
                        cctx = ZstdCompressor(level=compressionLevel,
                                              threads=threads)
                    else:
                        cctx = ZstdCompressor(level=compressionLevel)
                    compressor = cctx.stream_writer(f)
                    while True:

                        buffer = partitions[partNr].read(CHUNK_SZ)
                        while (len(buffer) < CHUNK_SZ
                               and partNr < len(partitions) - 1):
                            partitions[partNr].close()
                            partitions[partNr] = None
                            partNr += 1
                            buffer += partitions[partNr].read(CHUNK_SZ -
                                                              len(buffer))
                        if len(buffer) == 0:
                            break
                        compressor.write(buffer)

                        decompressedBytes += len(buffer)
                        stusReport[id] = [nspf.tell(), f.tell(), nspf.size]
                    partitions[partNr].close()
                    partitions[partNr] = None

                    compressor.flush(FLUSH_FRAME)
                    compressor.flush(COMPRESSOBJ_FLUSH_FINISH)
                    stusReport[id] = [nspf.tell(), f.tell(), nspf.size]

                    written = f.tell() - start
                    Print.info(
                        'Compressed {0}% {1} -> {2}  - {3}'.format(
                            written * 100 / nspf.size, decompressedBytes,
                            written, nspf._path), pleaseNoPrint)
                    writeContainer.resize(newFileName, written)
                    continue
            else:
                Print.info('Skipping not packed {0}'.format(nspf._path))

        with writeContainer.add(nspf._path, nspf.size, pleaseNoPrint) as f:
            nspf.seek(0)
            while not nspf.eof():
                buffer = nspf.read(CHUNK_SZ)
                f.write(buffer)
Example #18
0
def upload_test_dataset(contest_id: str, problem_id: str) -> Response:
    from collections import Counter
    import shutil
    from zipfile import ZipFile
    from contextlib import ExitStack
    from tempfile import TemporaryFile

    zctx = ZstdCompressor()
    test_cases = []
    ret = []
    with ExitStack() as stack:
        f = stack.enter_context(TemporaryFile())
        shutil.copyfileobj(request.stream, f)
        f.seek(0)
        z = stack.enter_context(ZipFile(f))
        counts = Counter()  # type: ignore
        path_mapping = {}
        for x in z.namelist():
            if not (x.endswith('.in') or x.endswith('.out')):
                continue
            name = os.path.basename(x)
            counts.update([os.path.splitext(name)[0]])
            path_mapping[name] = x

        for k, v in counts.items():
            if v != 2:
                continue
            try:
                with z.open(path_mapping[k + '.in']) as zi:
                    in_data = zctx.compress(zi.read())
                with z.open(path_mapping[k + '.out']) as zo:
                    out_data = zctx.compress(zo.read())
            except Exception:
                continue
            test_cases.append(
                dict(contest_id=contest_id,
                     problem_id=problem_id,
                     id=k,
                     input=in_data,
                     output=out_data))
            ret.append(k)

    # 参照がないテストケースのみを削除し、
    # 参照があるテストケースはUPDATE、
    # 新規テストケースはINSERTする
    from sqlalchemy import exists
    with transaction() as s:
        _ = _validate_token(s, admin_required=True)
        s.query(TestCase).filter(
            TestCase.contest_id == contest_id,
            TestCase.problem_id == problem_id,
            ~exists().where(JudgeResult.test_id == TestCase.id)).delete(
                synchronize_session=False)
        for kwargs in test_cases:
            q = s.query(TestCase).filter(TestCase.contest_id == contest_id,
                                         TestCase.problem_id == problem_id,
                                         TestCase.id == kwargs['id'])
            if q.count() == 0:
                s.add(TestCase(**kwargs))
            else:
                kwargs.pop('contest_id')
                kwargs.pop('problem_id')
                kwargs.pop('id')
                q.update(kwargs, synchronize_session=False)
    return jsonify(ret)
Example #19
0
 def compress(data):
     # ZstdCompressor is not thread safe.
     # TODO: Use a pool?
     return ZstdCompressor().compress(data)
Example #20
0
class AtlasRecordsWriter:
    """
    Write Atlas results in ND-JSON format.

    .. code-block:: python

        from fetchmesh.io import AtlasRecordsWriter
        with AtlasRecordsWriter("results.ndjson") as w:
            w.write({"msm_id": 1001, "prb_id": 1, "...": "..."})
    """

    file: Path
    """Output file path."""

    filters: List[StreamFilter[dict]] = field(default_factory=list)
    """List of filters to apply before writing the records."""

    append: bool = False
    """
    Whether to create a new file, or to append the records to an existing file.
    If append is set to false, and the output file already exists, it will be deleted.
    When append is set to false, the output file will be deleted if an exception happens.
    """

    log: bool = False
    """Record the size (in bytes) of each record. See :any:`LogEntry`."""

    compression: bool = False
    """
    Compresse the records using zstandard.
    We use the one-shot compression API and write one frame per record.
    This results in larger files than a single frame for all the records,
    but it allows us to build an index and make the file seekable.
    We use a pre-built dictionary (see :any:`dictionary`) to reduce the size of the compressed records.
    """

    compression_ctx: Optional[ZstdCompressor] = field(default=None, init=False)

    @property
    def log_file(self) -> Path:
        """Path to the (optional) log file."""
        return self.file.with_suffix(self.file.suffix + ".log")

    def __post_init__(self):
        self.file = Path(self.file)

    def __enter__(self):
        mode = "ab" if self.append else "wb"

        # (1) Open the output file
        self.f = self.file.open(mode)

        # (2) Open the log file
        if self.log:
            self.log_f = self.log_file.open(mode)

        # (3) Setup the compression context
        if self.compression:
            dict_data = ZstdCompressionDict(dictionary.read_bytes())
            self.compression_ctx = ZstdCompressor(dict_data=dict_data)

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        # (1) Close the output file
        self.f.flush()
        self.f.close()

        # (2) Close the log file
        if self.log:
            self.log_f.flush()
            self.log_f.close()

        # (3) Handle exceptions
        # NOTE: In append mode, we do not delete the file.
        if exc_type:
            if not self.append:
                self.file.unlink()
            if not self.append and self.log_file.exists():
                self.log_file.unlink()
            print_exception(exc_type, exc_value, traceback)

        # Do not reraise exceptions, excepted for KeyboardInterrupt.
        return exc_type is not KeyboardInterrupt

    def write(self, record: dict):
        """Write a single record."""

        # (1) Filter the record
        for filter_ in self.filters:
            if not filter_.keep(record):
                return

        # (2) Serialize and encode the record
        data = json_trydumps(record) + "\n"
        data = data.encode("utf-8")

        # (3) Compresse the record
        if self.compression_ctx:
            data = self.compression_ctx.compress(data)

        # (4) Update the log
        if self.log:
            entry = LogEntry.pack(len(data), record["msm_id"],
                                  record["prb_id"])
            self.log_f.write(entry)

        # (5) Write the record to the output file
        self.f.write(data)

    def writeall(self, records: Iterable[dict]):
        """Write all the records."""

        for record in records:
            self.write(record)