Ejemplo n.º 1
0
 def decodeHeader(self, last_chunk_len_format: str = "I") -> None:
     if self.headerChunk is not None:
         return  # Header already set
     for decoded in self.degreeToPacket[1]:
         if decoded.get_used_packets().issubset({0}):
             self.headerChunk = HeaderChunk(
                 decoded, last_chunk_len_format=last_chunk_len_format)
Ejemplo n.º 2
0
 def saveDecodedFile(self,
                     last_chunk_len_format: str = "I",
                     null_is_terminator: bool = False,
                     print_to_output: bool = False) -> None:
     assert self.is_decoded(), "Can not save File: Unable to reconstruct."
     if self.use_headerchunk:
         self.headerChunk = HeaderChunk(
             OnlinePacket(self.GEPP.b[0],
                          self.number_of_chunks,
                          self.quality,
                          self.epsilon,
                          0, {0},
                          self.dist,
                          read_only=True),
             last_chunk_len_format=last_chunk_len_format)
     file_name = "DEC_" + os.path.basename(
         self.file) if self.file is not None else "ONLINE.BIN"
     output_concat = b""
     if self.headerChunk is not None:
         file_name = self.headerChunk.get_file_name().decode("utf-8")
     file_name = file_name.split("\x00")[0]
     with open(file_name, "wb") as f:
         # for decoded in sorted(self.degreeToPacket[1]):
         for x in self.GEPP.result_mapping:
             if 0 != x or not self.use_headerchunk:
                 if x == self.number_of_chunks - 1 and self.use_headerchunk:
                     output = self.GEPP.b[x][0][0:self.headerChunk.
                                                get_last_chunk_length()]
                     output_concat += output.tobytes()
                     f.write(output)
                 else:
                     if null_is_terminator:
                         splitter = self.GEPP.b[x].tostring().decode(
                         ).split("\x00")
                         output = splitter[0].encode()
                         output_concat += output
                         f.write(output)
                         if len(splitter) > 1:
                             break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                     else:
                         output = self.GEPP.b[x]
                         try:
                             output_concat += output.tobytes()
                         except TypeError as te:
                             raise te
                         f.write(output)
     print("Saved file as '" + str(file_name) + "'")
     if print_to_output:
         print("Result:")
         print(output_concat.decode("utf-8"))
Ejemplo n.º 3
0
    def saveDecodedFile(self, last_chunk_len_format: str = "I", null_is_terminator: bool = False,
                        print_to_output: bool = True) -> None:
        assert self.is_decoded(), "Can not save File: Unable to reconstruct."
        file_name: str = "DEC_" + self.file.split("\x00")[0]  # split is needed for weird  MAC / Windows bugs...
        sort_list: typing.List = sorted(self.degreeToPacket[1])
        if 0 in sort_list[0].get_used_packets() and self.use_headerchunk:
            self.headerChunk = HeaderChunk(sort_list[0])
        output_concat: bytes = b""
        if self.headerChunk is not None:
            file_name = self.headerChunk.get_file_name().decode("utf-8")
        with open(file_name, "wb") as f:
            for decoded in sort_list:
                if 0 in decoded.get_used_packets() and self.use_headerchunk:
                    self.headerChunk = HeaderChunk(decoded, last_chunk_len_format=last_chunk_len_format)
                else:
                    if self.number_of_chunks - 1 in decoded.get_used_packets() and self.use_headerchunk:
                        output = decoded.get_data()[0: self.headerChunk.get_last_chunk_length()]
                        if type(output) == bytes:
                            output_concat += output
                        else:
                            output_concat += output.tobytes()
                        f.write(output)
                    else:
                        if null_is_terminator:
                            data = decoded.get_data()
                            if type(data) == bytes:
                                splitter = data.decode().split("\x00")
                            else:
                                splitter = data.tostring().decode().split("\x00")
                            output = splitter[0].encode()
                            output_concat += output
                            f.write(output)
                            if len(splitter) > 1:
                                break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                        else:
                            output = decoded.get_data()
                            if type(output) == np.ndarray or type(output) != bytes:
                                output_concat += output.tobytes()
                            else:
                                output_concat += output
                            f.write(output)

        print("Saved file as '" + str(file_name) + "'")
        if print_to_output:
            print("Result:")
            print(output_concat.decode("utf-8"))
Ejemplo n.º 4
0
 def saveDecodedFile(self, last_chunk_len_format: str = "I", null_is_terminator: bool = False,
                     print_to_output: bool = True) -> None:
     assert self.is_decoded(), "Can not save File: Unable to reconstruct."
     if self.use_headerchunk:
         self.headerChunk = HeaderChunk(Packet(self.GEPP.b[0], {0}, self.number_of_chunks, read_only=True),
                                        last_chunk_len_format=last_chunk_len_format)
     file_name = "DEC_" + os.path.basename(self.file) if self.file is not None else "LT.BIN"
     if self.headerChunk is not None:
         file_name = self.headerChunk.get_file_name().decode("utf-8")
     output_concat: bytes = b""
     file_name: str = file_name.split("\x00")[0]
     try:
         with open(file_name, "wb") as f:
             for x in self.GEPP.result_mapping:
                 if 0 != x or not self.use_headerchunk:
                     if self.number_of_chunks - 1 == x and self.use_headerchunk:
                         output: typing.Union[bytes, np.array] = self.GEPP.b[x][0][
                                                                 0: self.headerChunk.get_last_chunk_length()]
                         output_concat += output.tobytes()
                         f.write(output)
                     else:
                         if null_is_terminator:
                             splitter: str = self.GEPP.b[x].tostring().decode().split("\x00")
                             output = splitter[0].encode()
                             if type(output) == bytes:
                                 output_concat += output
                             else:
                                 output_concat += output.tobytes()
                             f.write(output)
                             if len(splitter) > 1:
                                 break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                         else:
                             output = self.GEPP.b[x]
                             output_concat += output.tobytes()
                             f.write(output)
         print("Saved file as '" + str(file_name) + "'")
     except Exception as ex:
         raise ex
     if print_to_output:
         print("Result:")
         print(output_concat.decode("utf-8"))
Ejemplo n.º 5
0
class OnlineDecoder(Decoder):
    def __init__(self,
                 file: typing.Optional[str] = None,
                 error_correction: typing.Callable[[typing.Any],
                                                   typing.Any] = nocode,
                 use_headerchunk: bool = True,
                 static_number_of_chunks: typing.Optional[int] = None,
                 read_all=True):
        super().__init__(file)
        self.debug: bool = False
        self.isPseudo: bool = False
        self.file: str = file
        self.decodedPackets: typing.Set[OnlinePacket] = set()
        self.degreeToPacket: typing.Dict[int, typing.Set[OnlinePacket]] = {}
        if file is not None:
            self.isFolder: bool = os.path.isdir(file)
            if not self.isFolder:
                self.f = open(self.file, "rb")
        self.correct: int = 0
        self.corrupt: int = 0
        self.rng: numpy.random = numpy.random
        self.number_of_chunks: int = 1000000
        self.headerChunk: typing.Optional[HeaderChunk] = None
        self.auxBlockNumbers: typing.Dict[int, typing.Set[int]] = dict()
        self.auxBlocks: typing.Dict[int, OnlineAuxPacket] = dict()
        self.GEPP: typing.Optional[GEPP] = None
        self.dist: typing.Optional[typing.Union[OnlineDistribution,
                                                Distribution]] = None
        self.read_all_before_decode: bool = read_all
        self.numberOfDecodedAuxBlocks: int = 0
        self.do_count: bool = True
        self.counter: typing.Dict[int, int] = dict()
        self.error_correction: typing.Callable[[typing.Any],
                                               typing.Any] = error_correction
        self.use_headerchunk: bool = use_headerchunk
        self.static_number_of_chunks: int = static_number_of_chunks
        self.EOF: bool = False
        self.quality: int = 0
        self.epsilon: float = 0.0

    def decodeFolder(self,
                     packet_len_format: str = "I",
                     crc_len_format: str = "L",
                     number_of_chunks_len_format: str = "I",
                     quality_len_format: str = "I",
                     epsilon_len_format: str = "f",
                     check_block_number_len_format: str = "I"):
        decoded: bool = False
        self.EOF: bool = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        for dir_file in os.listdir(self.file):
            if dir_file.endswith(".ONLINE") or dir_file.endswith("DNA"):
                self.EOF = False
                if dir_file.endswith("DNA"):
                    self.f = quat_file_to_bin(self.file + "/" + dir_file)
                else:
                    self.f = open(self.file + "/" + dir_file, "rb")
                new_pack = self.getNextValidPacket(
                    True,
                    packet_len_format=packet_len_format,
                    crc_len_format=crc_len_format,
                    number_of_chunks_len_format=number_of_chunks_len_format,
                    quality_len_format=quality_len_format,
                    epsilon_len_format=epsilon_len_format,
                    check_block_number_len_format=check_block_number_len_format
                )
                if new_pack is not None:
                    decoded = self.input_new_packet(new_pack)
                if decoded:
                    break
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        if self.GEPP is not None and self.GEPP.isPotentionallySolvable():
            decoded = self.GEPP.solve()
        if hasattr(self, "f"):
            self.f.close()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def decodeFile(self,
                   packet_len_format: str = "I",
                   crc_len_format: str = "L",
                   number_of_chunks_len_format: str = "I",
                   quality_len_format: str = "I",
                   epsilon_len_format: str = "f",
                   check_block_number_len_format: str = "I"):
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        decoded: bool = False
        self.EOF: bool = False
        if self.file.lower().endswith("fasta"):
            self.f.close()
            self.f = open(self.file, "r")
            raw_packet_list = []
            while not (decoded or self.EOF):
                line = self.f.readline()
                if not line:
                    self.EOF = True
                    break
                try:
                    error_prob, seed = line[1:].replace("\n", "").split("_")
                except:
                    error_prob, seed = "0", "0"
                line = self.f.readline()
                if not line:
                    self.EOF = True
                    break
                dna_str = line.replace("\n", "")
                raw_packet_list.append((error_prob, seed, dna_str))
                new_pack = self.parse_raw_packet(
                    BytesIO(tranlate_quat_to_byte(dna_str)).read(),
                    crc_len_format=crc_len_format,
                    number_of_chunks_len_format=number_of_chunks_len_format,
                    epsilon_len_format=epsilon_len_format,
                    quality_len_format=quality_len_format,
                    check_block_number_len_format=check_block_number_len_format
                )
                decoded = self.input_new_packet(new_pack)
                if self.progress_bar is not None:
                    self.progress_bar.update(self.correct,
                                             Corrupt=self.corrupt)
            else:
                while not (decoded or self.EOF):
                    new_pack = self.getNextValidPacket(
                        False,
                        packet_len_format=packet_len_format,
                        crc_len_format=crc_len_format,
                        number_of_chunks_len_format=number_of_chunks_len_format,
                        quality_len_format=quality_len_format,
                        epsilon_len_format=epsilon_len_format,
                        check_block_number_len_format=
                        check_block_number_len_format)
                    if new_pack is None:
                        break
                    decoded = self.input_new_packet(new_pack)
                    ##
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        if self.GEPP.isPotentionallySolvable():
            return self.GEPP.solve()
        self.f.close()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def createAuxBlocks(self) -> None:
        assert self.number_of_chunks is not None, "createAuxBlocks can only be called AFTER first Packet"
        self.rng.seed(self.number_of_chunks)
        if self.debug:
            print("We should have " + str(self.getNumberOfAuxBlocks()) +
                  " Aux-Blocks and " + str(self.number_of_chunks) +
                  " normal Chunks (+ 1 HeaderChunk)")
        for i in range(0, self.getNumberOfAuxBlocks()):
            self.auxBlockNumbers[i] = set()
        for chunk_no in range(
                0, self.number_of_chunks
        ):  # + (1 if self.use_headerchunk else 0)):  # + 1 for HeaderChunk
            # Insert this Chunk into quality different Aux-Packets
            for i in range(0, self.quality):
                # uniform choose a number of aux blocks
                aux_no = self.rng.randint(0, self.getNumberOfAuxBlocks())
                self.auxBlockNumbers[aux_no].add(chunk_no)

        # XOR all Chunks into the corresponding AUX-Block
        for aux_number in self.auxBlockNumbers.keys():
            self.auxBlocks[aux_number] = OnlineAuxPacket(
                b"",
                self.auxBlockNumbers[aux_number],
                aux_number=aux_number,
                total_number_of_chunks=self.number_of_chunks
            )  # , numberOfAuxPackets=self.getNumberOfAuxBlocks()) # We will add the Data once we have it.

    def getAuxPacketListFromPacket(
            self, packet: OnlinePacket) -> typing.List[typing.List[bool]]:
        res: typing.List[typing.List[bool]] = []
        aux_used_packets = packet.getBoolArrayAuxPackets()
        i = 0
        for aux in aux_used_packets:
            if aux:
                res.append(self.auxBlocks[i].get_bool_array_used_packets())
            i += 1
        return res

    def removeAndXorAuxPackets(
            self, packet: OnlinePacket) -> typing.List[typing.List[bool]]:
        aux_mapping = self.getAuxPacketListFromPacket(packet)
        aux_mapping.append(packet.get_bool_array_used_packets())
        return logical_xor(aux_mapping)

    def input_new_packet(self, packet: OnlinePacket) -> bool:
        if self.isPseudo and self.auxBlocks == dict():
            self.number_of_chunks = packet.get_total_number_of_chunks()
            self.quality = packet.getQuality()
            self.epsilon = round(packet.getEpsilon(), 6)
            self.dist: Distribution = OnlineDistribution(self.epsilon)
            self.createAuxBlocks()
        removed: typing.List[typing.List[bool]] = self.removeAndXorAuxPackets(
            packet)
        if self.do_count:
            for i in range(len(removed)):
                if i in self.counter.keys():
                    if removed[i]:
                        self.counter[i] += 1
                else:
                    self.counter[i] = 1
        if self.GEPP is None:
            self.GEPP: GEPP = GEPP(
                numpy.array([removed], dtype=bool),
                numpy.array([[packet.get_data()]], dtype=bytes),
            )
        else:
            self.GEPP.addRow(
                self.removeAndXorAuxPackets(packet),
                numpy.frombuffer(packet.get_data(), dtype="uint8"),
            )
        if self.isPseudo and not self.read_all_before_decode and (
                self.GEPP.isPotentionallySolvable() and self.GEPP.n % 25 == 0):
            if self.debug:
                print("current size: " + str(self.GEPP.n))
            return self.GEPP.solve(partial=False)
        return False

    def solve(self) -> bool:
        return self.GEPP.solve()

    def getSolvedCount(self) -> int:
        return self.GEPP.getSolvedCount()

    def is_decoded(self) -> bool:
        return self.GEPP is not None and self.GEPP.isPotentionallySolvable(
        ) and self.GEPP.isSolved()

    def getNextValidPacket(
        self,
        from_multiple_files: bool = False,
        packet_len_format: str = "I",
        crc_len_format: str = "L",
        number_of_chunks_len_format: str = "I",
        quality_len_format: str = "I",
        epsilon_len_format: str = "f",
        check_block_number_len_format: str = "I"
    ) -> typing.Optional[typing.Union[str, OnlinePacket]]:
        if not from_multiple_files:
            packet_len = self.f.read(struct.calcsize("<" + packet_len_format))
            packet_len = struct.unpack("<" + packet_len_format, packet_len)[0]
            packet = self.f.read(int(packet_len))
        else:
            packet = self.f.read()
            packet_len = len(packet)
        if not packet or not packet_len:  # EOF
            self.EOF: bool = True
            self.f.close()
            return None
        res = self.parse_raw_packet(
            packet,
            crc_len_format=crc_len_format,
            number_of_chunks_len_format=number_of_chunks_len_format,
            quality_len_format=quality_len_format,
            epsilon_len_format=epsilon_len_format,
            check_block_number_len_format=check_block_number_len_format)
        if res == "CORRUPT":
            res = self.getNextValidPacket(
                from_multiple_files,
                packet_len_format=packet_len_format,
                crc_len_format=crc_len_format,
                number_of_chunks_len_format=number_of_chunks_len_format,
                quality_len_format=quality_len_format,
                epsilon_len_format=epsilon_len_format,
                check_block_number_len_format=check_block_number_len_format)
        return res

    def parse_raw_packet(self, packet: bytes, crc_len_format="L", number_of_chunks_len_format="I",
                         quality_len_format="I", epsilon_len_format="f", check_block_number_len_format="I") -> \
            typing.Optional[typing.Union[str, OnlinePacket]]:
        crc_len = -struct.calcsize("<" + crc_len_format)
        if self.error_correction.__code__.co_name == crc32.__code__.co_name:
            payload = packet[:crc_len]
            crc: int = struct.unpack("<" + crc_len_format, packet[crc_len:])[0]
            calced_crc: int = calc_crc(payload)
            if crc != calced_crc:  # If the Packet is corrupt, try next one
                print("[-] CRC-Error - " + str(hex(crc)) + " != " +
                      str(hex(calced_crc)))
                self.corrupt += 1
                return "CORRUPT"
        else:
            crc_len = None
            try:
                packet = self.error_correction(packet)
            except:
                return "CORRUPT"  # if RS or other error correction cannot reconstruct this packet
        struct_str: str = "<" + number_of_chunks_len_format + quality_len_format + epsilon_len_format + check_block_number_len_format
        struct_len: int = struct.calcsize(struct_str)
        data: bytes = packet[struct_len:crc_len]
        len_data: typing.Union[typing.Tuple[int, int, float, int],
                               typing.Tuple[int, float, int]] = struct.unpack(
                                   struct_str, packet[0:struct_len])
        if self.static_number_of_chunks is None:
            number_of_chunks, quality, self.epsilon, check_block_number = len_data
            self.number_of_chunks = xor_mask(number_of_chunks,
                                             number_of_chunks_len_format)
        else:
            quality, self.epsilon, check_block_number = len_data
            self.epsilon = round(self.epsilon, 6)
        self.quality = xor_mask(quality, quality_len_format)
        if self.dist is None:
            self.dist = OnlineDistribution(self.epsilon)
        if self.correct == 0:
            self.createAuxBlocks()
            # Create MockUp AuxBlocks with the given Pseudo-Random Number -> we will know which Packets are Encoded in which AuxBlock
        self.correct += 1
        res = OnlinePacket(
            data,
            self.number_of_chunks,
            self.quality,
            self.epsilon,
            check_block_number,
            dist=self.dist,
            read_only=True,
            error_correction=self.error_correction,
            crc_len_format=crc_len_format,
            number_of_chunks_len_format=number_of_chunks_len_format,
            quality_len_format=quality_len_format,
            epsilon_len_format=epsilon_len_format,
            check_block_number_len_format=check_block_number_len_format,
            save_number_of_chunks_in_packet=self.static_number_of_chunks is
            None)
        return res

    def decodeHeader(self, last_chunk_len_format: str = "I") -> None:
        if self.headerChunk is not None:
            return  # Header already set
        for decoded in self.degreeToPacket[1]:
            if decoded.get_used_packets().issubset({0}):
                self.headerChunk = HeaderChunk(
                    decoded, last_chunk_len_format=last_chunk_len_format)

    def saveDecodedFile(self,
                        last_chunk_len_format: str = "I",
                        null_is_terminator: bool = False,
                        print_to_output: bool = False) -> None:
        assert self.is_decoded(), "Can not save File: Unable to reconstruct."
        if self.use_headerchunk:
            self.headerChunk = HeaderChunk(
                OnlinePacket(self.GEPP.b[0],
                             self.number_of_chunks,
                             self.quality,
                             self.epsilon,
                             0, {0},
                             self.dist,
                             read_only=True),
                last_chunk_len_format=last_chunk_len_format)
        file_name = "DEC_" + os.path.basename(
            self.file) if self.file is not None else "ONLINE.BIN"
        output_concat = b""
        if self.headerChunk is not None:
            file_name = self.headerChunk.get_file_name().decode("utf-8")
        file_name = file_name.split("\x00")[0]
        with open(file_name, "wb") as f:
            # for decoded in sorted(self.degreeToPacket[1]):
            for x in self.GEPP.result_mapping:
                if 0 != x or not self.use_headerchunk:
                    if x == self.number_of_chunks - 1 and self.use_headerchunk:
                        output = self.GEPP.b[x][0][0:self.headerChunk.
                                                   get_last_chunk_length()]
                        output_concat += output.tobytes()
                        f.write(output)
                    else:
                        if null_is_terminator:
                            splitter = self.GEPP.b[x].tostring().decode(
                            ).split("\x00")
                            output = splitter[0].encode()
                            output_concat += output
                            f.write(output)
                            if len(splitter) > 1:
                                break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                        else:
                            output = self.GEPP.b[x]
                            try:
                                output_concat += output.tobytes()
                            except TypeError as te:
                                raise te
                            f.write(output)
        print("Saved file as '" + str(file_name) + "'")
        if print_to_output:
            print("Result:")
            print(output_concat.decode("utf-8"))

    def getNumberOfAuxBlocks(self) -> int:
        return ceil(0.55 * self.quality * self.epsilon * self.number_of_chunks)
