def joint_coding_distortion(target_sig , ref_app, max_rate, search_width , threshold=0 , doSubtract=True, debug=0 , discard = False, precut=-1, initfftw=True): """ compute the joint coding distortion e.g. given a fixed maximum rate, what is the distorsion achieved if coding the target_sig knowing the ref_app This is limited to a time adaptation of the atoms """ tolerances = [2]*len(ref_app.dico.sizes) # initialize the fftw if initfftw: mp._initialize_fftw(ref_app.dico, max_thread_num=1) # initialize factored approx factorizedApprox = approx.Approx(ref_app.dico, [], target_sig, ref_app.length, ref_app.fs, fast_create=True) timeShifts = np.zeros(ref_app.atom_number) atom_idx = 0 rate = 0 residual = target_sig.data.copy() if debug > 0: print "starting factorization of " , ref_app.atom_number , " atoms" while (rate < max_rate) and (atom_idx < ref_app.atom_number): # Stop only when the target rate is achieved or all atoms have been used atom = ref_app[atom_idx] # make a copy of the atom newAtom = atom.copy() HalfWidth = (tolerances[ref_app.dico.sizes.index(atom.length)] -1) * atom.length/2; # Search for best time shift using cross correlation input1 = residual[atom.time_position - HalfWidth : atom.time_position + atom.length + HalfWidth] input2 = np.zeros((2*HalfWidth + atom.length)) input2[HalfWidth:-HalfWidth] = atom.waveform if not (input1.shape == input2.shape): raise ValueError("This certainly happens because you haven't sufficiently padded your signal") scoreVec = np.array([0.0]); newts = project_atom(input1,input2 , scoreVec , atom.length) score = scoreVec if debug>0: print "Score of " , score , " found" # handle MP atoms # if newAtom.time_shift is not None: newAtom.time_shift += newts newAtom.time_position += newts else: newAtom.time_shift = newts newAtom.time_position += newts atom.proj_score = atom.mdct_value if debug>0: print "Factorizing with new offset: " , newts if score <0: newAtom.waveform = -newAtom.waveform factorizedApprox.add(newAtom) if debug > 0: print "SRR Achieved of : " , factorizedApprox.compute_srr() timeShifts[atom_idx] = newAtom.time_shift if doSubtract: residual[newAtom.time_position : newAtom.time_position + newAtom.length ] -= newAtom.waveform rate += np.log2(abs(newts))+1 if debug: print "Atom %d - rate of %1.3f"%(atom_idx, rate) atom_idx +=1 # Use to prune the calculus if atom_idx>precut and precut>0: curdisto = factorizedApprox.compute_srr() if curdisto<0: # pruning return curdisto else: # stop wasting time afterwards precut = -1 # cleaning if initfftw: mp._clean_fftw() # calculate achieved SNR : return factorizedApprox.compute_srr()
# Now load the long version from PyMP.signals import LongSignal seg_size = 5*8192 long_signal = LongSignal(op.join(os.environ['PYMP_PATH'],'data/Bach_prelude_40s.wav'), seg_size, mono=True, Noverlap=0.5) # decomposing the long signal apps, decays = mp.mp_long(long_signal, dico, target_srr, max_atom_num) dists = np.zeros((long_signal.n_seg, len(apps))) mp._initialize_fftw(apps[0].dico, max_thread_num=1) for idx in range(long_signal.n_seg): for jdx in range(idx): # test all preceeding segments only target_sig = long_signal.get_sub_signal(idx, 1, mono=True, pad=dico.get_pad(),fast_create=True) dists[idx,jdx] = joint_coding_distortion(target_sig, apps[jdx],max_rate,1024, initfftw=False) mp._clean_fftw() plt.figure() plt.imshow(dists) plt.colorbar() plt.show()