Ejemplo n.º 1
0
class NALUnit:
    def __init__(self, filp, pos, utype, size, tid, nosetup=False):
        self.filp = filp
        self.pos = pos
        self.num_bytes_in_nalu = size
        self.nal_unit_type = utype
        self.temporal_id = tid

        self.rbsp_byte = BitStream()
        self.setup_rbsp()
        self.print_bin()
        if not nosetup:
            self.setup()

    def rbsp_read(self, fmt):
        return self.rbsp_byte.read(fmt)

    def next_bits(self, fmt, forward=False):
        try:
            if not forward:
                ret = self.filp.peek(fmt)
            else:
                ret = self.filp.read(fmt)
        except ReadError:
            return 0
        return ret

    def setup_rbsp(self):
        self.filp.bytepos += 2
        self.num_bytes_in_rbsp = 0
        i = 2
        while i < self.num_bytes_in_nalu:
            if i + 2 < self.num_bytes_in_nalu and self.next_bits('hex: 24') == '000003':
                self.rbsp_byte.append(self.filp.read(8))
                self.num_bytes_in_rbsp += 1
                self.rbsp_byte.append(self.filp.read(8))
                self.num_bytes_in_rbsp += 1
                # discard emulation_prevention_three_byte
                self.filp.read(8)
                i += 3
            else:
                self.rbsp_byte.append(self.filp.read(8))
                self.num_bytes_in_rbsp += 1
                i += 1
        self.filp.bytepos = self.pos

    def more_rbsp_data(self):
        try:
            while self.rbsp_read('uint: 1') != 1:
                pass
            return True
        except ReadError:
            return False

    def rbsp_trailing_bits(self):
        pos = self.rbsp_byte.pos
        bits = self.rbsp_byte.bytealign()
        self.rbsp_byte.pos = pos
        if bits == 0 or self.rbsp_read('uint: %d' % bits) != 1 << bits:
            print('Wrong rbsp_trailing_bits at NALU begined at bytes %d' % self.pos)
            exit(1)

    def profile_tier_level(self, ProfilePresentFlag, MaxNumSubLayersMinus1):
        if ProfilePresentFlag:
            self.general_profile_space = self.rbsp_read('uint: 2')
            self.general_tier_flag = self.rbsp_read('uint: 1')
            self.general_profile_idc = self.rbsp_read('uint: 5')
            self.general_profile_compatibility_flag = [0] * 32
            for i in range(0, 32):
                self.general_profile_compatibility_flag[i] = self.rbsp_read("uint: 1")
            self.general_reserved_zero_16bits = self.rbsp_read('uint: 16')
        self.general_level_idc = self.rbsp_read('uint: 8')
        self.sub_layer_profile_present_flag = [0] * MaxNumSubLayersMinus1
        self.sub_layer_level_present_flag = [0] * MaxNumSubLayersMinus1
        self.sub_layer_profile_space = [0] * MaxNumSubLayersMinus1
        self.sub_layer_tier_flag = [0] * MaxNumSubLayersMinus1
        self.sub_layer_profile_idc = [0] * MaxNumSubLayersMinus1
        self.sub_layer_reserved_zero_16bits = [0] * MaxNumSubLayersMinus1
        self.sub_layer_level_idc = [0] * MaxNumSubLayersMinus1
        self.sub_layer_profile_compatibility_flag = [[0] * 32] * MaxNumSubLayersMinus1
        for i in range(0, MaxNumSubLayersMinus1):
            self.sub_layer_profile_present_flag[i] = self.rbsp_read('uint: 1')
            self.sub_layer_level_present_flag[i] = self.rbsp_read('uint: 1')
            if ProfilePresentFlag and self.sub_layer_profile_present_flag[i]:
                self.sub_layer_profile_space[i] = self.rbsp_read('uint: 2')
                self.sub_layer_tier_flag[i] = self.rbsp_read('uint: 1')
                self.sub_layer_profile_idc[i] = self.rbsp_read('uint: 5')
                for j in range(0, 32):
                    self.sub_layer_profile_compatibility_flag[i][j] = self.rbsp_read('uint: 1')
                self.sub_layer_reserved_zero_16bits[i] = self.rbsp_read('uint: 16')
            if self.sub_layer_level_present_flag[i]:
                self.sub_layer_level_idc[i] = self.rbsp_read('uint: 8')

    def op_point(self, opIdx):
        self.op_num_layer_id_values_minus1 = [opIdx] = self.rbsp_read('ue')
        self.op_layer_id[opIdx] = [0] * self.op_num_layer_id_values_minus1
        for i in range(0, self.op_num_layer_id_values_minus1):
            self.op_layer_id[opIdx][i] = self.rbsp_read('uint: 6')

    def short_term_ref_pic_set(self, idxRps):
        if idxRps != 0:
            self.inter_ref__pic_set_prediction_flag = self.rbsp_read('uint: 1')
        else:
            self.inter_ref__pic_set_prediction_flag = 0
        if self.inter_ref__pic_set_prediction_flag:
            if idxRps == self.num_short_term_ref_pic_sets:
                self.delta_idx_minus1 = self.rbsp_read('ue')
            else:
                self.delta_idx_minus1 = 0
            RIdx = idxRps - self.delta_idx_minus1 - 1
            self.delta_rps_sign = self.rbsp_read('uint: 1')
            self.abs_delta_rps_minus1 = self.rbsp_read('ue')
            self.used_by_curr_pic_flag = [0] * (self.NumDeltaPocs[RIdx] + 1)
            self.use_delta_flag = [0] * (self.NumDeltaPocs[RIdx] + 1)
            for i in range(0, self.NumDeltaPocs[RIdx] + 1):
                self.used_by_curr_pic_flag[i] = self.rbsp_read('uint: 1')
                if not self.used_by_currpic_flag[i]:
                    self.use_delta_flag[i] = self.rbsp_read('uint: 1')
        else:
            num_negative_pics = self.rbsp_read('ue')
            num_positive_pics = self.rbsp_read('ue')
            self.delta_poc_s0_minus1 = [0] * num_negative_pics
            self.used_by_curr_pic_s0_flag = [0] * num_negative_pics
            for i in range(0, num_negative_pics):
                self.delta_poc_s0_minus1[i] = self.rbsp_read('ue')
                self.used_by_curr_pic_s0_flag[i] = self.rbsp_read('uint: 1')
            self.delta_poc_s1_minus1 = [0] * num_positive_pics
            self.used_by_curr_pic_s1_flag = [0] * num_positive_pics
            for i in range(0, num_positive_pics):
                self.delta_poc_s1_minus1[i] = self.rbsp_read('ue')
                self.used_by_curr_pic_s1_flag[i] = self.rbsp_read('uint: 1')
            self.NumDeltaPocs[idxRps] = num_negative_pics + num_positive_pics

    def setup(self):
        pass

    def __str__(self):
        return 'NALU:  pos=%d, length=%d, type=%s, tid=%d' % (self.pos, self.num_bytes_in_nalu, NAL_UNIT_TYPE[self.nal_unit_type], self.temporal_id)

    def print_bin(self):
        i = 1
        for abyte in self.rbsp_byte.tobytes():
            if i == 16:
                print('%3s' % hex(abyte)[2: ])
                i = 1
            else:
                print('%3s' % hex(abyte)[2: ], end=' ')
                i += 1
        print('\n')