Ejemplo n.º 6
0
class LTBPDecoder(BPDecoder):
    def __init__(self, file: typing.Optional[str] = None, error_correction: typing.Callable = nocode,
                 use_headerchunk: bool = True, static_number_of_chunks: typing.Optional[int] = None,
                 implicit_mode: bool = True, dist: typing.Optional[Distribution] = None):
        super().__init__(file, error_correction, use_headerchunk, static_number_of_chunks)
        self.implicit_mode: bool = implicit_mode
        self.use_headerchunk: bool = use_headerchunk
        self.file: str = file
        self.decodedPackets: typing.Dict = {}
        self.degreeToPacket: typing.Dict[int, typing.Set] = {}
        if file is not None:
            self.isFolder: bool = os.path.isdir(file)
            if not self.isFolder:
                self.f = open(self.file, "rb")
        self.correct: int = 0
        self.corrupt: int = 0
        self.number_of_chunks: int = 1000000
        self.headerChunk: typing.Optional[HeaderChunk] = None
        self.queue: deque = deque()
        self.error_correction: typing.Callable = error_correction
        self.static_number_of_chunks: int = static_number_of_chunks
        self.dist: typing.Optional[Distribution] = dist  # if implicit_mode is True, dist MUST be != None

    def decodeFolder(self, packet_len_format: str = "I", crc_len_format: str = "L",
                     number_of_chunks_len_format: str = "I", degree_len_format: str = "I", seed_len_format: str = "I",
                     last_chunk_len_format: str = "I") -> typing.Optional[int]:
        decoded: bool = False
        self.EOF: bool = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        if self.implicit_mode:
            degree_len_format = ""
        for file in os.listdir(self.file):
            if file.endswith(".LT") or file.endswith("DNA"):
                self.EOF = False
                if file.endswith("DNA"):
                    self.f = quat_file_to_bin(self.file + "/" + file)
                else:
                    self.f = open(self.file + "/" + file, "rb")
                new_pack = self.getNextValidPacket(True, packet_len_format=packet_len_format,
                                                   crc_len_format=crc_len_format,
                                                   number_of_chunks_len_format=number_of_chunks_len_format,
                                                   degree_len_format=degree_len_format, seed_len_format=seed_len_format,
                                                   last_chunk_len_format=last_chunk_len_format)
                if new_pack is not None:
                    ## koennte durch input_new_packet ersetzt werden:
                    self.addPacket(new_pack)
                    decoded = self.updatePackets(new_pack)
                if decoded:
                    break
                ##
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        if hasattr(self, "f"):
            self.f.close()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def decodeFile(self, packet_len_format: str = "I", crc_len_format: str = "L",
                   number_of_chunks_len_format: str = "I", degree_len_format: str = "I", seed_len_format: str = "I",
                   last_chunk_len_format: str = "I") -> typing.Optional[int]:
        decoded: bool = False
        self.EOF: bool = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        if self.implicit_mode:
            degree_len_format = ""
        while not (decoded or self.EOF):
            new_pack: Packet = self.getNextValidPacket(False, packet_len_format=packet_len_format,
                                                       crc_len_format=crc_len_format,
                                                       number_of_chunks_len_format=number_of_chunks_len_format,
                                                       degree_len_format=degree_len_format,
                                                       seed_len_format=seed_len_format,
                                                       last_chunk_len_format=last_chunk_len_format)
            if new_pack is None:
                break
            ## koennte durch input_new_packet ersetzt werden:
            self.addPacket(new_pack)
            decoded = self.updatePackets(new_pack)
            ##
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        self.f.close()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def input_new_packet(self, packet: Packet) -> bool:
        """ Used for easy PseudoDecode """
        if not isinstance(packet, DecodePacket):
            packet: Packet = DecodePacket.from_packet(packet)
        self.addPacket(packet)
        return self.updatePackets(packet)

    def addPacket(self, packet: Packet) -> None:
        if (packet.get_degree() not in self.degreeToPacket) or (
                not isinstance(self.degreeToPacket[packet.get_degree()], set)):
            self.degreeToPacket[packet.get_degree()] = set()
        self.number_of_chunks = packet.get_total_number_of_chunks()
        self.degreeToPacket[packet.get_degree()].add(packet)

    def updatePackets(self, packet: Packet) -> bool:
        self.queue.append(packet)
        finished: bool = False
        while len(self.queue) > 0 and not finished:
            finished = self.reduceAll(self.queue.popleft())
        return finished

    def compareAndReduce(self, packet: Packet, other: Packet) -> typing.Union[bool, int]:
        if self.file is None:  # In case of PseudoDecode: DO NOT REALLY COMPUTE XOR
            packet.remove_packets(other.get_used_packets())
        else:
            packet.xor_and_remove_packet(other)
        degree = packet.get_degree()
        if (degree not in self.degreeToPacket) or (not isinstance(self.degreeToPacket[degree], set)):
            self.degreeToPacket[degree] = set()
        self.degreeToPacket[degree].add(packet)
        if self.is_decoded():
            return True
        self.queue.append(packet)
        return degree

    """def reduceAll(self, packet: Packet) -> bool:
        # looup all packets for this to solve with ( when this packet has a subset of used Packets)
        fin: bool = False

        lookup: typing.List[int] = [i for i in self.degreeToPacket.keys() if packet.get_degree() < i]
        for i in lookup:
            if not isinstance(self.degreeToPacket[i], set):
                self.degreeToPacket[i] = set()
            for p in self.degreeToPacket[i].copy():
                p_used = p.get_used_packets()
                pack_used = packet.get_used_packets()
                if len(pack_used) < len(p_used) and pack_used.issubset(p_used):
                    self.degreeToPacket[i].remove(p)
                    degree = self.compareAndReduce(p, packet)
                    if isinstance(degree, bool) and degree:
                        return degree
        degree = packet.get_degree()
        lookup = [i for i in self.degreeToPacket.keys() if packet.get_degree() > i]
        for i in lookup:
            if not isinstance(self.degreeToPacket[i], set):
                self.degreeToPacket[i] = set()
            for p in self.degreeToPacket[i].copy():
                p_used = p.get_used_packets()
                pack_used = packet.get_used_packets()
                if len(pack_used) > len(p_used) and p_used.issubset(pack_used):
                    try:
                        self.degreeToPacket[degree].remove(packet)
                        degree = self.compareAndReduce(packet, p)
                        if isinstance(degree, bool) and degree:
                            return degree
                    except Exception:
                        continue
        return fin or self.is_decoded()"""

    def is_decoded(self) -> bool:
        return (1 in self.degreeToPacket and len(
            self.degreeToPacket[1]) == self.number_of_chunks)  # self.number_of_chunks

    def getSolvedCount(self) -> int:
        return len(self.degreeToPacket[1])

    def getNextValidPacket(self, from_multiple_files: bool = False, packet_len_format: str = "I",
                           crc_len_format: str = "L", number_of_chunks_len_format: str = "I",
                           degree_len_format: str = "I", seed_len_format: str = "I",
                           last_chunk_len_format: str = "I") -> typing.Optional[Packet]:
        if not from_multiple_files:
            packet_len: typing.Union[bytes, int] = self.f.read(struct.calcsize("<" + packet_len_format))
            packet_len = struct.unpack("<" + packet_len_format, packet_len)[0]
            packet: bytes = self.f.read(int(packet_len))
        else:
            packet = self.f.read()
            packet_len = len(packet)
        if not packet or not packet_len:  # EOF
            self.EOF: bool = True
            self.f.close()
            return None
        crc_len: typing.Optional[int] = -struct.calcsize("<" + crc_len_format)
        if self.error_correction.__code__.co_name == crc32.__code__.co_name:
            payload = packet[:crc_len]
            # instead of typing.Any we would have _SupportsIndex:
            crc: typing.Union[int, typing.Any] = struct.unpack("<" + crc_len_format, packet[crc_len:])[0]
            calced_crc = calc_crc(payload)
            if crc != calced_crc:  # If the Packet is corrupt, try next one
                print("[-] CRC-Error - " + str(hex(crc)) + " != " + str(hex(calced_crc)))
                self.corrupt += 1
                return self.getNextValidPacket(from_multiple_files, packet_len_format=packet_len_format,
                                               crc_len_format=crc_len_format,
                                               number_of_chunks_len_format=number_of_chunks_len_format,
                                               degree_len_format=degree_len_format, seed_len_format=seed_len_format,
                                               last_chunk_len_format=last_chunk_len_format)
        else:
            crc_len = None
            try:
                packet = self.error_correction(packet)
            except:
                self.corrupt += 1
                return self.getNextValidPacket(from_multiple_files)

        struct_str = "<" + number_of_chunks_len_format + degree_len_format + seed_len_format
        struct_len = struct.calcsize(struct_str)
        len_data = struct.unpack(struct_str, packet[0:struct_len])
        degree = None
        if self.static_number_of_chunks is None:
            if self.implicit_mode:
                number_of_chunks, seed = len_data
            else:
                number_of_chunks, degree, seed = len_data
            self.number_of_chunks = xor_mask(number_of_chunks, number_of_chunks_len_format)
        else:
            if self.implicit_mode:
                seed, = len_data
            else:
                degree, seed = len_data
        seed = xor_mask(seed, seed_len_format)
        if degree is None:
            self.dist.set_seed(seed)
            degree = self.dist.getNumber()
        else:
            degree = xor_mask(degree, degree_len_format)
        used_packets = self.choose_packet_numbers(degree, seed=seed)
        data = packet[struct_len:crc_len]

        self.correct += 1
        res = DecodePacket(data, used_packets, error_correction=self.error_correction,
                           number_of_chunks=self.number_of_chunks)
        if used_packets.issubset({0}) and self.headerChunk is None and self.use_headerchunk:
            self.headerChunk = HeaderChunk(res)
        return res

    def choose_packet_numbers(self, degree: int, seed: int = 0) -> typing.Set:
        assert degree <= self.number_of_chunks
        res = set()
        rng = np.random
        rng.seed(seed)
        for _ in range(0, degree):
            tmp = rng.choice(range(0, self.number_of_chunks))
            while tmp in res:
                tmp = rng.choice(range(0, self.number_of_chunks))
            res.add(tmp)
        return res

    def saveDecodedFile(self, last_chunk_len_format: str = "I", null_is_terminator: bool = False,
                        print_to_output: bool = True) -> None:
        assert self.is_decoded(), "Can not save File: Unable to reconstruct."
        file_name: str = "DEC_" + self.file.split("\x00")[0]  # split is needed for weird  MAC / Windows bugs...
        sort_list: typing.List = sorted(self.degreeToPacket[1])
        if 0 in sort_list[0].get_used_packets() and self.use_headerchunk:
            self.headerChunk = HeaderChunk(sort_list[0])
        output_concat: bytes = b""
        if self.headerChunk is not None:
            file_name = self.headerChunk.get_file_name().decode("utf-8")
        with open(file_name, "wb") as f:
            for decoded in sort_list:
                if 0 in decoded.get_used_packets() and self.use_headerchunk:
                    self.headerChunk = HeaderChunk(decoded, last_chunk_len_format=last_chunk_len_format)
                else:
                    if self.number_of_chunks - 1 in decoded.get_used_packets() and self.use_headerchunk:
                        output = decoded.get_data()[0: self.headerChunk.get_last_chunk_length()]
                        if type(output) == bytes:
                            output_concat += output
                        else:
                            output_concat += output.tobytes()
                        f.write(output)
                    else:
                        if null_is_terminator:
                            data = decoded.get_data()
                            if type(data) == bytes:
                                splitter = data.decode().split("\x00")
                            else:
                                splitter = data.tostring().decode().split("\x00")
                            output = splitter[0].encode()
                            output_concat += output
                            f.write(output)
                            if len(splitter) > 1:
                                break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                        else:
                            output = decoded.get_data()
                            if type(output) == np.ndarray or type(output) != bytes:
                                output_concat += output.tobytes()
                            else:
                                output_concat += output
                            f.write(output)

        print("Saved file as '" + str(file_name) + "'")
        if print_to_output:
            print("Result:")
            print(output_concat.decode("utf-8"))
