예제 #1
0
def joint_coding_distortion(target_sig , ref_app, max_rate,
                          search_width , threshold=0 , doSubtract=True, 
                          debug=0 , discard = False,
                          precut=-1, initfftw=True):
    """ compute the joint coding distortion e.g. given a fixed maximum rate, 
    what is the distorsion achieved if coding the target_sig knowing the ref_app
    
    This is limited to a time adaptation of the atoms
    """    
    tolerances = [2]*len(ref_app.dico.sizes)
    # initialize the fftw    
    if initfftw:
        mp._initialize_fftw(ref_app.dico, max_thread_num=1)
    # initialize factored approx
    factorizedApprox = approx.Approx(ref_app.dico, 
                                       [], 
                                       target_sig, 
                                       ref_app.length, 
                                       ref_app.fs,
                                       fast_create=True)        
    timeShifts = np.zeros(ref_app.atom_number)
    atom_idx = 0
    rate = 0    
    residual = target_sig.data.copy()
    if debug > 0:
        print "starting factorization of " , ref_app.atom_number , " atoms"        
    while (rate < max_rate) and (atom_idx < ref_app.atom_number):
        # Stop only when the target rate is achieved or all atoms have been used        
        atom = ref_app[atom_idx]
        # make a copy of the atom
        newAtom = atom.copy()
        HalfWidth = (tolerances[ref_app.dico.sizes.index(atom.length)] -1) * atom.length/2;            
        
        # Search for best time shift using cross correlation        
        input1 = residual[atom.time_position - HalfWidth : atom.time_position + atom.length + HalfWidth]
        input2 = np.zeros((2*HalfWidth + atom.length))
        input2[HalfWidth:-HalfWidth] = atom.waveform        
        if not (input1.shape == input2.shape):
            raise ValueError("This certainly happens because you haven't sufficiently padded your signal") 
        scoreVec = np.array([0.0]);
        newts = project_atom(input1,input2 , scoreVec , atom.length)        
        score = scoreVec
    
        if debug>0:
            print "Score of " , score , " found"
        # handle MP atoms #
        if newAtom.time_shift is not None:
            newAtom.time_shift += newts
            newAtom.time_position += newts
        else:
            newAtom.time_shift = newts
            newAtom.time_position += newts
            atom.proj_score = atom.mdct_value
                        
        if debug>0:
            print "Factorizing with new offset: " , newts                
        
        if score <0:                    
            newAtom.waveform = -newAtom.waveform

        factorizedApprox.add(newAtom)
        if debug > 0:
            print "SRR Achieved of : " , factorizedApprox.compute_srr()            
        timeShifts[atom_idx] = newAtom.time_shift
            
        if doSubtract:                
            residual[newAtom.time_position : newAtom.time_position + newAtom.length ] -= newAtom.waveform
            
        rate += np.log2(abs(newts))+1
        if debug:
            print "Atom %d - rate of %1.3f"%(atom_idx, rate)
        atom_idx +=1
        
        # Use to prune the calculus
        if atom_idx>precut and precut>0:
            curdisto = factorizedApprox.compute_srr() 
            if curdisto<0:
                # pruning
                return curdisto
            else:
                # stop wasting time afterwards
                 precut = -1
                
    # cleaning
    if initfftw:
        mp._clean_fftw()

    # calculate achieved SNR :    
    return factorizedApprox.compute_srr()
예제 #2
0
# Now load the long version
from PyMP.signals import LongSignal
seg_size = 5*8192
long_signal = LongSignal(op.join(os.environ['PYMP_PATH'],'data/Bach_prelude_40s.wav'),
                         seg_size,
                         mono=True, Noverlap=0.5)

# decomposing the long signal
apps, decays = mp.mp_long(long_signal,
               dico,
               target_srr, max_atom_num)




dists = np.zeros((long_signal.n_seg, len(apps)))    
mp._initialize_fftw(apps[0].dico, max_thread_num=1)
for idx in range(long_signal.n_seg):
    for jdx in range(idx):
        # test all preceeding segments only
        target_sig = long_signal.get_sub_signal(idx, 1, mono=True, pad=dico.get_pad(),fast_create=True)                                                                        
        dists[idx,jdx] = joint_coding_distortion(target_sig, apps[jdx],max_rate,1024, initfftw=False)  
                                                                                                                   
                                                                                                                   
mp._clean_fftw()    

plt.figure()
plt.imshow(dists)
plt.colorbar()
plt.show()