Ejemplo n.º 2
0
class Bitstream:
    """
    class for input bitstream
    """

    def __init__(self, file):
        self.srcfile = BitStream(file)
        self.cur_num_bytes_in_nalu = 0
        self.cur_index = 0
        self.nalu_list = []

    def get_next_nalu(self):
        self.cur_index += 1
        return self.nalu_list[self.cur_index-1]

    def set_next_nalu_pos(self, pos):
        self.cur_index = pos

    def get_all_nalu(self):
        return self.nalu_list

    def next_bits(self, bits, forward=False):
        try:
            if not forward:
                ret = self.srcfile.peek('hex:%d' % bits)
            else:
                ret = self.srcfile.read('hex:%d' % bits)
        except ReadError:
            return 0
        return ret

    def calculate_nalu_size(self):
        head_pos = self.srcfile.bytepos

        value = self.next_bits(24)
        while value != 0 \
                and value != '000000' \
                and value != '000001':
            self.srcfile.bytepos += 1
            value = self.next_bits(24)
        else:
            if value == 0:
                # will be a bug here if last NALU is only 0 or 1 byte
                # but NALU header needs 2 bytes, so ignore that 
                size = self.srcfile.bytepos - head_pos + 2
            else:
                size = self.srcfile.bytepos - head_pos
        if self.srcfile.bytealign() != 0:
            print('Byte Align Error when calculate NALU Size!')
            exit(1)
        self.srcfile.bytepos = head_pos
        return size

    def init(self):
        """
        save all NALU positions in the list
        """
        while 1:
            while self.next_bits(24) != '000001' \
                    and self.next_bits(32) != '00000001':
                value = self.next_bits(bits_of_leading_zero_8bits, forward=True)
                if value != leading_zero_8bits:
                    print('wrong leading_zero_8bits')
                    exit(1)

            if self.next_bits(24) != '000001':
                value = self.next_bits(bits_of_zero_byte, forward=True)
                if value != zero_byte:
                    print('wrong zero_byte')
                    exit(1)

            value = self.next_bits(bits_of_start_code_prefix_one_3bytes, forward=True)
            if value != start_code_prefix_one_3bytes:
                print('wrong start_code_prefix_one_3bytes')
                exit(1)

            self.cur_num_bytes_in_nalu = self.calculate_nalu_size()

            self.nalu_list.append(nalunit.get_nalu(self.srcfile, self.srcfile.bytepos, self.cur_num_bytes_in_nalu))

            self.srcfile.bytepos = self.srcfile.bytepos + self.cur_num_bytes_in_nalu
            while self.next_bits(8) != 0 \
                    and self.next_bits(24) != '000001' \
                    and self.next_bits(32) != '00000001':
                value = self.next_bits(bits_of_trailing_zero_8bits, forward=True)
                if value != trailing_zero_8bits:
                    print('wrong trailing_zero_8bits')
                    exit(1)

            if self.next_bits(8) == 0:
                return

    def get_nalu_nums(self):
        return len(self.nalu_list)

    def __str__(self):
        return 'bitstream: length=%d, numofNALU=%d' % (len(self.srcfile), self.get_nalu_nums())