Ejemplo n.º 7
0
    def getNextValidPacket(self, from_multiple_files: bool = False, packet_len_format: str = "I",
                           crc_len_format: str = "L", number_of_chunks_len_format: str = "I",
                           degree_len_format: str = "I", seed_len_format: str = "I",
                           last_chunk_len_format: str = "I") -> typing.Optional[Packet]:
        if not from_multiple_files:
            packet_len: typing.Union[bytes, int] = self.f.read(struct.calcsize("<" + packet_len_format))
            packet_len = struct.unpack("<" + packet_len_format, packet_len)[0]
            packet: bytes = self.f.read(int(packet_len))
        else:
            packet = self.f.read()
            packet_len = len(packet)
        if not packet or not packet_len:  # EOF
            self.EOF: bool = True
            self.f.close()
            return None
        crc_len: typing.Optional[int] = -struct.calcsize("<" + crc_len_format)
        if self.error_correction.__code__.co_name == crc32.__code__.co_name:
            payload = packet[:crc_len]
            # instead of typing.Any we would have _SupportsIndex:
            crc: typing.Union[int, typing.Any] = struct.unpack("<" + crc_len_format, packet[crc_len:])[0]
            calced_crc = calc_crc(payload)
            if crc != calced_crc:  # If the Packet is corrupt, try next one
                print("[-] CRC-Error - " + str(hex(crc)) + " != " + str(hex(calced_crc)))
                self.corrupt += 1
                return self.getNextValidPacket(from_multiple_files, packet_len_format=packet_len_format,
                                               crc_len_format=crc_len_format,
                                               number_of_chunks_len_format=number_of_chunks_len_format,
                                               degree_len_format=degree_len_format, seed_len_format=seed_len_format,
                                               last_chunk_len_format=last_chunk_len_format)
        else:
            crc_len = None
            try:
                packet = self.error_correction(packet)
            except:
                self.corrupt += 1
                return self.getNextValidPacket(from_multiple_files)

        struct_str = "<" + number_of_chunks_len_format + degree_len_format + seed_len_format
        struct_len = struct.calcsize(struct_str)
        len_data = struct.unpack(struct_str, packet[0:struct_len])
        degree = None
        if self.static_number_of_chunks is None:
            if self.implicit_mode:
                number_of_chunks, seed = len_data
            else:
                number_of_chunks, degree, seed = len_data
            self.number_of_chunks = xor_mask(number_of_chunks, number_of_chunks_len_format)
        else:
            if self.implicit_mode:
                seed, = len_data
            else:
                degree, seed = len_data
        seed = xor_mask(seed, seed_len_format)
        if degree is None:
            self.dist.set_seed(seed)
            degree = self.dist.getNumber()
        else:
            degree = xor_mask(degree, degree_len_format)
        used_packets = self.choose_packet_numbers(degree, seed=seed)
        data = packet[struct_len:crc_len]

        self.correct += 1
        res = DecodePacket(data, used_packets, error_correction=self.error_correction,
                           number_of_chunks=self.number_of_chunks)
        if used_packets.issubset({0}) and self.headerChunk is None and self.use_headerchunk:
            self.headerChunk = HeaderChunk(res)
        return res
