class FlacDecoder: CHANNEL_COUNT = [1, 2, 3, 4, 5, 6, 7, 8, 2, 2, 2, None, None, None, None, None] (SUBFRAME_CONSTANT, SUBFRAME_VERBATIM, SUBFRAME_FIXED, SUBFRAME_LPC) = range(4) def __init__(self, filename, channel_mask): self.reader = BitstreamReader(open(filename, "rb"), 0) if (self.reader.read_bytes(4) != 'fLaC'): raise ValueError("invalid FLAC file") self.current_md5sum = md5() #locate the STREAMINFO, #which is sometimes needed to handle non-subset streams for (block_id, block_size, block_reader) in self.metadata_blocks(self.reader): if (block_id == 0): #read STREAMINFO self.minimum_block_size = block_reader.read(16) self.maximum_block_size = block_reader.read(16) self.minimum_frame_size = block_reader.read(24) self.maximum_frame_size = block_reader.read(24) self.sample_rate = block_reader.read(20) self.channels = block_reader.read(3) + 1 self.channel_mask = channel_mask self.bits_per_sample = block_reader.read(5) + 1 self.total_frames = block_reader.read64(36) self.md5sum = block_reader.read_bytes(16) #these are frame header lookup tables #which vary slightly depending on STREAMINFO's values self.BLOCK_SIZE = [self.maximum_block_size, 192, 576, 1152, 2304, 4608, None, None, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] self.SAMPLE_RATE = [self.sample_rate, 88200, 176400, 192000, 8000, 16000, 22050, 24000, 32000, 44100, 48000, 96000, None, None, None, None] self.BITS_PER_SAMPLE = [self.bits_per_sample, 8, 12, None, 16, 20, 24, None] def metadata_blocks(self, reader): """yields a (block_id, block_size, block_reader) tuple per metadata block where block_reader is a BitstreamReader substream""" (last_block, block_id, block_size) = self.reader.parse("1u 7u 24u") while (last_block == 0): yield (block_id, block_size, self.reader.substream(block_size)) (last_block, block_id, block_size) = self.reader.parse("1u 7u 24u") else: yield (block_id, block_size, self.reader.substream(block_size)) def read(self, pcm_frames): #if the stream is exhausted, #verify its MD5 sum and return an empty pcm.FrameList object if (self.total_frames < 1): if (self.md5sum == self.current_md5sum.digest()): return from_list([], self.channels, self.bits_per_sample, True) else: raise ValueError("MD5 checksum mismatch") crc16 = CRC16() self.reader.add_callback(crc16.update) #fetch the decoding parameters from the frame header (block_size, channel_assignment, bits_per_sample) = self.read_frame_header() channel_count = self.CHANNEL_COUNT[channel_assignment] if (channel_count is None): raise ValueError("invalid channel assignment") #channel data will be a list of signed sample lists, one per channel #such as [[1, 2, 3, ...], [4, 5, 6, ...]] for a 2 channel stream channel_data = [] for channel_number in xrange(channel_count): if ((channel_assignment == 0x8) and (channel_number == 1)): #for left-difference assignment #the difference channel has 1 additional bit channel_data.append(self.read_subframe(block_size, bits_per_sample + 1)) elif ((channel_assignment == 0x9) and (channel_number == 0)): #for difference-right assignment #the difference channel has 1 additional bit channel_data.append(self.read_subframe(block_size, bits_per_sample + 1)) elif ((channel_assignment == 0xA) and (channel_number == 1)): #for mid-side assignment #the side channel has 1 additional bit channel_data.append(self.read_subframe(block_size, bits_per_sample + 1)) else: #otherwise, use the frame's bits-per-sample value channel_data.append(self.read_subframe(block_size, bits_per_sample)) #one all the subframes have been decoded, #reconstruct them depending on the channel assignment if (channel_assignment == 0x8): #left-difference samples = [] for (left, difference) in zip(*channel_data): samples.append(left) samples.append(left - difference) elif (channel_assignment == 0x9): #difference-right samples = [] for (difference, right) in zip(*channel_data): samples.append(difference + right) samples.append(right) elif (channel_assignment == 0xA): #mid-side samples = [] for (mid, side) in zip(*channel_data): samples.append((((mid * 2) + (side % 2)) + side) / 2) samples.append((((mid * 2) + (side % 2)) - side) / 2) else: #independent samples = [0] * block_size * channel_count for (i, channel) in enumerate(channel_data): samples[i::channel_count] = channel self.reader.byte_align() #read and verify the frame's trailing CRC-16 footer self.reader.read(16) self.reader.pop_callback() if (int(crc16) != 0): raise ValueError("CRC16 mismatch in frame footer") #deduct the amount of PCM frames from the remaining amount self.total_frames -= block_size #build a pcm.FrameList object from the combined samples framelist = from_list(samples, channel_count, bits_per_sample, True) #update the running MD5 sum calculation with the frame's data self.current_md5sum.update(framelist.to_bytes(0, 1)) #and finally return the frame data return framelist def read_frame_header(self): crc8 = CRC8() self.reader.add_callback(crc8.update) #read the 32-bit FLAC frame header sync_code = self.reader.read(14) if (sync_code != 0x3FFE): raise ValueError("invalid sync code") self.reader.skip(1) blocking_strategy = self.reader.read(1) block_size_bits = self.reader.read(4) sample_rate_bits = self.reader.read(4) channel_assignment = self.reader.read(4) bits_per_sample_bits = self.reader.read(3) self.reader.skip(1) #the frame number is a UTF-8 encoded value #which takes a variable number of whole bytes frame_number = self.read_utf8() #unpack the 4 bit block size field #which is the total PCM frames in the FLAC frame #and may require up to 16 more bits if the frame is usually-sized #(which typically happens at the end of the stream) if (block_size_bits == 0x6): block_size = self.reader.read(8) + 1 elif (block_size_bits == 0x7): block_size = self.reader.read(16) + 1 else: block_size = self.BLOCK_SIZE[block_size_bits] #unpack the 4 bit sample rate field #which is used for playback, but not needed for decoding #and may require up to 16 more bits #if the stream has a particularly unusual sample rate if (sample_rate_bits == 0xC): sample_rate = self.reader.read(8) * 1000 elif (sample_rate_bits == 0xD): sample_rate = self.reader.read(16) elif (sample_rate_bits == 0xE): sample_rate = self.reader.read(16) * 10 elif (sample_rate_bits == 0xF): raise ValueError("invalid sample rate") else: sample_rate = self.SAMPLE_RATE[sample_rate_bits] #unpack the 3 bit bits-per-sample field #this never requires additional bits if ((bits_per_sample_bits == 0x3) or (bits_per_sample_bits == 0x7)): raise ValueError("invalid bits per sample") else: bits_per_sample = self.BITS_PER_SAMPLE[bits_per_sample_bits] #read and verify frame's CRC-8 value self.reader.read(8) self.reader.pop_callback() if (int(crc8) != 0): raise ValueError("CRC8 mismatch in frame header") return (block_size, channel_assignment, bits_per_sample) def read_subframe_header(self): """returns a tuple of (subframe_type, subframe_order, wasted_bps)""" self.reader.skip(1) subframe_type = self.reader.read(6) if (self.reader.read(1) == 1): wasted_bps = self.reader.unary(1) + 1 else: wasted_bps = 0 #extract "order" value from 6 bit subframe type, if necessary if (subframe_type == 0): return (self.SUBFRAME_CONSTANT, None, wasted_bps) elif (subframe_type == 1): return (self.SUBFRAME_VERBATIM, None, wasted_bps) elif ((subframe_type & 0x38) == 0x08): return (self.SUBFRAME_FIXED, subframe_type & 0x07, wasted_bps) elif ((subframe_type & 0x20) == 0x20): return (self.SUBFRAME_LPC, (subframe_type & 0x1F) + 1, wasted_bps) else: raise ValueError("invalid subframe type") def read_subframe(self, block_size, bits_per_sample): (subframe_type, subframe_order, wasted_bps) = self.read_subframe_header() #read a list of signed sample values #depending on the subframe type, block size, #adjusted bits per sample and optional subframe order if (subframe_type == self.SUBFRAME_CONSTANT): subframe_samples = self.read_constant_subframe( block_size, bits_per_sample - wasted_bps) elif (subframe_type == self.SUBFRAME_VERBATIM): subframe_samples = self.read_verbatim_subframe( block_size, bits_per_sample - wasted_bps) elif (subframe_type == self.SUBFRAME_FIXED): subframe_samples = self.read_fixed_subframe( block_size, bits_per_sample - wasted_bps, subframe_order) else: subframe_samples = self.read_lpc_subframe( block_size, bits_per_sample - wasted_bps, subframe_order) #account for wasted bits-per-sample, if necessary if (wasted_bps): return [sample << wasted_bps for sample in subframe_samples] else: return subframe_samples def read_constant_subframe(self, block_size, bits_per_sample): sample = self.reader.read_signed(bits_per_sample) return [sample] * block_size def read_verbatim_subframe(self, block_size, bits_per_sample): return [self.reader.read_signed(bits_per_sample) for x in xrange(block_size)] def read_fixed_subframe(self, block_size, bits_per_sample, order): #"order" number of warm-up samples samples = [self.reader.read_signed(bits_per_sample) for i in xrange(order)] #"block_size" - "order" number of residual values residuals = self.read_residual(block_size, order) #which are applied to the warm-up samples #depending on the FIXED subframe order #and results in "block_size" number of total samples if (order == 0): return residuals elif (order == 1): for residual in residuals: samples.append( samples[-1] + residual) return samples elif (order == 2): for residual in residuals: samples.append( (2 * samples[-1]) - samples[-2] + residual) return samples elif (order == 3): for residual in residuals: samples.append( (3 * samples[-1]) - (3 * samples[-2]) + samples[-3] + residual) return samples elif (order == 4): for residual in residuals: samples.append( (4 * samples[-1]) - (6 * samples[-2]) + (4 * samples[-3]) - samples[-4] + residual) return samples else: raise ValueError("unsupported FIXED subframe order") def read_lpc_subframe(self, block_size, bits_per_sample, order): #"order" number of warm-up samples samples = [self.reader.read_signed(bits_per_sample) for i in xrange(order)] #the size of each QLP coefficient, in bits qlp_precision = self.reader.read(4) #the amount of right shift to apply #during LPC calculation #(though this is a signed value, negative shifts are noops # in the reference FLAC decoder) qlp_shift_needed = max(self.reader.read_signed(5), 0) #"order" number of signed QLP coefficients qlp_coeffs = [self.reader.read_signed(qlp_precision + 1) for i in xrange(order)] #QLP coefficients are applied in reverse order qlp_coeffs.reverse() #"block_size" - "order" number of residual values residuals = self.read_residual(block_size, order) #which are applied to the running LPC calculation for residual in residuals: samples.append((sum([coeff * sample for (coeff, sample) in zip(qlp_coeffs, samples[-order:])]) >> qlp_shift_needed) + residual) return samples def read_residual(self, block_size, order): residuals = [] coding_method = self.reader.read(2) partition_order = self.reader.read(4) #each parititon contains block_size / 2 ** partition_order #number of residuals for partition_number in xrange(2 ** partition_order): if (partition_number == 0): #except for the first partition #which contains "order" less than the rest residuals.extend( self.read_residual_partition( coding_method, (block_size / 2 ** partition_order) - order)) else: residuals.extend( self.read_residual_partition( coding_method, block_size / 2 ** partition_order)) return residuals def read_residual_partition(self, coding_method, residual_count): if (coding_method == 0): #the Rice parameters determines the number of #least-significant bits to read for each residual rice_parameter = self.reader.read(4) if (rice_parameter == 0xF): escape_code = self.reader.read(5) return [self.reader.read_signed(escape_code) for i in xrange(residual_count)] elif (coding_method == 1): #24 bps streams may use a 5-bit Rice parameter #for better compression rice_parameter = self.reader.read(5) if (rice_parameter == 0x1F): escape_code = self.reader.read(5) return [self.reader.read_signed(escape_code) for i in xrange(residual_count)] else: raise ValueError("invalid Rice coding parameter") #a list of signed residual values partition_residuals = [] for i in xrange(residual_count): msb = self.reader.unary(1) # most-significant bits lsb = self.reader.read(rice_parameter) # least-significant bits value = (msb << rice_parameter) | lsb # combined into a value if (value & 1): # whose least-significant bit is the sign value partition_residuals.append(-(value >> 1) - 1) else: partition_residuals.append(value >> 1) return partition_residuals def read_utf8(self): total_bytes = self.reader.unary(0) value = self.reader.read(7 - total_bytes) while (total_bytes > 1): value = ((value << 6) | self.reader.parse("2p 6u")[0]) total_bytes -= 1 return value def close(self): self.reader.close()
class ALACDecoder: def __init__(self, filename): self.reader = BitstreamReader(open(filename, "rb"), 0) self.reader.mark() try: #locate the "alac" atom #which is full of required decoding parameters try: stsd = self.find_sub_atom("moov", "trak", "mdia", "minf", "stbl", "stsd") except KeyError: raise ValueError("required stsd atom not found") (stsd_version, descriptions) = stsd.parse("8u 24p 32u") (alac1, alac2, self.samples_per_frame, self.bits_per_sample, self.history_multiplier, self.initial_history, self.maximum_k, self.channels, self.sample_rate) = stsd.parse( #ignore much of the stuff in the "high" ALAC atom "32p 4b 6P 16p 16p 16p 4P 16p 16p 16p 16p 4P" + #and use the attributes in the "low" ALAC atom instead "32p 4b 4P 32u 8p 8u 8u 8u 8u 8u 16p 32p 32p 32u") self.channel_mask = {1: 0x0004, 2: 0x0003, 3: 0x0007, 4: 0x0107, 5: 0x0037, 6: 0x003F, 7: 0x013F, 8: 0x00FF}.get(self.channels, 0) if ((alac1 != 'alac') or (alac2 != 'alac')): raise ValueError("Invalid alac atom") #also locate the "mdhd" atom #which contains the stream's length in PCM frames self.reader.rewind() mdhd = self.find_sub_atom("moov", "trak", "mdia", "mdhd") (version, ) = mdhd.parse("8u 24p") if (version == 0): (self.total_pcm_frames,) = mdhd.parse( "32p 32p 32p 32u 2P 16p") elif (version == 1): (self.total_pcm_frames,) = mdhd.parse( "64p 64p 32p 64U 2P 16p") else: raise ValueError("invalid mdhd version") #finally, set our stream to the "mdat" atom self.reader.rewind() (atom_size, atom_name) = self.reader.parse("32u 4b") while (atom_name != "mdat"): self.reader.skip_bytes(atom_size - 8) (atom_size, atom_name) = self.reader.parse("32u 4b") finally: self.reader.unmark() def find_sub_atom(self, *atom_names): reader = self.reader for (last, next_atom) in iter_last(iter(atom_names)): try: (length, stream_atom) = reader.parse("32u 4b") while (stream_atom != next_atom): reader.skip_bytes(length - 8) (length, stream_atom) = reader.parse("32u 4b") if (last): return reader.substream(length - 8) else: reader = reader.substream(length - 8) except IOError: raise KeyError(next_atom) def read(self, pcm_frames): #if the stream is exhausted, return an empty pcm.FrameList object if (self.total_pcm_frames == 0): return from_list([], self.channels, self.bits_per_sample, True) #otherwise, read one ALAC frameset's worth of frame data frameset_data = [] frame_channels = self.reader.read(3) + 1 while (frame_channels != 0x8): frameset_data.extend(self.read_frame(frame_channels)) frame_channels = self.reader.read(3) + 1 self.reader.byte_align() #reorder the frameset to Wave order, depending on channel count if ((self.channels == 1) or (self.channels == 2)): pass elif (self.channels == 3): frameset_data = [frameset_data[1], frameset_data[2], frameset_data[0]] elif (self.channels == 4): frameset_data = [frameset_data[1], frameset_data[2], frameset_data[0], frameset_data[3]] elif (self.channels == 5): frameset_data = [frameset_data[1], frameset_data[2], frameset_data[0], frameset_data[3], frameset_data[4]] elif (self.channels == 6): frameset_data = [frameset_data[1], frameset_data[2], frameset_data[0], frameset_data[5], frameset_data[3], frameset_data[4]] elif (self.channels == 7): frameset_data = [frameset_data[1], frameset_data[2], frameset_data[0], frameset_data[6], frameset_data[3], frameset_data[4], frameset_data[5]] elif (self.channels == 8): frameset_data = [frameset_data[3], frameset_data[4], frameset_data[0], frameset_data[7], frameset_data[5], frameset_data[6], frameset_data[1], frameset_data[2]] else: raise ValueError("unsupported channel count") framelist = from_channels([from_list(channel, 1, self.bits_per_sample, True) for channel in frameset_data]) #deduct PCM frames from remainder self.total_pcm_frames -= framelist.frames #return samples as a pcm.FrameList object return framelist def read_frame(self, channel_count): """returns a list of PCM sample lists, one per channel""" #read the ALAC frame header self.reader.skip(16) has_sample_count = self.reader.read(1) uncompressed_lsb_size = self.reader.read(2) uncompressed = self.reader.read(1) if (has_sample_count): sample_count = self.reader.read(32) else: sample_count = self.samples_per_frame if (uncompressed == 1): #if the frame is uncompressed, #read the raw, interlaced samples samples = [self.reader.read_signed(self.bits_per_sample) for i in xrange(sample_count * channel_count)] return [samples[i::channel_count] for i in xrange(channel_count)] else: #if the frame is compressed, #read the interlacing parameters interlacing_shift = self.reader.read(8) interlacing_leftweight = self.reader.read(8) #subframe headers subframe_headers = [self.read_subframe_header() for i in xrange(channel_count)] #optional uncompressed LSB values if (uncompressed_lsb_size > 0): uncompressed_lsbs = [ self.reader.read(uncompressed_lsb_size * 8) for i in xrange(sample_count * channel_count)] else: uncompressed_lsbs = [] sample_size = (self.bits_per_sample - (uncompressed_lsb_size * 8) + channel_count - 1) #and residual blocks residual_blocks = [self.read_residuals(sample_size, sample_count) for i in xrange(channel_count)] #calculate subframe samples based on #subframe header's QLP coefficients and QLP shift-needed decoded_subframes = [self.decode_subframe(header[0], header[1], sample_size, residuals) for (header, residuals) in zip(subframe_headers, residual_blocks)] #decorrelate channels according interlacing shift and leftweight decorrelated_channels = self.decorrelate_channels( decoded_subframes, interlacing_shift, interlacing_leftweight) #if uncompressed LSB values are present, #prepend them to each sample of each channel if (uncompressed_lsb_size > 0): channels = [] for (i, channel) in enumerate(decorrelated_channels): assert(len(channel) == len(uncompressed_lsbs[i::channel_count])) channels.append([s << (uncompressed_lsb_size * 8) | l for (s, l) in zip(channel, uncompressed_lsbs[i::channel_count])]) return channels else: return decorrelated_channels def read_subframe_header(self): prediction_type = self.reader.read(4) qlp_shift_needed = self.reader.read(4) rice_modifier = self.reader.read(3) qlp_coefficients = [self.reader.read_signed(16) for i in xrange(self.reader.read(5))] return (qlp_shift_needed, qlp_coefficients) def read_residuals(self, sample_size, sample_count): residuals = [] history = self.initial_history sign_modifier = 0 i = 0 while (i < sample_count): #get an unsigned residual based on "history" #and on "sample_size" as a lst resort k = min(log2(history / (2 ** 9) + 3), self.maximum_k) unsigned = self.read_residual(k, sample_size) + sign_modifier #clear out old sign modifier, if any sign_modifier = 0 #change unsigned residual to signed residual if (unsigned & 1): residuals.append(-((unsigned + 1) / 2)) else: residuals.append(unsigned / 2) #update history based on unsigned residual if (unsigned <= 0xFFFF): history += ((unsigned * self.history_multiplier) - ((history * self.history_multiplier) >> 9)) else: history = 0xFFFF #if history gets too small, we may have a block of 0 samples #which can be compressed more efficiently if ((history < 128) and ((i + 1) < sample_count)): zeroes_k = min(7 - log2(history) + ((history + 16) / 64), self.maximum_k) zero_residuals = self.read_residual(zeroes_k, 16) if (zero_residuals > 0): residuals.extend([0] * zero_residuals) i += zero_residuals history = 0 if (zero_residuals <= 0xFFFF): sign_modifier = 1 i += 1 return residuals def read_residual(self, k, sample_size): msb = self.reader.limited_unary(0, 9) if (msb is None): return self.reader.read(sample_size) elif (k == 0): return msb else: lsb = self.reader.read(k) if (lsb > 1): return msb * ((1 << k) - 1) + (lsb - 1) elif (lsb == 1): self.reader.unread(1) return msb * ((1 << k) - 1) else: self.reader.unread(0) return msb * ((1 << k) - 1) def decode_subframe(self, qlp_shift_needed, qlp_coefficients, sample_size, residuals): #first sample is always copied verbatim samples = [residuals.pop(0)] if (len(qlp_coefficients) < 31): #the next "coefficient count" samples #are applied as differences to the previous for i in xrange(len(qlp_coefficients)): samples.append(truncate_bits(samples[-1] + residuals.pop(0), sample_size)) #remaining samples are processed much like LPC for residual in residuals: base_sample = samples[-len(qlp_coefficients) - 1] lpc_sum = sum([(s - base_sample) * c for (s, c) in zip(samples[-len(qlp_coefficients):], reversed(qlp_coefficients))]) outval = (1 << (qlp_shift_needed - 1)) + lpc_sum outval >>= qlp_shift_needed samples.append(truncate_bits(outval + residual + base_sample, sample_size)) buf = samples[-len(qlp_coefficients) - 2:-1] #error value then adjusts the coefficients table if (residual > 0): predictor_num = len(qlp_coefficients) - 1 while ((predictor_num >= 0) and residual > 0): val = (buf[0] - buf[len(qlp_coefficients) - predictor_num]) sign = sign_only(val) qlp_coefficients[predictor_num] -= sign val *= sign residual -= ((val >> qlp_shift_needed) * (len(qlp_coefficients) - predictor_num)) predictor_num -= 1 elif (residual < 0): #the same as above, but we break if residual goes positive predictor_num = len(qlp_coefficients) - 1 while ((predictor_num >= 0) and residual < 0): val = (buf[0] - buf[len(qlp_coefficients) - predictor_num]) sign = -sign_only(val) qlp_coefficients[predictor_num] -= sign val *= sign residual -= ((val >> qlp_shift_needed) * (len(qlp_coefficients) - predictor_num)) predictor_num -= 1 else: #residuals are encoded as simple difference values for residual in residuals: samples.append(truncate_bits(samples[-1] + residual, sample_size)) return samples def decorrelate_channels(self, channel_data, interlacing_shift, interlacing_leftweight): if (len(channel_data) != 2): return channel_data elif (interlacing_leftweight == 0): return channel_data else: left = [] right = [] for (ch1, ch2) in zip(*channel_data): right.append(ch1 - ((ch2 * interlacing_leftweight) / (2 ** interlacing_shift))) left.append(ch2 + right[-1]) return [left, right] def close(self): pass
class ALACDecoder(object): def __init__(self, filename): self.reader = BitstreamReader(open(filename, "rb"), False) self.reader.mark() try: # locate the "alac" atom # which is full of required decoding parameters try: stsd = self.find_sub_atom(b"moov", b"trak", b"mdia", b"minf", b"stbl", b"stsd") except KeyError: raise ValueError("required stsd atom not found") (stsd_version, descriptions) = stsd.parse("8u 24p 32u") (alac1, alac2, self.samples_per_frame, self.bits_per_sample, self.history_multiplier, self.initial_history, self.maximum_k, self.channels, self.sample_rate) = stsd.parse( # ignore much of the stuff in the "high" ALAC atom "32p 4b 6P 16p 16p 16p 4P 16p 16p 16p 16p 4P" + # and use the attributes in the "low" ALAC atom instead "32p 4b 4P 32u 8p 8u 8u 8u 8u 8u 16p 32p 32p 32u") self.channel_mask = { 1: 0x0004, 2: 0x0003, 3: 0x0007, 4: 0x0107, 5: 0x0037, 6: 0x003F, 7: 0x013F, 8: 0x00FF }.get(self.channels, 0) if ((alac1 != b'alac') or (alac2 != b'alac')): raise ValueError("Invalid alac atom") # also locate the "mdhd" atom # which contains the stream's length in PCM frames self.reader.rewind() mdhd = self.find_sub_atom(b"moov", b"trak", b"mdia", b"mdhd") (version, ) = mdhd.parse("8u 24p") if (version == 0): (self.total_pcm_frames, ) = mdhd.parse("32p 32p 32p 32u 2P 16p") elif (version == 1): (self.total_pcm_frames, ) = mdhd.parse("64p 64p 32p 64U 2P 16p") else: raise ValueError("invalid mdhd version") # finally, set our stream to the "mdat" atom self.reader.rewind() (atom_size, atom_name) = self.reader.parse("32u 4b") while (atom_name != b"mdat"): self.reader.skip_bytes(atom_size - 8) (atom_size, atom_name) = self.reader.parse("32u 4b") finally: self.reader.unmark() def find_sub_atom(self, *atom_names): reader = self.reader for (last, next_atom) in iter_last(iter(atom_names)): try: (length, stream_atom) = reader.parse("32u 4b") while (stream_atom != next_atom): reader.skip_bytes(length - 8) (length, stream_atom) = reader.parse("32u 4b") if (last): return reader.substream(length - 8) else: reader = reader.substream(length - 8) except IOError: raise KeyError(next_atom) def read(self, pcm_frames): # if the stream is exhausted, return an empty pcm.FrameList object if (self.total_pcm_frames == 0): return empty_framelist(self.channels, self.bits_per_sample) # otherwise, read one ALAC frameset's worth of frame data frameset_data = [] frame_channels = self.reader.read(3) + 1 while (frame_channels != 0x8): frameset_data.extend(self.read_frame(frame_channels)) frame_channels = self.reader.read(3) + 1 self.reader.byte_align() # reorder the frameset to Wave order, depending on channel count if ((self.channels == 1) or (self.channels == 2)): pass elif (self.channels == 3): frameset_data = [ frameset_data[1], frameset_data[2], frameset_data[0] ] elif (self.channels == 4): frameset_data = [ frameset_data[1], frameset_data[2], frameset_data[0], frameset_data[3] ] elif (self.channels == 5): frameset_data = [ frameset_data[1], frameset_data[2], frameset_data[0], frameset_data[3], frameset_data[4] ] elif (self.channels == 6): frameset_data = [ frameset_data[1], frameset_data[2], frameset_data[0], frameset_data[5], frameset_data[3], frameset_data[4] ] elif (self.channels == 7): frameset_data = [ frameset_data[1], frameset_data[2], frameset_data[0], frameset_data[6], frameset_data[3], frameset_data[4], frameset_data[5] ] elif (self.channels == 8): frameset_data = [ frameset_data[3], frameset_data[4], frameset_data[0], frameset_data[7], frameset_data[5], frameset_data[6], frameset_data[1], frameset_data[2] ] else: raise ValueError("unsupported channel count") framelist = from_channels([ from_list(channel, 1, self.bits_per_sample, True) for channel in frameset_data ]) # deduct PCM frames from remainder self.total_pcm_frames -= framelist.frames # return samples as a pcm.FrameList object return framelist def read_frame(self, channel_count): """returns a list of PCM sample lists, one per channel""" # read the ALAC frame header self.reader.skip(16) has_sample_count = self.reader.read(1) uncompressed_lsb_size = self.reader.read(2) uncompressed = self.reader.read(1) if (has_sample_count): sample_count = self.reader.read(32) else: sample_count = self.samples_per_frame if (uncompressed == 1): # if the frame is uncompressed, # read the raw, interlaced samples samples = [ self.reader.read_signed(self.bits_per_sample) for i in range(sample_count * channel_count) ] return [samples[i::channel_count] for i in range(channel_count)] else: # if the frame is compressed, # read the interlacing parameters interlacing_shift = self.reader.read(8) interlacing_leftweight = self.reader.read(8) # subframe headers subframe_headers = [ self.read_subframe_header() for i in range(channel_count) ] # optional uncompressed LSB values if (uncompressed_lsb_size > 0): uncompressed_lsbs = [ self.reader.read(uncompressed_lsb_size * 8) for i in range(sample_count * channel_count) ] else: uncompressed_lsbs = [] sample_size = (self.bits_per_sample - (uncompressed_lsb_size * 8) + channel_count - 1) # and residual blocks residual_blocks = [ self.read_residuals(sample_size, sample_count) for i in range(channel_count) ] # calculate subframe samples based on # subframe header's QLP coefficients and QLP shift-needed decoded_subframes = [ self.decode_subframe(header[0], header[1], sample_size, residuals) for (header, residuals) in zip(subframe_headers, residual_blocks) ] # decorrelate channels according interlacing shift and leftweight decorrelated_channels = self.decorrelate_channels( decoded_subframes, interlacing_shift, interlacing_leftweight) # if uncompressed LSB values are present, # prepend them to each sample of each channel if (uncompressed_lsb_size > 0): channels = [] for (i, channel) in enumerate(decorrelated_channels): assert (len(channel) == len( uncompressed_lsbs[i::channel_count])) channels.append([ s << (uncompressed_lsb_size * 8) | l for (s, l) in zip( channel, uncompressed_lsbs[i::channel_count]) ]) return channels else: return decorrelated_channels def read_subframe_header(self): prediction_type = self.reader.read(4) qlp_shift_needed = self.reader.read(4) rice_modifier = self.reader.read(3) qlp_coefficients = [ self.reader.read_signed(16) for i in range(self.reader.read(5)) ] return (qlp_shift_needed, qlp_coefficients) def read_residuals(self, sample_size, sample_count): residuals = [] history = self.initial_history sign_modifier = 0 i = 0 while (i < sample_count): # get an unsigned residual based on "history" # and on "sample_size" as a lst resort k = min(log2(history // (2**9) + 3), self.maximum_k) unsigned = self.read_residual(k, sample_size) + sign_modifier # clear out old sign modifier, if any sign_modifier = 0 # change unsigned residual to signed residual if (unsigned & 1): residuals.append(-((unsigned + 1) // 2)) else: residuals.append(unsigned // 2) # update history based on unsigned residual if (unsigned <= 0xFFFF): history += ((unsigned * self.history_multiplier) - ((history * self.history_multiplier) >> 9)) else: history = 0xFFFF # if history gets too small, we may have a block of 0 samples # which can be compressed more efficiently if ((history < 128) and ((i + 1) < sample_count)): zeroes_k = min(7 - log2(history) + ((history + 16) // 64), self.maximum_k) zero_residuals = self.read_residual(zeroes_k, 16) if (zero_residuals > 0): residuals.extend([0] * zero_residuals) i += zero_residuals history = 0 if (zero_residuals <= 0xFFFF): sign_modifier = 1 i += 1 return residuals def read_residual(self, k, sample_size): msb = self.reader.read_huffman_code(RESIDUAL) if (msb == -1): return self.reader.read(sample_size) elif (k == 0): return msb else: lsb = self.reader.read(k) if (lsb > 1): return msb * ((1 << k) - 1) + (lsb - 1) elif (lsb == 1): self.reader.unread(1) return msb * ((1 << k) - 1) else: self.reader.unread(0) return msb * ((1 << k) - 1) def decode_subframe(self, qlp_shift_needed, qlp_coefficients, sample_size, residuals): # first sample is always copied verbatim samples = [residuals.pop(0)] if (len(qlp_coefficients) < 31): # the next "coefficient count" samples # are applied as differences to the previous for i in range(len(qlp_coefficients)): samples.append( truncate_bits(samples[-1] + residuals.pop(0), sample_size)) # remaining samples are processed much like LPC for residual in residuals: base_sample = samples[-len(qlp_coefficients) - 1] lpc_sum = sum([(s - base_sample) * c for (s, c) in zip(samples[-len(qlp_coefficients):], reversed(qlp_coefficients))]) outval = (1 << (qlp_shift_needed - 1)) + lpc_sum outval >>= qlp_shift_needed samples.append( truncate_bits(outval + residual + base_sample, sample_size)) buf = samples[-len(qlp_coefficients) - 2:-1] # error value then adjusts the coefficients table if (residual > 0): predictor_num = len(qlp_coefficients) - 1 while ((predictor_num >= 0) and residual > 0): val = (buf[0] - buf[len(qlp_coefficients) - predictor_num]) sign = sign_only(val) qlp_coefficients[predictor_num] -= sign val *= sign residual -= ((val >> qlp_shift_needed) * (len(qlp_coefficients) - predictor_num)) predictor_num -= 1 elif (residual < 0): # the same as above, but we break if residual goes positive predictor_num = len(qlp_coefficients) - 1 while ((predictor_num >= 0) and residual < 0): val = (buf[0] - buf[len(qlp_coefficients) - predictor_num]) sign = -sign_only(val) qlp_coefficients[predictor_num] -= sign val *= sign residual -= ((val >> qlp_shift_needed) * (len(qlp_coefficients) - predictor_num)) predictor_num -= 1 else: # residuals are encoded as simple difference values for residual in residuals: samples.append( truncate_bits(samples[-1] + residual, sample_size)) return samples def decorrelate_channels(self, channel_data, interlacing_shift, interlacing_leftweight): if (len(channel_data) != 2): return channel_data elif (interlacing_leftweight == 0): return channel_data else: left = [] right = [] for (ch1, ch2) in zip(*channel_data): right.append(ch1 - ((ch2 * interlacing_leftweight) // (2**interlacing_shift))) left.append(ch2 + right[-1]) return [left, right] def close(self): self.reader.close() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close()
class FlacDecoder(object): CHANNEL_COUNT = [ 1, 2, 3, 4, 5, 6, 7, 8, 2, 2, 2, None, None, None, None, None ] (SUBFRAME_CONSTANT, SUBFRAME_VERBATIM, SUBFRAME_FIXED, SUBFRAME_LPC) = range(4) def __init__(self, filename, channel_mask): self.reader = BitstreamReader(open(filename, "rb"), False) if (self.reader.read_bytes(4) != b'fLaC'): raise ValueError("invalid FLAC file") self.current_md5sum = md5() # locate the STREAMINFO, # which is sometimes needed to handle non-subset streams for (block_id, block_size, block_reader) in self.metadata_blocks(self.reader): if (block_id == 0): # read STREAMINFO self.minimum_block_size = block_reader.read(16) self.maximum_block_size = block_reader.read(16) self.minimum_frame_size = block_reader.read(24) self.maximum_frame_size = block_reader.read(24) self.sample_rate = block_reader.read(20) self.channels = block_reader.read(3) + 1 self.channel_mask = channel_mask self.bits_per_sample = block_reader.read(5) + 1 self.total_frames = block_reader.read(36) self.md5sum = block_reader.read_bytes(16) # these are frame header lookup tables # which vary slightly depending on STREAMINFO's values self.BLOCK_SIZE = [ self.maximum_block_size, 192, 576, 1152, 2304, 4608, None, None, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768 ] self.SAMPLE_RATE = [ self.sample_rate, 88200, 176400, 192000, 8000, 16000, 22050, 24000, 32000, 44100, 48000, 96000, None, None, None, None ] self.BITS_PER_SAMPLE = [ self.bits_per_sample, 8, 12, None, 16, 20, 24, None ] def metadata_blocks(self, reader): """yields a (block_id, block_size, block_reader) tuple per metadata block where block_reader is a BitstreamReader substream""" (last_block, block_id, block_size) = self.reader.parse("1u 7u 24u") while (last_block == 0): yield (block_id, block_size, self.reader.substream(block_size)) (last_block, block_id, block_size) = self.reader.parse("1u 7u 24u") else: yield (block_id, block_size, self.reader.substream(block_size)) def read(self, pcm_frames): # if the stream is exhausted, # verify its MD5 sum and return an empty pcm.FrameList object if (self.total_frames < 1): if (self.md5sum == self.current_md5sum.digest()): return empty_framelist(self.channels, self.bits_per_sample) else: raise ValueError("MD5 checksum mismatch") crc16 = CRC16() self.reader.add_callback(crc16.update) # fetch the decoding parameters from the frame header (block_size, channel_assignment, bits_per_sample) = self.read_frame_header() channel_count = self.CHANNEL_COUNT[channel_assignment] if (channel_count is None): raise ValueError("invalid channel assignment") # channel data will be a list of signed sample lists, one per channel # such as [[1, 2, 3, ...], [4, 5, 6, ...]] for a 2 channel stream channel_data = [] for channel_number in range(channel_count): if ((channel_assignment == 0x8) and (channel_number == 1)): # for left-difference assignment # the difference channel has 1 additional bit channel_data.append( self.read_subframe(block_size, bits_per_sample + 1)) elif ((channel_assignment == 0x9) and (channel_number == 0)): # for difference-right assignment # the difference channel has 1 additional bit channel_data.append( self.read_subframe(block_size, bits_per_sample + 1)) elif ((channel_assignment == 0xA) and (channel_number == 1)): # for average-difference assignment # the difference channel has 1 additional bit channel_data.append( self.read_subframe(block_size, bits_per_sample + 1)) else: # otherwise, use the frame's bits-per-sample value channel_data.append( self.read_subframe(block_size, bits_per_sample)) # one all the subframes have been decoded, # reconstruct them depending on the channel assignment if (channel_assignment == 0x8): # left-difference samples = [] for (left, difference) in zip(*channel_data): samples.append(left) samples.append(left - difference) elif (channel_assignment == 0x9): # difference-right samples = [] for (difference, right) in zip(*channel_data): samples.append(difference + right) samples.append(right) elif (channel_assignment == 0xA): # mid-side samples = [] for (mid, side) in zip(*channel_data): samples.append((((mid * 2) + (side % 2)) + side) // 2) samples.append((((mid * 2) + (side % 2)) - side) // 2) else: # independent samples = [0] * block_size * channel_count for (i, channel) in enumerate(channel_data): samples[i::channel_count] = channel self.reader.byte_align() # read and verify the frame's trailing CRC-16 footer self.reader.read(16) self.reader.pop_callback() if (int(crc16) != 0): raise ValueError("CRC16 mismatch in frame footer") # deduct the amount of PCM frames from the remaining amount self.total_frames -= block_size # build a pcm.FrameList object from the combined samples framelist = from_list(samples, channel_count, bits_per_sample, True) # update the running MD5 sum calculation with the frame's data self.current_md5sum.update(framelist.to_bytes(0, 1)) # and finally return the frame data return framelist def read_frame_header(self): crc8 = CRC8() self.reader.add_callback(crc8.update) # read the 32-bit FLAC frame header sync_code = self.reader.read(14) if (sync_code != 0x3FFE): raise ValueError("invalid sync code") self.reader.skip(1) blocking_strategy = self.reader.read(1) block_size_bits = self.reader.read(4) sample_rate_bits = self.reader.read(4) channel_assignment = self.reader.read(4) bits_per_sample_bits = self.reader.read(3) self.reader.skip(1) # the frame number is a UTF-8 encoded value # which takes a variable number of whole bytes frame_number = self.read_utf8() # unpack the 4 bit block size field # which is the total PCM frames in the FLAC frame # and may require up to 16 more bits if the frame is usually-sized # (which typically happens at the end of the stream) if (block_size_bits == 0x6): block_size = self.reader.read(8) + 1 elif (block_size_bits == 0x7): block_size = self.reader.read(16) + 1 else: block_size = self.BLOCK_SIZE[block_size_bits] # unpack the 4 bit sample rate field # which is used for playback, but not needed for decoding # and may require up to 16 more bits # if the stream has a particularly unusual sample rate if (sample_rate_bits == 0xC): sample_rate = self.reader.read(8) * 1000 elif (sample_rate_bits == 0xD): sample_rate = self.reader.read(16) elif (sample_rate_bits == 0xE): sample_rate = self.reader.read(16) * 10 elif (sample_rate_bits == 0xF): raise ValueError("invalid sample rate") else: sample_rate = self.SAMPLE_RATE[sample_rate_bits] # unpack the 3 bit bits-per-sample field # this never requires additional bits if ((bits_per_sample_bits == 0x3) or (bits_per_sample_bits == 0x7)): raise ValueError("invalid bits per sample") else: bits_per_sample = self.BITS_PER_SAMPLE[bits_per_sample_bits] # read and verify frame's CRC-8 value self.reader.read(8) self.reader.pop_callback() if (int(crc8) != 0): raise ValueError("CRC8 mismatch in frame header") return (block_size, channel_assignment, bits_per_sample) def read_subframe_header(self): """returns a tuple of (subframe_type, subframe_order, wasted_bps)""" self.reader.skip(1) subframe_type = self.reader.read(6) if (self.reader.read(1) == 1): wasted_bps = self.reader.unary(1) + 1 else: wasted_bps = 0 # extract "order" value from 6 bit subframe type, if necessary if (subframe_type == 0): return (self.SUBFRAME_CONSTANT, None, wasted_bps) elif (subframe_type == 1): return (self.SUBFRAME_VERBATIM, None, wasted_bps) elif ((subframe_type & 0x38) == 0x08): return (self.SUBFRAME_FIXED, subframe_type & 0x07, wasted_bps) elif ((subframe_type & 0x20) == 0x20): return (self.SUBFRAME_LPC, (subframe_type & 0x1F) + 1, wasted_bps) else: raise ValueError("invalid subframe type") def read_subframe(self, block_size, bits_per_sample): (subframe_type, subframe_order, wasted_bps) = self.read_subframe_header() # read a list of signed sample values # depending on the subframe type, block size, # adjusted bits per sample and optional subframe order if (subframe_type == self.SUBFRAME_CONSTANT): subframe_samples = self.read_constant_subframe( block_size, bits_per_sample - wasted_bps) elif (subframe_type == self.SUBFRAME_VERBATIM): subframe_samples = self.read_verbatim_subframe( block_size, bits_per_sample - wasted_bps) elif (subframe_type == self.SUBFRAME_FIXED): subframe_samples = self.read_fixed_subframe( block_size, bits_per_sample - wasted_bps, subframe_order) else: subframe_samples = self.read_lpc_subframe( block_size, bits_per_sample - wasted_bps, subframe_order) # account for wasted bits-per-sample, if necessary if (wasted_bps): return [sample << wasted_bps for sample in subframe_samples] else: return subframe_samples def read_constant_subframe(self, block_size, bits_per_sample): sample = self.reader.read_signed(bits_per_sample) return [sample] * block_size def read_verbatim_subframe(self, block_size, bits_per_sample): return [ self.reader.read_signed(bits_per_sample) for x in range(block_size) ] def read_fixed_subframe(self, block_size, bits_per_sample, order): # "order" number of warm-up samples samples = [ self.reader.read_signed(bits_per_sample) for i in range(order) ] # "block_size" - "order" number of residual values residuals = self.read_residual(block_size, order) # which are applied to the warm-up samples # depending on the FIXED subframe order # and results in "block_size" number of total samples if (order == 0): return residuals elif (order == 1): for residual in residuals: samples.append(samples[-1] + residual) return samples elif (order == 2): for residual in residuals: samples.append((2 * samples[-1]) - samples[-2] + residual) return samples elif (order == 3): for residual in residuals: samples.append((3 * samples[-1]) - (3 * samples[-2]) + samples[-3] + residual) return samples elif (order == 4): for residual in residuals: samples.append((4 * samples[-1]) - (6 * samples[-2]) + (4 * samples[-3]) - samples[-4] + residual) return samples else: raise ValueError("unsupported FIXED subframe order") def read_lpc_subframe(self, block_size, bits_per_sample, order): # "order" number of warm-up samples samples = [ self.reader.read_signed(bits_per_sample) for i in range(order) ] # the size of each QLP coefficient, in bits qlp_precision = self.reader.read(4) # the amount of right shift to apply # during LPC calculation # (though this is a signed value, negative shifts are noops # in the reference FLAC decoder) qlp_shift_needed = max(self.reader.read_signed(5), 0) # "order" number of signed QLP coefficients qlp_coeffs = [ self.reader.read_signed(qlp_precision + 1) for i in range(order) ] # QLP coefficients are applied in reverse order qlp_coeffs.reverse() # "block_size" - "order" number of residual values residuals = self.read_residual(block_size, order) # which are applied to the running LPC calculation for residual in residuals: samples.append((sum([ coeff * sample for (coeff, sample) in zip(qlp_coeffs, samples[-order:]) ]) >> qlp_shift_needed) + residual) return samples def read_residual(self, block_size, order): residuals = [] coding_method = self.reader.read(2) partition_order = self.reader.read(4) # each parititon contains block_size / 2 ** partition_order # number of residuals for partition_number in range(2**partition_order): if (partition_number == 0): # except for the first partition # which contains "order" less than the rest residuals.extend( self.read_residual_partition( coding_method, (block_size // 2**partition_order) - order)) else: residuals.extend( self.read_residual_partition( coding_method, block_size // 2**partition_order)) return residuals def read_residual_partition(self, coding_method, residual_count): if (coding_method == 0): # the Rice parameters determines the number of # least-significant bits to read for each residual rice_parameter = self.reader.read(4) if (rice_parameter == 0xF): escape_code = self.reader.read(5) return [ self.reader.read_signed(escape_code) for i in range(residual_count) ] elif (coding_method == 1): # 24 bps streams may use a 5-bit Rice parameter # for better compression rice_parameter = self.reader.read(5) if (rice_parameter == 0x1F): escape_code = self.reader.read(5) return [ self.reader.read_signed(escape_code) for i in range(residual_count) ] else: raise ValueError("invalid Rice coding parameter") # a list of signed residual values partition_residuals = [] for i in range(residual_count): msb = self.reader.unary(1) # most-significant bits lsb = self.reader.read(rice_parameter) # least-significant bits value = (msb << rice_parameter) | lsb # combined into a value if (value & 1): # whose least-significant bit is the sign value partition_residuals.append(-(value >> 1) - 1) else: partition_residuals.append(value >> 1) return partition_residuals def read_utf8(self): total_bytes = self.reader.unary(0) value = self.reader.read(7 - total_bytes) while (total_bytes > 1): value = ((value << 6) | self.reader.parse("2p 6u")[0]) total_bytes -= 1 return value def close(self): self.reader.close() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close()