Ejemplo n.º 8
0
class OnlineBPDecoder(BPDecoder):
    def __init__(self,
                 file: str,
                 error_correction: typing.Callable = nocode,
                 use_headerchunk: bool = True,
                 static_number_of_chunks: typing.Optional[int] = None):
        super().__init__(file, error_correction, use_headerchunk,
                         static_number_of_chunks)
        self.use_headerchunk: bool = use_headerchunk
        self.file: str = file
        if file is not None:
            self.isFolder: bool = os.path.isdir(file)
            if not self.isFolder:
                self.f = open(self.file, "rb")
        self.rng: numpy.random = numpy.random
        self.auxBlockNumbers: typing.Dict[int, typing.Set[int]] = dict()
        self.error_correction: typing.Callable = error_correction
        self.static_number_of_chunks: int = static_number_of_chunks
        self.epsilon = None
        self.quality = None

    def decodeFolder(self,
                     packet_len_format: str = "I",
                     crc_len_format: str = "L",
                     number_of_chunks_len_format: str = "I",
                     quality_len_format: str = "I",
                     epsilon_len_format: str = "f",
                     check_block_number_len_format: str = "I") -> int:
        decoded: bool = False
        self.EOF: bool = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks: int = self.static_number_of_chunks
            number_of_chunks_len_format: str = ""  # if we got static number_of_chunks we do not need it in struct string
        for dir_file in os.listdir(self.file):
            if dir_file.endswith(".ONLINE") or dir_file.endswith("DNA"):
                self.EOF = False
                if dir_file.endswith("DNA"):
                    self.f = quat_file_to_bin(self.file + "/" + dir_file)
                else:
                    self.f = open(self.file + "/" + dir_file, "rb")
                new_pack = self.getNextValidPacket(
                    True,
                    packet_len_format=packet_len_format,
                    crc_len_format=crc_len_format,
                    number_of_chunks_len_format=number_of_chunks_len_format,
                    quality_len_format=quality_len_format,
                    epsilon_len_format=epsilon_len_format,
                    check_block_number_len_format=check_block_number_len_format
                )
                if new_pack is not None:
                    # koennte durch input_new_packet ersetzt werden:
                    self.addPacket(new_pack)
                    decoded = self.updatePackets(new_pack)
                if decoded:
                    break
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        if hasattr(self, "f"):
            self.f.close()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def decodeFile(self,
                   packet_len_format: str = "I",
                   crc_len_format: str = "L",
                   number_of_chunks_len_format: str = "I",
                   quality_len_format: str = "I",
                   epsilon_len_format: str = "f",
                   check_block_number_len_format: str = "I") -> int:
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        decoded = False
        self.EOF: bool = False
        while not (decoded or self.EOF):
            new_pack = self.getNextValidPacket(
                False,
                packet_len_format=packet_len_format,
                crc_len_format=crc_len_format,
                number_of_chunks_len_format=number_of_chunks_len_format,
                quality_len_format=quality_len_format,
                epsilon_len_format=epsilon_len_format,
                check_block_number_len_format=check_block_number_len_format)
            if new_pack is None:
                break
            self.addPacket(new_pack)
            decoded = self.updatePackets(new_pack)
            ##
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        self.f.close()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def input_new_packet(self,
                         packet: OnlinePacket,
                         last_chunk_len_format: str = "I") -> bool:
        self.number_of_chunks = packet.total_number_of_chunks
        self.quality: int = packet.quality
        self.epsilon: float = packet.epsilon
        if self.headerChunk is None and self.use_headerchunk:
            self.decodeHeader()
        if self.correct == 0:
            self.createAuxBlocks()
        self.correct += 1
        self.addPacket(packet)
        return self.updatePackets(packet)

    def getAuxPacketListFromPacket(
            self, packet: OnlinePacket) -> typing.List[typing.List[bool]]:
        res: typing.List[typing.List[bool]] = []
        aux_used_packets = packet.getBoolArrayAuxPackets()
        i = 0
        for aux in aux_used_packets:
            if aux:
                res.append(self.auxBlocks[i].get_bool_array_used_packets())
            i += 1
        return res

    def removeAndXorAuxPackets(self,
                               packet: OnlinePacket) -> typing.List[bool]:
        aux_mapping = self.getAuxPacketListFromPacket(packet)
        aux_mapping.append(packet.get_bool_array_used_packets())
        return logical_xor(aux_mapping)

    def createAuxBlocks(self) -> None:
        assert self.number_of_chunks is not None, "createAuxBlocks can only be called AFTER first Packet"
        # self.dist.update_number_of_chunks(self.number_of_chunks)
        self.rng.seed(self.number_of_chunks)
        if self.debug:
            print("We should have " + str(self.getNumberOfAuxBlocks()) +
                  " Aux-Blocks and " + str(self.number_of_chunks) +
                  " normal Chunks (+ 1 HeaderChunk)")
        for i in range(0, self.getNumberOfAuxBlocks()):
            self.auxBlockNumbers[i] = set()
        for chunk_no in range(
                0, self.number_of_chunks
        ):  # + (1 if self.use_headerchunk else 0)):  # + 1 for HeaderChunk
            # Insert this Chunk into quality different Aux-Packets
            for i in range(0, self.quality):
                # uniform choose a number of aux blocks
                aux_no = self.rng.randint(0, self.getNumberOfAuxBlocks())
                self.auxBlockNumbers[aux_no].add(chunk_no)

        # XOR all Chunks into the corresponding AUX-Block
        for aux_number in self.auxBlockNumbers.keys():
            self.auxBlocks[aux_number] = OnlineAuxPacket(
                b"",
                self.auxBlockNumbers[aux_number],
                aux_number=aux_number,
                total_number_of_chunks=self.number_of_chunks
            )  # , numberOfAuxPackets=self.getNumberOfAuxBlocks()) # We will add the Data once we have it.

    def solve(self):
        if self.use_headerchunk and self.headerChunk is None:
            self.decodeHeader()
        # Decoder.solve(self)
        super(OnlineBPDecoder, self).solve()

    def is_decoded(self) -> bool:
        return self.getSolvedCount() >= self.number_of_chunks

    def getSolvedCount(self) -> int:
        return len(
            self.decodedPackets) + (1 if self.headerChunk is not None else 0)

    def getNextValidPacket(
        self,
        from_multiple_files: bool = False,
        packet_len_format: str = "I",
        crc_len_format: str = "L",
        number_of_chunks_len_format: str = "I",
        quality_len_format: str = "I",
        epsilon_len_format: str = "f",
        check_block_number_len_format: str = "I"
    ) -> typing.Optional[OnlinePacket]:
        if not from_multiple_files:
            packet_len = self.f.read(struct.calcsize("<" + packet_len_format))
            packet_len = struct.unpack("<" + packet_len_format, packet_len)[0]
            packet: bytes = self.f.read(int(packet_len))
        else:
            packet = self.f.read()
            packet_len = len(packet)
        if not packet or not packet_len:  # EOF
            self.EOF = True
            self.f.close()
            return None

        crc_len: typing.Optional[int] = struct.calcsize("<" + crc_len_format)
        if self.error_correction.__code__.co_name == crc32.__code__.co_name:
            payload = packet[:crc_len]
            crc = struct.unpack("<L", packet[crc_len:])[0]
            calced_crc = calc_crc(payload)

            if crc != calced_crc:  # If the Packet is corrupt, try next one
                print("[-] CRC-Error - " + str(hex(crc)) + " != " +
                      str(hex(calced_crc)))
                self.corrupt += 1
                return self.getNextValidPacket(from_multiple_files)
        else:
            crc_len = None
            try:
                packet = self.error_correction(packet)
            except:
                self.corrupt += 1
                return self.getNextValidPacket(from_multiple_files)
        struct_str: str = "<" + number_of_chunks_len_format + quality_len_format + epsilon_len_format + check_block_number_len_format
        struct_len: int = struct.calcsize(struct_str)
        data = packet[struct_len:crc_len]
        len_data: typing.Union[typing.Tuple[int, float, int],
                               typing.Tuple[int, int, float,
                                            int]] = struct.unpack(
                                                struct_str,
                                                packet[0:struct_len])
        if self.static_number_of_chunks is None:
            number_of_chunks, quality, self.epsilon, check_block_number = len_data
            self.number_of_chunks = xor_mask(number_of_chunks,
                                             number_of_chunks_len_format)
        else:
            quality, self.epsilon, check_block_number = len_data
        self.quality = xor_mask(quality, quality_len_format)
        if self.dist is None:
            self.dist = OnlineDistribution(self.epsilon)
        if self.correct == 0:
            # Create MockUp AuxBlocks with the given Pseudo-Random Number -> we will know which Packets are Encoded in which AuxBlock
            self.createAuxBlocks()

        self.correct += 1
        res = OnlinePacket(
            data,
            self.number_of_chunks,
            self.quality,
            self.epsilon,
            check_block_number,
            read_only=True,
            crc_len_format=crc_len_format,
            number_of_chunks_len_format=number_of_chunks_len_format,
            quality_len_format=quality_len_format,
            epsilon_len_format=epsilon_len_format,
            check_block_number_len_format=check_block_number_len_format,
            save_number_of_chunks_in_packet=self.static_number_of_chunks is
            None)
        return res

    def decodeHeader(self, last_chunk_len_format: str = "I") -> None:
        if self.headerChunk is not None or 1 not in self.degreeToPacket.keys():
            return  # Header already set or no chunks decoded so far
        for decoded in self.degreeToPacket[1]:
            if decoded.get_used_packets().issubset({0}):
                self.headerChunk = HeaderChunk(
                    decoded, last_chunk_len_format=last_chunk_len_format)
                return

    def saveDecodedFile(self,
                        null_is_terminator: bool = False,
                        print_to_output: bool = True) -> None:
        assert self.is_decoded(), "Can not save File: Unable to reconstruct."
        if self.use_headerchunk:
            self.decodeHeader()
        file_name = "DEC_" + self.file.split("\x00")[
            0]  # split is needed for weird  MAC / Windows bugs...
        output_concat = b""
        if self.headerChunk is not None:
            file_name = self.headerChunk.get_file_name().decode("utf-8")
        with open(file_name, "wb") as f:
            a = []
            for decoded in sorted(self.decodedPackets):
                [num] = decoded.get_used_packets()
                if 0 != num or not self.use_headerchunk or self.number_of_chunks - 1 == 0:
                    if isinstance(decoded, OnlineAuxPacket):
                        a.append(num)
                    if self.number_of_chunks - 1 == num and self.use_headerchunk:
                        output: typing.Union[
                            bytes, numpy.array] = decoded.get_data(
                            )[0:self.headerChunk.get_last_chunk_length()]
                        if type(output) == bytes:
                            output_concat += output
                        else:
                            output_concat += output.tobytes()
                        f.write(output)
                    else:
                        if null_is_terminator:
                            data = decoded.get_data()
                            if type(data) == bytes:
                                splitter = data.decode().split("\x00")
                            else:
                                splitter = data.tostring().decode().split(
                                    "\x00")
                            output = splitter[0].encode()
                            if type(output) == bytes:
                                output_concat += output
                            else:
                                output_concat += output.tobytes()
                            f.write(output)
                            if len(splitter) > 1:
                                break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                        else:
                            output = decoded.get_data()
                            if type(output) == bytes:
                                output_concat += output
                            else:
                                output_concat += output.tobytes()
                            f.write(output)
        print("Saved file as '" + str(file_name) + "'")
        if print_to_output:
            print("Result:")
            print(output_concat.decode("utf-8"))

    def getNumberOfAuxBlocks(self) -> int:
        return int(
            ceil(0.55 * self.quality * self.epsilon * self.number_of_chunks))
Ejemplo n.º 9
0
class RU10BPDecoder(BPDecoder):
    def __init__(self,
                 file: typing.Optional[str] = None,
                 error_correction=nocode,
                 use_headerchunk: bool = True,
                 static_number_of_chunks: typing.Optional[int] = None,
                 use_method: bool = False):
        super().__init__()
        self.file: typing.Optional[str] = file
        self.use_method: bool = use_method
        if file is not None:
            self.isFolder = os.path.isdir(file)
            if not self.isFolder:
                self.f = open(self.file, "rb")
        self.number_of_chunks: int = 1000000
        self.s: int = -1
        self.h: int = -1
        self.error_correction: typing.Callable = error_correction
        self.use_headerchunk: bool = use_headerchunk
        self.static_number_of_chunks: typing.Optional[
            int] = static_number_of_chunks

    def decodeFolder(self,
                     packet_len_format: str = "I",
                     crc_len_format: str = "I",
                     number_of_chunks_len_format: str = "I",
                     id_len_format: str = "I"):
        """
        Decodes the information from a folder if self.file represents a folder and the packets were saved
        in multiple files and prints the number of decoded and corrupted packets.
        :param packet_len_format: Format of the packet length
        :param crc_len_format:  Format of the crc length
        :param number_of_chunks_len_format: Format of the number of chunks length
        :param id_len_format: Format of the ID length
        :return: -1 if the decoding wasn't successful
        """
        decoded = False
        self.EOF = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        for file in os.listdir(self.file):
            if file.endswith(".RU10") or file.endswith("DNA"):
                self.EOF = False
                if file.endswith("DNA"):
                    if self.error_correction.__name__ == 'dna_reed_solomon_decode':
                        try:
                            self.f = quad_file_to_bytes(self.file + "/" + file)
                        except TypeError:
                            print(
                                "skipping CORRUPT file - contains illegal character(s)"
                            )
                            self.corrupt += 1
                            continue
                    else:
                        try:
                            self.f = quat_file_to_bin(self.file + "/" + file)
                        except TypeError:
                            print(
                                "skipping CORRUPT file - contains illegal character(s)"
                            )
                            self.corrupt += 1
                            continue
                else:
                    self.f = open(self.file + "/" + file, "rb")
                new_pack = self.getNextValidPacket(
                    True,
                    packet_len_format=packet_len_format,
                    crc_len_format=crc_len_format,
                    number_of_chunks_len_format=number_of_chunks_len_format,
                    id_len_format=id_len_format)
                if new_pack is not None:
                    # koennte durch input_new_packet ersetzt werden:
                    # self.addPacket(new_pack)
                    decoded = self.input_new_packet(new_pack)
                if decoded:
                    break
                ##
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets: " + str(self.corrupt))
        if hasattr(self, "f"):
            self.f.close()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too many errors?")
            return -1

    def decodeFile(self,
                   packet_len_format: str = "I",
                   crc_len_format: str = "L",
                   number_of_chunks_len_format: str = "I",
                   id_len_format: str = "I"):
        """
        Decodes the information from a file if self.file represents a file and the packets were saved in a single file.
        :param packet_len_format: Format of the packet length
        :param crc_len_format:  Format of the crc length
        :param number_of_chunks_len_format: Format of the number of chunks length
        :param id_len_format: Format of the ID length
        :return: -1 if the decoding wasn't successful
        """
        decoded = False
        self.EOF = False
        if self.file.lower().endswith("dna"):
            try:
                self.f.close()
                self.f = quat_file_to_bin(self.file)
            except TypeError:
                print("skipping CORRUPT file - contains illegal character(s)")
                self.corrupt += 1
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        if self.file.lower().endswith("fasta"):
            self.f.close()
            self.f = open(self.file, "r")
            raw_packet_list = []
            while not (decoded or self.EOF):
                line = self.f.readline()
                if not line:
                    self.EOF = True
                    break
                try:
                    error_prob, seed = line[1:].replace("\n", "").split("_")
                except:
                    error_prob, seed = "0", "0"
                line = self.f.readline()
                if not line:
                    self.EOF = True
                    break
                dna_str = line.replace("\n", "")
                raw_packet_list.append((error_prob, seed, dna_str))
                new_pack = self.parse_raw_packet(
                    BytesIO(tranlate_quat_to_byte(dna_str)).read(),
                    crc_len_format=crc_len_format,
                    number_of_chunks_len_format=number_of_chunks_len_format,
                    packet_len_format=packet_len_format,
                    id_len_format=id_len_format)
                decoded = self.input_new_packet(new_pack)
        else:
            while not (decoded or self.EOF):
                new_pack = self.getNextValidPacket(
                    False,
                    packet_len_format=packet_len_format,
                    crc_len_format=crc_len_format,
                    number_of_chunks_len_format=number_of_chunks_len_format,
                    id_len_format=id_len_format)
                if new_pack is None:
                    break
                decoded = self.input_new_packet(new_pack)
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def getNumberOfLDPCBlocks(self):
        return self.s

    def getNumberOfHalfBlocks(self):
        return self.h

    def getNumberOfRepairBlocks(self):
        return self.getNumberOfHalfBlocks() + self.getNumberOfLDPCBlocks()

    def removeAndXorAuxPackets(self, packet: RU10Packet) -> typing.List[bool]:
        """
        Removes auxpackets (LDCP and Half) from a given packet to get the packets data.
        :param packet: Packet to remove auxpackets from
        :return: The data without the auxpackets
        """
        aux_mapping = self.getHalfPacketListFromPacket(
            packet)  # Enthaelt Data + LDPC Nummern
        aux_mapping.append(packet.get_bool_array_used_and_ldpc_packets())
        xored_list = logical_xor(aux_mapping)
        tmp = set(from_true_false_list(
            xored_list))  # Nur noch Data + LDPC sind vorhanden
        if self.debug:
            print(tmp)
        tmp = type(packet)("",
                           tmp,
                           self.number_of_chunks,
                           packet.id,
                           packet.dist,
                           read_only=True)
        aux_mapping = self.getAuxPacketListFromPacket(tmp)
        aux_mapping.append(
            tmp.get_bool_array_used_packets())  # [-len(self.auxBlocks):])
        return logical_xor(aux_mapping)

    def input_new_packet(self, packet: RU10Packet):
        """
        Removes auxpackets (LDPC and Half) and adds the remaining data to the GEPP matrix.
        :param packet: A Packet to add to the GEPP matrix
        :return: True: If solved. False: Else.
        """
        if self.auxBlocks == dict() and self.dist is None:  # self.isPseudo and
            self.dist = RaptorDistribution(self.number_of_chunks)
            self.number_of_chunks = packet.get_total_number_of_chunks()
            _, self.s, self.h = intermediate_symbols(self.number_of_chunks,
                                                     self.dist)
            self.createAuxBlocks()
        # we need to do it twice sine half symbols may contain ldpc symbols (which by definition are repair codes.)
        if self.debug:
            print("----")
            print("Id = " + str(packet.id))
            print(packet.used_packets)
        removed = self.removeAndXorAuxPackets(packet)
        if self.debug:
            print(from_true_false_list(removed))
            print(packet.get_error_correction())
            print("----")
        packet.set_used_packets(set(from_true_false_list(removed)))
        if self.count:
            for i in range(len(removed)):
                if i in self.counter.keys():
                    if removed[i]:
                        self.counter[i] += 1
                else:
                    self.counter[i] = 1
        self.addPacket(packet)
        return self.updatePackets(packet)

    def createAuxBlocks(self):
        """
        Reconstructs the auxblocks to be able to remove them afterwards.
        :return:
        """
        assert (self.number_of_chunks is not None
                ), "createAuxBlocks can only be called AFTER first Packet"
        if self.debug:
            print("We should have " + str(self.getNumberOfLDPCBlocks()) +
                  " LDPC-Blocks, " + str(self.getNumberOfHalfBlocks()) +
                  " Half-Blocks and " + str(self.number_of_chunks) +
                  " normal Chunks (including 1 HeaderChunk)")
        for i in range(0, self.getNumberOfRepairBlocks()):
            self.repairBlockNumbers[i] = set()
        i = 0
        for group in self.generateIntermediateBlocksFormat(
                self.number_of_chunks):
            for elem in group:
                self.repairBlockNumbers[i] = elem
                i += 1
        # XOR all Chunks into the corresponding AUX-Block
        for aux_number in self.repairBlockNumbers.keys():
            self.auxBlocks[aux_number] = RU10IntermediatePacket(
                "",
                self.repairBlockNumbers[aux_number],
                total_number_of_chunks=self.number_of_chunks,
                id=aux_number,
                dist=self.dist)  # # We will add the Data once we have it.
            if self.debug:
                print(
                    str(aux_number) + " : " +
                    str(self.auxBlocks[aux_number].used_packets))
        # Correct

    def getAuxPacketListFromPacket(self, packet: RU10Packet):
        """
        Creates a list for a packet with information about whether auxpackets have been used for that packet.
        :param packet: The packet to check.
        :return: Information about used auxpackets.
        """
        res = []
        aux_used_packets = packet.get_bool_array_repair_packets()
        for i in range(len(aux_used_packets)):
            if aux_used_packets[i]:
                res.append((self.auxBlocks[i].get_bool_array_used_packets()))

        return res

    def getHalfPacketListFromPacket(
            self, packet: RU10Packet) -> typing.List[typing.List[bool]]:
        """
        Generates a list of halfpackets from a packet.
        :param packet: The packet to get the list from
        :return: List of halfpackets
        """
        res: typing.List[typing.List[bool]] = []
        aux_used_packets = packet.get_bool_array_half_packets()
        for i in range(len(aux_used_packets)):
            if aux_used_packets[i]:
                res.append(
                    (self.auxBlocks[packet.get_number_of_ldpc_blocks() +
                                    i].get_bool_array_used_and_ldpc_packets()))
        return res

    def solve(self):
        if self.use_headerchunk and self.headerChunk is None:
            self.decodeHeader()
        # Decoder.solve(self)
        super(self.__class__, self).solve()

    def is_decoded(self) -> bool:
        return self.getSolvedCount() >= self.number_of_chunks

    def getSolvedCount(self) -> int:
        return len(
            self.decodedPackets) + (1 if self.headerChunk is not None else 0)

    def getNextValidPacket(
            self,
            from_multiple_files: bool = False,
            packet_len_format: str = "I",
            crc_len_format: str = "L",
            number_of_chunks_len_format: str = "I",
            id_len_format: str = "I") -> typing.Optional[RU10Packet]:
        """
        Takes a raw packet from a file and calls @parse_raw_packet to get a RU10 packet. If the packet is corrupt the
        next one will be taken.
        :param from_multiple_files: True: The packets were saved in multiple files. False: Packets were saved in one file.
        :param packet_len_format: Format of the packet length
        :param crc_len_format:  Format of the crc length
        :param number_of_chunks_len_format: Format of the number of chunks length
        :param id_len_format: Format of the ID length
        :return: RU10Packet
        """
        if not from_multiple_files:
            packet_len = self.f.read(struct.calcsize("<" + packet_len_format))
            try:
                packet_len = struct.unpack("<" + packet_len_format,
                                           packet_len)[0]
                packet = self.f.read(int(packet_len))
            except:
                return None
        else:
            packet = self.f.read()
            packet_len = len(packet)
        if not packet or not packet_len:  # EOF
            self.EOF = True
            try:
                self.f.close()
            except:
                return None
            return None
        res = self.parse_raw_packet(
            packet,
            crc_len_format=crc_len_format,
            packet_len_format=packet_len_format,
            number_of_chunks_len_format=number_of_chunks_len_format,
            id_len_format=id_len_format)
        if res == "CORRUPT":
            res = self.getNextValidPacket(
                from_multiple_files,
                packet_len_format=packet_len_format,
                crc_len_format=crc_len_format,
                number_of_chunks_len_format=number_of_chunks_len_format,
                id_len_format=id_len_format)
        return res

    def parse_raw_packet(
            self,
            packet,
            crc_len_format: str = "L",
            number_of_chunks_len_format: str = "L",
            packet_len_format: str = "I",
            id_len_format: str = "L") -> typing.Union[RU10Packet, str]:
        """
        Creates a RU10 packet from a raw given packet. Also checks if the packet is corrupted. If any method was used to
        create packets from specific chunks, set self.use_method = True. This will treat the last byte of the raw packet
        data as the byte that contains the information about the used method ("even", "odd", "window_30 + window" or
        "window_40 + window". See RU10Encoder.create_new_packet_from_chunks for further information.
        :param packet: A raw packet
        :param packet_len_format: Format of the packet length
        :param crc_len_format:  Format of the crc length
        :param number_of_chunks_len_format: Format of the number of chunks length
        :param id_len_format: Format of the ID length
        :return: RU10Packet or an error message
        """
        struct_str = "<" + number_of_chunks_len_format + id_len_format
        struct_len = struct.calcsize(struct_str)
        """"
        if self.error_correction.__code__.co_name == crc32.__code__.co_name:
            crc_len = -struct.calcsize("<" + crc_len_format)
            payload = packet[:crc_len]
            crc = struct.unpack("<" + crc_len_format, packet[crc_len:])[0]
            calced_crc = calc_crc(payload)
            if crc != calced_crc:  # If the Packet is corrupt, try next one
                print("[-] CRC-Error - " + str(hex(crc)) + " != " + str(hex(calced_crc)))
                self.corrupt += 1
                return "CORRUPT"
    
        else:
        """
        try:
            packet = self.error_correction(packet)
        except:
            self.corrupt += 1
            return "CORRUPT"

        data = packet[struct_len:]
        if self.use_method:
            method_data = bin(data[-1])[2:]
            while len(method_data) < 8:
                method_data = '0' + method_data
            data = data[:-1]
            if method_data.startswith('00'):
                chunk_lst = [
                    ch for ch in range(0, self.number_of_chunks + 1)
                    if ch % 2 == 0
                ]
            elif method_data.startswith('01'):
                chunk_lst = [
                    ch for ch in range(0, self.number_of_chunks + 1)
                    if ch % 2 != 0
                ]
            elif method_data.startswith('10'):
                window = int(method_data[2:], 2)
                window_size = 30
                start = window * (window_size - 10)
                chunk_lst = [
                    ch for ch in range(start, start + window_size)
                    if ch <= self.number_of_chunks
                ]
            elif method_data.startswith('11'):
                window = int(method_data[2:], 2)
                window_size = 40
                start = window * (window_size - 10)
                chunk_lst = [
                    ch for ch in range(start, start + window_size)
                    if ch <= self.number_of_chunks
                ]
            else:
                raise RuntimeError(f"Invalid method_data: %s" % method_data)
        len_data = struct.unpack(struct_str, packet[0:struct_len])
        if self.static_number_of_chunks is None:
            self.number_of_chunks = xor_mask(len_data[0],
                                             number_of_chunks_len_format)
            unxored_id = xor_mask(len_data[1], id_len_format)
        else:
            unxored_id = xor_mask(len_data[0], id_len_format)
        if self.dist is None:
            self.dist = RaptorDistribution(self.number_of_chunks)
            _, self.s, self.h = intermediate_symbols(self.number_of_chunks,
                                                     self.dist)

        if self.correct == 0:
            self.createAuxBlocks()
        self.correct += 1
        if self.use_method:
            numbers = choose_packet_numbers(len(chunk_lst),
                                            unxored_id,
                                            self.dist,
                                            systematic=False,
                                            max_l=len(chunk_lst))
            used_packets = set([chunk_lst[i] for i in numbers])
        else:
            used_packets = set(
                choose_packet_numbers(self.number_of_chunks,
                                      unxored_id,
                                      self.dist,
                                      systematic=False))
        res = RU10Packet(
            data,
            used_packets,
            self.number_of_chunks,
            unxored_id,
            read_only=True,
            packet_len_format=packet_len_format,
            crc_len_format=crc_len_format,
            number_of_chunks_len_format=number_of_chunks_len_format,
            id_len_format=id_len_format,
            save_number_of_chunks_in_packet=self.static_number_of_chunks is
            None)
        return res

    def generateIntermediateBlocksFormat(
            self, number_of_chunks: int
    ) -> typing.List[typing.List[typing.List[int]]]:
        """
        Generates the format of the intermediate blocks from the number of used chunks.
        :param number_of_chunks: The number of used chunks.
        :return:
        """
        compositions: typing.List[typing.List[int]] = [[]
                                                       for _ in range(self.s)]
        for i in range(0, number_of_chunks):
            a = 1 + (int(floor(np.float64(i) / np.float64(self.s))) %
                     (self.s - 1))
            b = int(i % self.s)
            compositions[b].append(i)
            b = (b + a) % self.s
            compositions[b].append(i)
            b = (b + a) % self.s
            compositions[b].append(i)

        hprime: int = int(ceil(np.float64(self.h) / 2))
        m = buildGraySequence(number_of_chunks + self.s, hprime)
        hcompositions: typing.List[typing.List[int]] = [[]
                                                        for _ in range(self.h)]
        for i in range(0, self.h):
            hcomposition = []
            for j in range(0, number_of_chunks + self.s):
                if bitSet(np.uint32(m[j]), np.uint32(i)):
                    hcomposition.append(j)
            hcompositions[i] = hcomposition
        res = [compositions, hcompositions]
        return res

    def decodeHeader(self, last_chunk_len_format: str = "I") -> None:
        if self.headerChunk is not None or 1 not in self.degreeToPacket.keys():
            return  # Header already set
        for decoded in self.degreeToPacket[1]:
            if decoded.get_used_packets().issubset({0}):
                self.headerChunk = HeaderChunk(
                    decoded, last_chunk_len_format=last_chunk_len_format)
                return

    def saveDecodedFile(self,
                        last_chunk_len_format: str = "I",
                        null_is_terminator: bool = False,
                        print_to_output: bool = True,
                        return_file_name=False) -> typing.Union[bytes, str]:
        """
        Saves the file - if decoded. The filename is either taken from the headerchunk or generated based on the input
        filename.
        :param return_file_name: if set to true, this function will return the filename under which the file as been saved
        :param last_chunk_len_format: Format of the last chunk length
        :param null_is_terminator: True: The file is handled as null-terminated C-String.
        :param print_to_output: True: Result we be printed to the command line.
        :return:
        """
        assert self.is_decoded(), "Can not save File: Unable to reconstruct."
        if self.use_headerchunk:
            self.decodeHeader()
        file_name = "DEC_" + os.path.basename(
            self.file) if self.file is not None else "RU10.BIN"
        output_concat = b""
        if self.headerChunk is not None:
            file_name = self.headerChunk.get_file_name().decode("utf-8")
        file_name = file_name.split("\x00")[0]
        with open(file_name, "wb") as f:
            a = []
            for decoded in sorted(self.decodedPackets):
                [num] = decoded.get_used_packets()
                if 0 != num or not self.use_headerchunk or self.number_of_chunks - 1 == 0:
                    if isinstance(decoded, RU10IntermediatePacket):
                        a.append(num)
                    if self.number_of_chunks - 1 == num and self.use_headerchunk:
                        output: typing.Union[
                            bytes, np.array] = decoded.get_data(
                            )[0:self.headerChunk.get_last_chunk_length()]
                        if type(output) == bytes:
                            output_concat += output
                        else:
                            output_concat += output.tobytes()
                        f.write(output)
                    else:
                        if null_is_terminator:
                            data = decoded.get_data()
                            if type(data) == bytes:
                                splitter = data.decode().split("\x00")
                            else:
                                splitter = data.tostring().decode().split(
                                    "\x00")
                            output = splitter[0].encode()
                            if type(output) == bytes:
                                output_concat += output
                            else:
                                output_concat += output.tobytes()
                            f.write(output)
                            if len(splitter) > 1:
                                break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                        else:
                            output = decoded.get_data()
                            if type(output) == bytes:
                                output_concat += output
                            else:
                                output_concat += output.tobytes()
                            f.write(output)
        print("Saved file as '" + str(file_name) + "'")
        if print_to_output:
            print("Result:")
            print(output_concat.decode("utf-8"))
        if return_file_name:
            return file_name

    def mode_1_bmp_decode(self, last_chunk_len_format: str = "I"):
        dec_out = self.saveDecodedFile(
            last_chunk_len_format=last_chunk_len_format,
            null_is_terminator=False,
            print_to_output=False)
        return self.bytes_to_bitmap(dec_out)

    def bytes_to_bitmap(self, img_byt: bytes):
        width, height = struct.unpack('>H', img_byt[:2])[0], struct.unpack(
            '>H', img_byt[2:4])[0]
        unpack = np.unpackbits(
            np.frombuffer(img_byt,
                          dtype=np.uint8,
                          count=int((width * height) / 8),
                          offset=4)).reshape(height, width).transpose()
        flip_bits = np.logical_not(unpack).astype(int)
        new_img = self.draw_img(flip_bits, width, height)
        tmp_file_name = os.path.basename(self.file) + ".bmp"
        file_name = "DEC_" + tmp_file_name if self.file is not None else "RU10.BIN.bmp"
        new_img.save(file_name)
        return file_name

    @staticmethod
    def draw_img(unpacked_flipped_bits, width: int, height: int) -> Image:
        new_img = Image.new('1', (width, height))
        pixels = new_img.load()

        for i in range(new_img.size[0]):
            for j in range(new_img.size[1]):
                pixels[i, j] = int(unpacked_flipped_bits[i, j])
        return new_img
Ejemplo n.º 10
0
 def saveDecodedFile(self, last_chunk_len_format: str = "I", null_is_terminator: bool = False,
                     print_to_output: bool = True, return_file_name=False, partial_decoding: bool = True) -> \
         typing.Union[bytes, str]:
     """
     Saves the file - if decoded. The filename is either taken from the headerchunk or generated based on the input
     filename.
     :param partial_decoding: perform partial decoding if full decoding failed, missing parts will be filled with "\x00"
     :param return_file_name: if set to true, this function will return the filename under which the file as been saved
     :param last_chunk_len_format: Format of the last chunk length
     :param null_is_terminator: True: The file is handled as null-terminated C-String.
     :param print_to_output: True: Result we be printed to the command line.
     :return:
     """
     assert self.is_decoded(
     ) or partial_decoding, "Can not save File: Unable to reconstruct. You may try saveDecodedFile(partial_decoding=True)"
     if partial_decoding:
         self.solve(partial=True)
     dirty = False
     if self.use_headerchunk:
         header_row = self.GEPP.result_mapping[0]
         if header_row >= 0:
             self.headerChunk = HeaderChunk(
                 Packet(self.GEPP.b[header_row], {0},
                        self.number_of_chunks,
                        read_only=True),
                 last_chunk_len_format=last_chunk_len_format)
     file_name = "DEC_" + os.path.basename(
         self.file) if self.file is not None else "RU10.BIN"
     output_concat = b""
     if self.headerChunk is not None:
         try:
             file_name = self.headerChunk.get_file_name().decode("utf-8")
         except Exception as ex:
             print("Warning:", ex)
     file_name = file_name.split("\x00")[0]
     with open(file_name, "wb") as f:
         for x in self.GEPP.result_mapping:
             if x < 0:
                 f.write(b"\x00" * len(self.GEPP.b[x][0]))
                 dirty = True
                 continue
             if 0 != x or not self.use_headerchunk:
                 if self.number_of_chunks - 1 == x and self.use_headerchunk:
                     output = self.GEPP.b[x][0][0:self.headerChunk.
                                                get_last_chunk_length()]
                     output_concat += output.tobytes()
                     f.write(output)
                 else:
                     if null_is_terminator:
                         splitter = self.GEPP.b[x].tostring().decode(
                         ).split("\x00")
                         output = splitter[0].encode()
                         if type(output) == bytes:
                             output_concat += output
                         else:
                             output_concat += output.tobytes()
                         f.write(output)
                         if len(splitter) > 1:
                             break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                     else:
                         output = self.GEPP.b[x]
                         if type(output) == bytes:
                             output_concat += output
                         else:
                             output_concat += output.tobytes()
                         f.write(output)
     print("Saved file as '" + str(file_name) + "'")
     if dirty:
         print(
             "Some parts could not be restored, file WILL contain sections with \\x00 !"
         )
     if print_to_output:
         print("Result:")
         print(output_concat.decode("utf-8"))
     if self.progress_bar is not None:
         self.progress_bar.update(self.number_of_chunks,
                                  Corrupt=self.corrupt)
     if return_file_name:
         return file_name
     return output_concat
Ejemplo n.º 11
0
class RU10Decoder(Decoder):
    def __init__(self,
                 file: typing.Optional[str] = None,
                 error_correction=nocode,
                 use_headerchunk: bool = True,
                 static_number_of_chunks: typing.Optional[int] = None,
                 use_method: bool = False):
        self.debug = False
        super().__init__()
        self.isPseudo: bool = False
        self.file: typing.Optional[str] = file
        self.degreeToPacket: dict = {}
        self.use_method: bool = use_method
        if file is not None:
            self.isFolder = os.path.isdir(file)
            self.isZip = file.endswith(".zip")
            if not self.isFolder:
                self.f = open(self.file, "rb")
        self.correct: int = 0
        self.corrupt: int = 0
        self.number_of_chunks: int = 1000000
        self.headerChunk: typing.Optional[HeaderChunk] = None
        self.GEPP: typing.Optional[GEPP] = None
        self.pseudoCount: int = 0
        self.ldpcANDhalf: typing.Dict[int, RU10IntermediatePacket] = dict()
        self.repairBlockNumbers: dict = dict()
        self.s: int = -1
        self.h: int = -1
        self.distribution: typing.Optional[Distribution] = None
        self.EOF: bool = False
        self.counter: dict = dict()
        self.count: bool = True
        self.error_correction: typing.Callable = error_correction
        self.use_headerchunk: bool = use_headerchunk
        self.static_number_of_chunks: typing.Optional[
            int] = static_number_of_chunks

    def decodeZip(self,
                  packet_len_format: str = "I",
                  crc_len_format: str = "I",
                  number_of_chunks_len_format: str = "I",
                  id_len_format: str = "I"):
        if hasattr(self, "f"):
            self.f.close()
        decoded = False
        self.EOF = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        archive = ZipFile(self.file, 'r')
        namelist = archive.namelist()
        try:
            nam = [x.split("_") for x in namelist]
            sorted_by_second = sorted(nam,
                                      key=lambda tup: float(tup[1]),
                                      reverse=False)
            namelist = [x[0] + "_" + x[1] for x in sorted_by_second]
        except Exception:
            pass
        for name in namelist:
            self.f = io.BytesIO(archive.read(name))
            new_pack = self.getNextValidPacket(
                True,
                packet_len_format=packet_len_format,
                crc_len_format=crc_len_format,
                number_of_chunks_len_format=number_of_chunks_len_format,
                id_len_format=id_len_format)
            if hasattr(self, "f"):
                self.f.close()
            if new_pack is None:
                break
            # koennte durch input_new_packet ersetzt werden:
            # self.addPacket(new_pack)
            if new_pack != "CORRUPT":
                decoded = self.input_new_packet(new_pack)
            else:
                if DEBUG: print(f"Packet with name={name} corrupt.")
            if decoded:
                break
            if self.progress_bar is not None:
                self.progress_bar.update(self.correct, Corrupt=self.corrupt)
            ##
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets: " + str(self.corrupt))

        if self.GEPP is None:
            print("No Packet was correctly decoded. Check your configuration.")
            return -1
        if self.GEPP.isPotentionallySolvable():
            decoded = self.GEPP.solve()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too many errors?")
            return -1

    def decodeFolder(self,
                     packet_len_format: str = "I",
                     crc_len_format: str = "I",
                     number_of_chunks_len_format: str = "I",
                     id_len_format: str = "I"):
        """
        Decodes the information from a folder if self.file represents a folder and the packets were saved
        in multiple files and prints the number of decoded and corrupted packets.
        :param packet_len_format: Format of the packet length
        :param crc_len_format:  Format of the crc length
        :param number_of_chunks_len_format: Format of the number of chunks length
        :param id_len_format: Format of the ID length
        :return: -1 if the decoding wasn't successful
        """
        decoded = False
        self.EOF = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        for file_in_folder in os.listdir(self.file):
            if file_in_folder.endswith(".RU10") or file_in_folder.endswith(
                    "DNA"):
                self.EOF = False
                if file_in_folder.endswith("DNA"):
                    if self.error_correction.__name__ == 'dna_reed_solomon_decode':
                        try:
                            self.f = quad_file_to_bytes(self.file + "/" +
                                                        file_in_folder)
                        except TypeError:
                            print(
                                "skipping CORRUPT file - contains illegal character(s)"
                            )
                            self.corrupt += 1
                            continue
                    else:
                        try:
                            self.f = quat_file_to_bin(self.file + "/" +
                                                      file_in_folder)
                        except TypeError:
                            print(
                                "skipping CORRUPT file - contains illegal character(s)"
                            )
                            self.corrupt += 1
                            continue
                else:
                    self.f = open(self.file + "/" + file_in_folder, "rb")
                new_pack = self.getNextValidPacket(
                    True,
                    packet_len_format=packet_len_format,
                    crc_len_format=crc_len_format,
                    number_of_chunks_len_format=number_of_chunks_len_format,
                    id_len_format=id_len_format)
                if new_pack is not None and new_pack != "CORRUPT":
                    # koennte durch input_new_packet ersetzt werden:
                    # self.addPacket(new_pack)
                    decoded = self.input_new_packet(new_pack)
                if decoded:
                    break
            if self.progress_bar is not None:
                self.progress_bar.update(self.correct, Corrupt=self.corrupt)
            ##
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets: " + str(self.corrupt))
        if hasattr(self, "f"):
            self.f.close()
        if self.GEPP is None:
            print("No Packet was correctly decoded. Check your configuration.")
            return -1
        if self.GEPP.isPotentionallySolvable():
            decoded = self.GEPP.solve()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too many errors?")
            return -1

    def decodeFile(self,
                   packet_len_format: str = "I",
                   crc_len_format: str = "L",
                   number_of_chunks_len_format: str = "I",
                   id_len_format: str = "I"):
        """
        Decodes the information from a file if self.file represents a file and the packets were saved in a single file.
        :param packet_len_format: Format of the packet length
        :param crc_len_format:  Format of the crc length
        :param number_of_chunks_len_format: Format of the number of chunks length
        :param id_len_format: Format of the ID length
        :return: -1 if the decoding wasn't successful
        """
        decoded = False
        self.EOF = False
        if self.file.lower().endswith("dna"):
            try:
                self.f.close()
                self.f = quat_file_to_bin(self.file)
            except TypeError:
                print("skipping CORRUPT file - contains illegal character(s)")
                self.corrupt += 1
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        if self.file.lower().endswith("fasta"):
            self.f.close()
            self.f = open(self.file, "r")
            raw_packet_list = []
            while not (decoded or self.EOF):
                line = self.f.readline()
                if not line:
                    self.EOF = True
                    break
                try:
                    error_prob, seed = line[1:].replace("\n", "").split("_")
                except:
                    error_prob, seed = "0", "0"
                line = self.f.readline()
                if not line:
                    self.EOF = True
                    break
                dna_str = line.replace("\n", "")
                raw_packet_list.append((error_prob, seed, dna_str))
                try:
                    new_pack = self.parse_raw_packet(
                        BytesIO(tranlate_quat_to_byte(dna_str)).read(),
                        crc_len_format=crc_len_format,
                        number_of_chunks_len_format=number_of_chunks_len_format,
                        packet_len_format=packet_len_format,
                        id_len_format=id_len_format)
                except Exception:
                    new_pack = "CORRUPT"
                if new_pack != "CORRUPT":
                    decoded = self.input_new_packet(new_pack)
                    if self.progress_bar is not None:
                        self.progress_bar.update(self.correct,
                                                 Corrupt=self.corrupt)
        else:
            while not (decoded or self.EOF):
                new_pack = self.getNextValidPacket(
                    False,
                    packet_len_format=packet_len_format,
                    crc_len_format=crc_len_format,
                    number_of_chunks_len_format=number_of_chunks_len_format,
                    id_len_format=id_len_format)
                if new_pack is None:
                    break
                # koennte durch input_new_packet ersetzt werden:
                # self.addPacket(new_pack)
                if new_pack != "CORRUPT":
                    decoded = self.input_new_packet(new_pack)
                #
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        if self.GEPP.isPotentionallySolvable():
            return self.GEPP.solve()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1
        # self.f.close()

    def getNumberOfLDPCBlocks(self):
        return self.s

    def getNumberOfHalfBlocks(self):
        return self.h

    def getNumberOfRepairBlocks(self):
        return self.getNumberOfHalfBlocks() + self.getNumberOfLDPCBlocks()

    def input_new_packet(self, packet: RU10Packet):
        """
        Removes auxpackets (LDPC and Half) and adds the remaining data to the GEPP matrix.
        :param packet: A Packet to add to the GEPP matrix
        :return: True: If solved. False: Else.
        """
        if self.ldpcANDhalf == dict(
        ) and self.distribution is None:  # self.isPseudo and
            self.distribution = RaptorDistribution(self.number_of_chunks)
            self.number_of_chunks = packet.get_total_number_of_chunks()
            _, self.s, self.h = intermediate_symbols(self.number_of_chunks,
                                                     self.distribution)
            self.createAuxBlocks()
            self.progress_bar = self.create_progress_bar(
                self.number_of_chunks + 0.02 * self.number_of_chunks)
        # we need to do it twice sine half symbols may contain ldpc symbols (which by definition are repair codes.)
        if self.debug:
            print("----")
            print("Id = " + str(packet.id))
            print(packet.used_packets)
        removed = self.removeAndXorAuxPackets(packet)
        if self.debug:
            print(from_true_false_list(removed))
            print(packet.get_error_correction())
            print("----")
        if self.count:
            for i in range(len(removed)):
                if i in self.counter.keys():
                    if removed[i]:
                        self.counter[i] += 1
                else:
                    self.counter[i] = 1
        if self.GEPP is None:
            self.GEPP = GEPP(
                np.array([removed], dtype=bool),
                np.array([[packet.get_data()]], dtype=bytes),
            )
        else:
            self.GEPP.addRow(
                np.array(removed, dtype=bool),
                np.frombuffer(packet.get_data(), dtype="uint8"),
            )
        if (self.isPseudo or not self.read_all_before_decode
            ) and self.GEPP.isPotentionallySolvable():
            # and self.GEPP.n % 5 == 0:  # Nur alle 5 Packete versuch starten
            if self.debug:
                print("current size: " + str(self.GEPP.n))
            return self.GEPP.solve(partial=False)
        return False

    # Correct
    def removeAndXorAuxPackets(self, packet: RU10Packet):
        """
        Removes auxpackets (LDCP and Half) from a given packet to get the packets data.
        :param packet: Packet to remove auxpackets from
        :return: The data without the auxpackets
        """
        aux_mapping = self.getHalfPacketListFromPacket(
            packet)  # Enthaelt Data + LDPC Nummern
        aux_mapping.append(packet.get_bool_array_used_and_ldpc_packets())
        xored_list = logical_xor(aux_mapping)
        del aux_mapping
        tmp = from_true_false_list(
            xored_list)  # Nur noch Data + LDPC sind vorhanden
        if self.debug:
            print(tmp)
        tmp = RU10Packet("",
                         tmp,
                         self.number_of_chunks,
                         packet.id,
                         packet.dist,
                         read_only=True)
        aux_mapping = self.getAuxPacketListFromPacket(tmp)
        aux_mapping.append(
            tmp.get_bool_array_used_packets())  # [-len(self.auxBlocks):])
        res = logical_xor(aux_mapping)
        del tmp, aux_mapping
        return res

    def createAuxBlocks(self):
        """
        Reconstructs the auxblocks to be able to remove them afterwards.
        :return:
        """
        assert (self.number_of_chunks is not None
                ), "createAuxBlocks can only be called AFTER first Packet"
        if self.debug:
            print("We should have " + str(self.getNumberOfLDPCBlocks()) +
                  " LDPC-Blocks, " + str(self.getNumberOfHalfBlocks()) +
                  " Half-Blocks and " + str(self.number_of_chunks) +
                  " normal Chunks (including 1 HeaderChunk)")
        for i in range(0, self.getNumberOfRepairBlocks()):
            self.repairBlockNumbers[i] = set()
        i = 0
        for group in self.generateIntermediateBlocksFormat(
                self.number_of_chunks):
            for elem in group:
                self.repairBlockNumbers[i] = elem
                i += 1
        # XOR all Chunks into the corresponding AUX-Block
        for aux_number in self.repairBlockNumbers.keys():
            self.ldpcANDhalf[aux_number] = RU10IntermediatePacket(
                "",
                self.repairBlockNumbers[aux_number],
                total_number_of_chunks=self.number_of_chunks,
                id=aux_number,
                dist=self.distribution)
            if self.debug:
                print(
                    str(aux_number) + " : " +
                    str(self.ldpcANDhalf[aux_number].used_packets))

    # Correct
    def getAuxPacketListFromPacket(self, packet: RU10Packet):
        """
        Creates a list for a packet with information about whether auxpackets have been used for that packet.
        :param packet: The packet to check.
        :return: Information about used auxpackets.
        """
        res = []
        aux_used_packets = packet.get_bool_array_repair_packets()
        for i in range(len(aux_used_packets)):
            if aux_used_packets[i]:
                res.append((self.ldpcANDhalf[i].get_bool_array_used_packets()))
        return res

    def getHalfPacketListFromPacket(
            self, packet: RU10Packet) -> typing.List[typing.List[bool]]:
        """
        Generates a list of halfpackets from a packet.
        :param packet: The packet to get the list from
        :return: List of halfpackets
        """
        res: typing.List[typing.List[bool]] = []
        aux_used_packets = packet.get_bool_array_half_packets()
        for i in range(len(aux_used_packets)):
            if aux_used_packets[i]:
                res.append((self.ldpcANDhalf[
                    packet.get_number_of_ldpc_blocks() +
                    i].get_bool_array_used_and_ldpc_packets()))
        return res

    def solve(self, partial=False) -> bool:
        """
        Calls GEPP.solve()
        :return: True: If GEPP was able to solve the matrix. False: Else.
        """
        return self.GEPP.solve(partial=partial)

    def getSolvedCount(self) -> int:
        return self.GEPP.getSolvedCount()

    def is_decoded(self) -> bool:
        """
        Checks if the data is decoded.
        :return: True: If the decoding was successfull. False: Else.
        """
        return self.GEPP is not None and self.GEPP.isPotentionallySolvable(
        ) and self.GEPP.isSolved()

    def getNextValidPacket(
            self,
            from_multiple_files: bool = False,
            packet_len_format: str = "I",
            crc_len_format: str = "L",
            number_of_chunks_len_format: str = "I",
            id_len_format: str = "I") -> typing.Optional[RU10Packet]:
        """
        Takes a raw packet from a file and calls @parse_raw_packet to get a RU10 packet. If the packet is corrupt the
        next one will be taken.
        :param from_multiple_files: True: The packets were saved in multiple files. False: Packets were saved in one file.
        :param packet_len_format: Format of the packet length
        :param crc_len_format:  Format of the crc length
        :param number_of_chunks_len_format: Format of the number of chunks length
        :param id_len_format: Format of the ID length
        :return: RU10Packet
        """
        if not from_multiple_files:
            packet_len = self.f.read(struct.calcsize("<" + packet_len_format))
            try:
                packet_len = struct.unpack("<" + packet_len_format,
                                           packet_len)[0]
                packet = self.f.read(int(packet_len))
            except:
                return None
        else:
            packet = self.f.read()
            packet_len = len(packet)
        if not packet or not packet_len:  # EOF
            self.EOF = True
            try:
                self.f.close()
            except:
                return None
            return None
        res = self.parse_raw_packet(
            packet,
            crc_len_format=crc_len_format,
            packet_len_format=packet_len_format,
            number_of_chunks_len_format=number_of_chunks_len_format,
            id_len_format=id_len_format)
        if res == "CORRUPT" and not from_multiple_files:
            res = self.getNextValidPacket(
                from_multiple_files,
                packet_len_format=packet_len_format,
                crc_len_format=crc_len_format,
                number_of_chunks_len_format=number_of_chunks_len_format,
                id_len_format=id_len_format)
        return res

    def parse_raw_packet(
            self,
            packet,
            crc_len_format: str = "L",
            number_of_chunks_len_format: str = "L",
            packet_len_format: str = "I",
            id_len_format: str = "L") -> typing.Union[RU10Packet, str]:
        """
        Creates a RU10 packet from a raw given packet. Also checks if the packet is corrupted. If any method was used to
        create packets from specific chunks, set self.use_method = True. This will treat the last byte of the raw packet
        data as the byte that contains the information about the used method ("even", "odd", "window_30 + window" or
        "window_40 + window". See RU10Encoder.create_new_packet_from_chunks for further information.
        :param packet: A raw packet
        :param packet_len_format: Format of the packet length
        :param crc_len_format:  Format of the crc length
        :param number_of_chunks_len_format: Format of the number of chunks length
        :param id_len_format: Format of the ID length
        :return: RU10Packet or an error message
        """
        struct_str = "<" + number_of_chunks_len_format + id_len_format
        struct_len = struct.calcsize(struct_str)
        try:
            packet = self.error_correction(packet)
        except:
            self.corrupt += 1
            return "CORRUPT"

        data = packet[struct_len:]
        chunk_lst = []
        if self.use_method:
            method_data = bin(data[-1])[2:]
            while len(method_data) < 8:
                method_data = '0' + method_data
            data = data[:-1]
            if method_data.startswith('00'):
                chunk_lst = [
                    ch for ch in range(0, self.number_of_chunks + 1)
                    if ch % 2 == 0
                ]
            elif method_data.startswith('01'):
                chunk_lst = [
                    ch for ch in range(0, self.number_of_chunks + 1)
                    if ch % 2 != 0
                ]
            elif method_data.startswith('10'):
                window = int(method_data[2:], 2)
                window_size = 30
                start = window * (window_size - 10)
                chunk_lst = [
                    ch for ch in range(start, start + window_size)
                    if ch <= self.number_of_chunks
                ]
            elif method_data.startswith('11'):
                window = int(method_data[2:], 2)
                window_size = 40
                start = window * (window_size - 10)
                chunk_lst = [
                    ch for ch in range(start, start + window_size)
                    if ch <= self.number_of_chunks
                ]
            else:
                raise RuntimeError("Not a valid start:", method_data)
        len_data = struct.unpack(struct_str, packet[0:struct_len])
        if self.static_number_of_chunks is None:
            self.number_of_chunks = xor_mask(len_data[0],
                                             number_of_chunks_len_format)
            unxored_id = xor_mask(len_data[1], id_len_format)
        else:
            unxored_id = xor_mask(len_data[0], id_len_format)
        if self.distribution is None:
            self.distribution = RaptorDistribution(self.number_of_chunks)
            _, self.s, self.h = intermediate_symbols(self.number_of_chunks,
                                                     self.distribution)
            self.progress_bar = self.create_progress_bar(
                self.number_of_chunks + 0.02 * self.number_of_chunks)

        if self.correct == 0:
            self.createAuxBlocks()
        self.correct += 1
        if self.use_method:
            numbers = choose_packet_numbers(len(chunk_lst),
                                            unxored_id,
                                            self.distribution,
                                            systematic=False,
                                            max_l=len(chunk_lst))
            used_packets = [chunk_lst[i] for i in numbers]
        else:
            used_packets = choose_packet_numbers(self.number_of_chunks,
                                                 unxored_id,
                                                 self.distribution,
                                                 systematic=False)
        res = RU10Packet(
            data,
            used_packets,
            self.number_of_chunks,
            unxored_id,
            read_only=True,
            packet_len_format=packet_len_format,
            crc_len_format=crc_len_format,
            number_of_chunks_len_format=number_of_chunks_len_format,
            id_len_format=id_len_format,
            save_number_of_chunks_in_packet=self.static_number_of_chunks is
            None)
        return res

    def generateIntermediateBlocksFormat(
            self, number_of_chunks: int
    ) -> typing.List[typing.List[typing.List[int]]]:
        """
        Generates the format of the intermediate blocks from the number of used chunks.
        :param number_of_chunks: The number of used chunks.
        :return:
        """
        compositions: typing.List[typing.List[int]] = [[]
                                                       for _ in range(self.s)]
        for i in range(0, number_of_chunks):
            a = 1 + (int(floor(np.float64(i) / np.float64(self.s))) %
                     (self.s - 1))
            b = int(i % self.s)
            compositions[b].append(i)
            b = (b + a) % self.s
            compositions[b].append(i)
            b = (b + a) % self.s
            compositions[b].append(i)

        hprime: int = int(ceil(np.float64(self.h) / 2))
        m = buildGraySequence(number_of_chunks + self.s, hprime)
        hcompositions: typing.List[typing.List[int]] = [[]
                                                        for _ in range(self.h)]
        for i in range(0, self.h):
            hcomposition = []
            for j in range(0, number_of_chunks + self.s):
                if bitSet(np.uint32(m[j]), np.uint32(i)):
                    hcomposition.append(j)
            hcompositions[i] = hcomposition
        res = [compositions, hcompositions]
        return res

    def saveDecodedFile(self, last_chunk_len_format: str = "I", null_is_terminator: bool = False,
                        print_to_output: bool = True, return_file_name=False, partial_decoding: bool = True) -> \
            typing.Union[bytes, str]:
        """
        Saves the file - if decoded. The filename is either taken from the headerchunk or generated based on the input
        filename.
        :param partial_decoding: perform partial decoding if full decoding failed, missing parts will be filled with "\x00"
        :param return_file_name: if set to true, this function will return the filename under which the file as been saved
        :param last_chunk_len_format: Format of the last chunk length
        :param null_is_terminator: True: The file is handled as null-terminated C-String.
        :param print_to_output: True: Result we be printed to the command line.
        :return:
        """
        assert self.is_decoded(
        ) or partial_decoding, "Can not save File: Unable to reconstruct. You may try saveDecodedFile(partial_decoding=True)"
        if partial_decoding:
            self.solve(partial=True)
        dirty = False
        if self.use_headerchunk:
            header_row = self.GEPP.result_mapping[0]
            if header_row >= 0:
                self.headerChunk = HeaderChunk(
                    Packet(self.GEPP.b[header_row], {0},
                           self.number_of_chunks,
                           read_only=True),
                    last_chunk_len_format=last_chunk_len_format)
        file_name = "DEC_" + os.path.basename(
            self.file) if self.file is not None else "RU10.BIN"
        output_concat = b""
        if self.headerChunk is not None:
            try:
                file_name = self.headerChunk.get_file_name().decode("utf-8")
            except Exception as ex:
                print("Warning:", ex)
        file_name = file_name.split("\x00")[0]
        with open(file_name, "wb") as f:
            for x in self.GEPP.result_mapping:
                if x < 0:
                    f.write(b"\x00" * len(self.GEPP.b[x][0]))
                    dirty = True
                    continue
                if 0 != x or not self.use_headerchunk:
                    if self.number_of_chunks - 1 == x and self.use_headerchunk:
                        output = self.GEPP.b[x][0][0:self.headerChunk.
                                                   get_last_chunk_length()]
                        output_concat += output.tobytes()
                        f.write(output)
                    else:
                        if null_is_terminator:
                            splitter = self.GEPP.b[x].tostring().decode(
                            ).split("\x00")
                            output = splitter[0].encode()
                            if type(output) == bytes:
                                output_concat += output
                            else:
                                output_concat += output.tobytes()
                            f.write(output)
                            if len(splitter) > 1:
                                break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                        else:
                            output = self.GEPP.b[x]
                            if type(output) == bytes:
                                output_concat += output
                            else:
                                output_concat += output.tobytes()
                            f.write(output)
        print("Saved file as '" + str(file_name) + "'")
        if dirty:
            print(
                "Some parts could not be restored, file WILL contain sections with \\x00 !"
            )
        if print_to_output:
            print("Result:")
            print(output_concat.decode("utf-8"))
        if self.progress_bar is not None:
            self.progress_bar.update(self.number_of_chunks,
                                     Corrupt=self.corrupt)
        if return_file_name:
            return file_name
        return output_concat

    def mode_1_bmp_decode(self, last_chunk_len_format: str = "I"):
        dec_out = self.saveDecodedFile(
            last_chunk_len_format=last_chunk_len_format,
            null_is_terminator=False,
            print_to_output=False)
        return self.bytes_to_bitmap(dec_out)

    def bytes_to_bitmap(self, img_byt: bytes):
        width, height = struct.unpack('>H', img_byt[:2])[0], struct.unpack(
            '>H', img_byt[2:4])[0]
        unpack = np.unpackbits(
            np.frombuffer(img_byt,
                          dtype=np.uint8,
                          count=int((width * height) / 8),
                          offset=4)).reshape(height, width).transpose()
        flip_bits = np.logical_not(unpack).astype(int)
        new_img = self.draw_img(flip_bits, width, height)
        tmp_file_name = os.path.basename(self.file) + ".bmp"
        file_name = "DEC_" + tmp_file_name if self.file is not None else "RU10.BIN.bmp"
        new_img.save(file_name)
        return file_name

    @staticmethod
    def draw_img(unpacked_flipped_bits, width: int, height: int) -> Image:
        new_img = Image.new('1', (width, height))
        pixels = new_img.load()

        for i in range(new_img.size[0]):
            for j in range(new_img.size[1]):
                pixels[i, j] = int(unpacked_flipped_bits[i, j])
        return new_img
Ejemplo n.º 12
0
class LTDecoder(Decoder):
    def __init__(self, file: typing.Optional[str] = None, error_correction: typing.Callable = nocode,
                 use_headerchunk: bool = True, static_number_of_chunks: typing.Optional[int] = None,
                 implicit_mode: bool = True, dist: typing.Optional[Distribution] = None):
        super().__init__(file)
        self.use_headerchunk: bool = use_headerchunk
        self.isPseudo: bool = False
        self.file: typing.Optional[str] = file
        self.degreeToPacket: typing.Dict[int, Packet] = {}
        if file is not None:
            self.isFolder: bool = os.path.isdir(file)
            if not self.isFolder:
                self.f = open(self.file, "rb")
        self.correct: int = 0
        self.corrupt: int = 0
        self.number_of_chunks: int = 1000000
        self.headerChunk: typing.Optional[HeaderChunk] = None
        self.GEPP: typing.Optional[GEPP] = None
        self.pseudoCount: int = 0
        self.read_all_before_decode: bool = True
        self.count: bool = True
        self.counter: typing.Dict[int, int] = dict()
        self.error_correction: typing.Callable = error_correction
        self.static_number_of_chunks: int = static_number_of_chunks
        self.implicit_mode: bool = implicit_mode
        self.dist: typing.Optional[Distribution] = dist
        self.EOF: bool = False

    def decodeFolder(self, packet_len_format: str = "I", crc_len_format: str = "L",
                     number_of_chunks_len_format: str = "I", degree_len_format: str = "I", seed_len_format: str = "I",
                     last_chunk_len_format: str = "I") -> typing.Optional[int]:
        decoded: bool = False
        self.EOF: bool = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        for file in os.listdir(self.file):
            if file.endswith(".LT") or file.endswith("DNA"):
                self.EOF = False
                if file.endswith("DNA"):
                    self.f = quat_file_to_bin(self.file + "/" + file)
                else:
                    self.f = open(self.file + "/" + file, "rb")
                new_pack = self.getNextValidPacket(True, packet_len_format=packet_len_format,
                                                   crc_len_format=crc_len_format,
                                                   number_of_chunks_len_format=number_of_chunks_len_format,
                                                   degree_len_format=degree_len_format, seed_len_format=seed_len_format,
                                                   last_chunk_len_format=last_chunk_len_format)
                if new_pack is not None:
                    decoded = self.input_new_packet(new_pack)
                if decoded:
                    break
        self.EOF = True
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        if self.GEPP.isPotentionallySolvable():
            decoded = self.GEPP.solve()
        if hasattr(self, "f"):
            self.f.close()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def decodeFile(self, packet_len_format: str = "I", crc_len_format: str = "L",
                   number_of_chunks_len_format: str = "I", degree_len_format: str = "I", seed_len_format: str = "I",
                   last_chunk_len_format: str = "I") -> typing.Optional[int]:
        decoded: bool = False
        self.EOF: bool = False
        if self.static_number_of_chunks is not None:
            self.number_of_chunks = self.static_number_of_chunks
            number_of_chunks_len_format = ""  # if we got static number_of_chunks we do not need it in struct string
        if self.file.lower().endswith("fasta"):
            self.f.close()
            self.f = open(self.file, "r")
            raw_packet_list = []
            while not (decoded or self.EOF):
                line = self.f.readline()
                if not line:
                    self.EOF = True
                    break
                try:
                    error_prob, seed = line[1:].replace("\n", "").split("_")
                except:
                    error_prob, seed = "0", "0"
                line = self.f.readline()
                if not line:
                    self.EOF = True
                    break
                dna_str = line.replace("\n", "")
                raw_packet_list.append((error_prob, seed, dna_str))
                new_pack = self.parse_raw_packet(BytesIO(tranlate_quat_to_byte(dna_str)).read(),
                                                 crc_len_format=crc_len_format,
                                                 number_of_chunks_len_format=number_of_chunks_len_format,
                                                 degree_len_format=degree_len_format,
                                                 seed_len_format=seed_len_format)
                decoded = self.input_new_packet(new_pack)
                if self.progress_bar is not None:
                    self.progress_bar.update(self.correct, Corrupt=self.corrupt)
            else:
                while not (decoded or self.EOF):
                    new_pack = self.getNextValidPacket(False, packet_len_format=packet_len_format,
                                                       crc_len_format=crc_len_format,
                                                       number_of_chunks_len_format=number_of_chunks_len_format,
                                                       degree_len_format=degree_len_format,
                                                       seed_len_format=seed_len_format,
                                                       last_chunk_len_format=last_chunk_len_format)
                    if new_pack is None:
                        break
                    # koennte durch input_new_packet ersetzt werden:
                    # self.addPacket(new_pack)
                    decoded = self.input_new_packet(new_pack)
                    ##
        print("Decoded Packets: " + str(self.correct))
        print("Corrupt Packets : " + str(self.corrupt))
        self.f.close()
        if self.GEPP.isPotentionallySolvable():
            return self.GEPP.solve()
        if not decoded and self.EOF:
            print("Unable to retrieve File from Chunks. Too much errors?")
            return -1

    def input_new_packet(self, packet: Packet) -> bool:
        self.pseudoCount += 1
        packets: typing.List[bool] = packet.get_bool_array_used_packets()
        if self.count:
            for i in range(len(packets)):
                if i in self.counter.keys():
                    if packets[i]:
                        self.counter[i] += 1
                else:
                    self.counter[i] = 1
        if self.GEPP is None:
            self.GEPP = GEPP(np.array([packet.get_bool_array_used_packets()], dtype=bool),
                             np.array([[packet.get_data()]], dtype=bytes), )
        else:
            self.GEPP.addRow(packet.get_bool_array_used_packets(), np.frombuffer(packet.get_data(), dtype="uint8"), )
        if self.isPseudo and not self.read_all_before_decode and self.GEPP.isPotentionallySolvable():
            return self.GEPP.solve(partial=False)
        return False

    def solve(self) -> bool:
        return self.GEPP.solve()

    def getSolvedCount(self) -> int:
        return self.GEPP.getSolvedCount()

    def choose_packet_numbers(self, degree: int, seed: int = 0) -> typing.Set[int]:
        assert degree <= self.number_of_chunks
        res: typing.Set[int] = set()
        rng = np.random
        rng.seed(seed)
        for _ in range(0, degree):
            tmp = rng.choice(range(0, self.number_of_chunks))
            while tmp in res:
                tmp = rng.choice(range(0, self.number_of_chunks))
            res.add(tmp)
        return res

    def is_decoded(self) -> bool:
        return self.GEPP is not None and self.GEPP.isPotentionallySolvable() and self.GEPP.isSolved()

    def getNextValidPacket(self, from_multiple_files: bool = False, packet_len_format: str = "I",
                           crc_len_format: str = "L", number_of_chunks_len_format: str = "I",
                           degree_len_format: str = "I", seed_len_format: str = "I",
                           last_chunk_len_format: str = "I") -> typing.Optional[Packet]:
        if not from_multiple_files:
            packet_len: typing.Union[int, bytes] = self.f.read(struct.calcsize("<" + packet_len_format))
            packet_len = struct.unpack("<" + packet_len_format, packet_len)[0]
            packet: bytes = self.f.read(int(packet_len))
        else:
            packet = self.f.read()
            packet_len = len(packet)
        if not packet or not packet_len:  # EOF
            self.EOF: bool = True
            self.f.close()
            return None
        res = self.parse_raw_packet(packet, number_of_chunks_len_format=number_of_chunks_len_format,
                                    degree_len_format=degree_len_format,
                                    seed_len_format=seed_len_format)
        if res == "CORRUPT":
            res = self.getNextValidPacket(from_multiple_files=from_multiple_files,
                                          number_of_chunks_len_format=number_of_chunks_len_format,
                                          degree_len_format=degree_len_format,
                                          seed_len_format=seed_len_format)
        return res

    def saveDecodedFile(self, last_chunk_len_format: str = "I", null_is_terminator: bool = False,
                        print_to_output: bool = True) -> None:
        assert self.is_decoded(), "Can not save File: Unable to reconstruct."
        if self.use_headerchunk:
            self.headerChunk = HeaderChunk(Packet(self.GEPP.b[0], {0}, self.number_of_chunks, read_only=True),
                                           last_chunk_len_format=last_chunk_len_format)
        file_name = "DEC_" + os.path.basename(self.file) if self.file is not None else "LT.BIN"
        if self.headerChunk is not None:
            file_name = self.headerChunk.get_file_name().decode("utf-8")
        output_concat: bytes = b""
        file_name: str = file_name.split("\x00")[0]
        try:
            with open(file_name, "wb") as f:
                for x in self.GEPP.result_mapping:
                    if 0 != x or not self.use_headerchunk:
                        if self.number_of_chunks - 1 == x and self.use_headerchunk:
                            output: typing.Union[bytes, np.array] = self.GEPP.b[x][0][
                                                                    0: self.headerChunk.get_last_chunk_length()]
                            output_concat += output.tobytes()
                            f.write(output)
                        else:
                            if null_is_terminator:
                                splitter: str = self.GEPP.b[x].tostring().decode().split("\x00")
                                output = splitter[0].encode()
                                if type(output) == bytes:
                                    output_concat += output
                                else:
                                    output_concat += output.tobytes()
                                f.write(output)
                                if len(splitter) > 1:
                                    break  # since we are in null-terminator mode, we exit once we see the first 0-byte
                            else:
                                output = self.GEPP.b[x]
                                output_concat += output.tobytes()
                                f.write(output)
            print("Saved file as '" + str(file_name) + "'")
        except Exception as ex:
            raise ex
        if print_to_output:
            print("Result:")
            print(output_concat.decode("utf-8"))

    def parse_raw_packet(self, packet: bytes, crc_len_format: str = "L", number_of_chunks_len_format: str = "I",
                         degree_len_format: str = "I", seed_len_format: str = "I") -> typing.Union[str, Packet]:
        crc_len = -struct.calcsize("<" + crc_len_format)
        if self.error_correction.__code__.co_name == crc32.__code__.co_name:
            payload: bytes = packet[:crc_len]
            crc: int = struct.unpack("<" + crc_len_format, packet[crc_len:])[0]
            calced_crc: int = calc_crc(payload)
            if crc != calced_crc:  # If the Packet is corrupt, try next one
                print("[-] CRC-Error - " + str(hex(crc)) + " != " + str(hex(calced_crc)))
                self.corrupt += 1
                return "CORRUPT"
        else:
            crc_len = None
            try:
                packet = self.error_correction(packet)
            except:
                self.corrupt += 1
                return "CORRUPT"
        if self.implicit_mode:
            degree_len_format = ""
        struct_str: str = "<" + number_of_chunks_len_format + degree_len_format + seed_len_format
        struct_len: int = struct.calcsize(struct_str)
        len_data: typing.Union[int, typing.Tuple[int, int], typing.Tuple[int, int, int]] = struct.unpack(struct_str,
                                                                                                         packet[
                                                                                                         0:struct_len])
        degree: typing.Optional[int] = None
        if self.static_number_of_chunks is None:
            if self.implicit_mode:
                number_of_chunks, seed = len_data
            else:
                number_of_chunks, degree, seed = len_data
            self.number_of_chunks = xor_mask(number_of_chunks, number_of_chunks_len_format)
        else:
            if self.implicit_mode:
                seed = len_data
            else:
                degree, seed = len_data
        seed: int = xor_mask(seed, seed_len_format)
        if degree is None:
            self.dist.set_seed(seed)
            degree: int = self.dist.getNumber()
        else:
            degree: int = xor_mask(degree, degree_len_format)
        used_packets = self.choose_packet_numbers(degree, seed)
        data = packet[struct_len:crc_len]
        self.correct += 1

        return Packet(data, used_packets, self.number_of_chunks, read_only=True, error_correction=self.error_correction,
                      save_number_of_chunks_in_packet=self.static_number_of_chunks is None)