示例#1
0
def Train(params, sess, saver, qn, corpus, envs, envsTest):
    print("Start training")
    for epoch in range(params.max_epochs):
        #print("epoch", epoch)
        for env in envs:
            TIMER.Start("Trajectory")
            arrRL, totReward, totDiscountedReward = Trajectory(env, params, sess, qn, corpus, False)
            TIMER.Pause("Trajectory")

            lastLangVisited = corpus.transitions[-1].langsVisited
            print("epoch train", epoch, env.rootURL, arrRL[-1], totReward, totDiscountedReward, lastLangVisited)

            if epoch % params.updateFrequency == 0:
                SavePlot(params, env, params.saveDir, epoch, "train", arrRL, totReward, totDiscountedReward)

            TIMER.Start("CalcGrads")
            qn.CalcGrads(sess, corpus)
            TIMER.Pause("CalcGrads")

        if epoch % params.updateFrequency == 0:
            RunRLSavePlots(sess, qn, corpus, params, envsTest, params.saveDir, epoch, "test")

            if epoch != 0:
                print("UpdateGrads & Validating")
                TIMER.Start("UpdateGrads")
                qn.UpdateGrads(sess, corpus)
                TIMER.Pause("UpdateGrads")


        sys.stdout.flush()
示例#2
0
def Train(params, sess, saver, qns, envs, envsTest):
    print("Start training")
    for epoch in range(params.max_epochs):
        #print("epoch", epoch)
        for env in envs:
            TIMER.Start("Trajectory")
            arrRL, totReward, totDiscountedReward = Trajectory(
                env, params, sess, qns, False, 1)
            TIMER.Pause("Trajectory")
            print("epoch train", epoch, env.rootURL, totReward,
                  totDiscountedReward)

            SavePlot(params, env, params.saveDirPlots, epoch, "train", arrRL,
                     totReward, totDiscountedReward)

        TIMER.Start("Train")
        qns.q[0].corpus.Train(sess, params)
        qns.q[1].corpus.Train(sess, params)
        TIMER.Pause("Train")

        if epoch > 0 and epoch % params.walk == 0:
            print("Validating")
            #SavePlots(sess, qns, params, envs, params.saveDirPlots, epoch, "train")
            RunRLSavePlots(sess, qns, params, envsTest, params.saveDirPlots,
                           epoch, "test")
def plot_results(sr, audio, onsets, savefile):

    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    t = np.linspace(0, len(audio) / sr, len(audio))

    plt.plot(t, audio, color='dodgerblue')

    if len(onsets) > 0:
        for o in onsets:
            plt.axvline(o / sr, color='red', linestyle='--')

    plt.xlabel('Time (s)')
    plt.grid()

    sp = SavePlot(True, savefile=savefile, auto_overwrite=True)
    sp.plot(plt)
def plot_file_data(files, savefile=None, plotlabels=None):
    """ Plot radar graph of Fmes, P and R from onset analysis file(s)
    """
    
    data = []
    measures = ['F-measure', 'Precision', 'Recall']
    
    for file in files:
        results = read_onset_analysis(file)
        data.append([results[measure] for measure in measures])
    
    if savefile is not None:
        sp = SavePlot(True, savefile, auto_overwrite=True)
    else:
        sp = SavePlot(False, auto_overwrite=True)
        
    make_radar_graph(data, sp, labels=plotlabels)
示例#5
0
def RunRLSavePlot(sess, qn, corpus, params, env, saveDir, epoch, sset):
    arrRL, totReward, totDiscountedReward = Trajectory(env, params, sess, qn, corpus, True)
    SavePlot(params, env, saveDir, epoch, sset, arrRL, totReward, totDiscountedReward)
示例#6
0
    #                        2.11864583, 2.21      , 2.29      , 2.37      ,
    #                        2.47])
    #
    #    falsepos = np.array([0.59, 1.55, 2.55, 2.61, 2.69, 2.77])
    #
    #    falseneg = None

    savepath = os.path.join(user, 'onsets', 'hsj', 'Visualisation', 'Iowa')
    saveroot, _ = os.path.splitext(file)
    savefiledmp = '{}_{:.0e}'.format(saveroot, damping)
    ext = '.pdf'

    savefile = os.path.join(savepath, savefiledmp + ext)
    savelegfile = os.path.join(savepath, savefiledmp + '_legend' + ext)
    saveaudio = os.path.join(savepath, saveroot + ext)
    spa = SavePlot(save=False, savefile=saveaudio)

    sp = SavePlot(save=False, savefile=savefile)
    sl = SaveLegend(savelegfile)
    get_responses(
        filepath,
        damping=damping,
        sp=sp,
        tp=truepos,
        fp=falsepos,
        fn=falseneg,
        #                  sl=sl,
        #                  spa=spa,
        #                  t1=3
    )
示例#7
0

def mix(f0, amp):

    audio = np.zeros(sr)

    for n in range(len(f0)):
        t = np.linspace(0, 2 * np.pi * f0[n], sr)
        audio += np.sin(t) * amp[n]

    audio = np.append(audio, np.zeros(2 * sr))

    return audio


sp = SavePlot(False)

method = DetectorBank.central_difference
f_norm = DetectorBank.freq_unnormalized
a_norm = DetectorBank.amp_unnormalized

if method == DetectorBank.runge_kutta:
    mthd = 'rk'
elif method == DetectorBank.central_difference:
    mthd = 'cd'

if f_norm == DetectorBank.freq_unnormalized:
    nrml = 'un'
elif f_norm == DetectorBank.search_normalized:
    nrml = 'sn'
def plot_segments(file, threshold=None, save=False, audio_file=None):

    sns.set_style('whitegrid')

    plot_onsets = True  # False #

    plot_ms = False  #True #

    savepath = '/home/keziah/onsets/hsj/Visualisation/Iowa/'
    #    if plot_onsets:
    #        figdir = '{}_octaves_with_onsets'.format('xylophone') #'piano'
    #    else:
    #        figdir = '{}_octaves_without_onsets'.format('xylophone') #'piano'
    figdir = ''
    figdir = os.path.join(savepath, figdir)
    if not os.path.exists(figdir):
        os.makedirs(figdir)

    arr = np.loadtxt(file)

    #    arr = arr[:-1]

    # plot first 500 segments
    i0 = 0
    i1 = len(arr)  # 500 # 20 #

    seg_size = 20e-3

    title = file.split(os.path.sep)[-2]

    sr, manual_onsets = get_manual_onsets(file)
    sr, found_onsets = get_found_onsets(file)
    found_onsets /= sr

    if audio_file is not None:
        sp = SavePlot(True,
                      os.path.join(savepath, title + '_waveform.pdf'),
                      auto_overwrite=True)
        plot_audio(audio_file, sp)  #, manual_onsets, found_onsets)

    t = np.arange(0, len(arr), dtype=float)
    t *= seg_size
    if plot_ms:
        t *= 1000

    plt.plot(t[i0:i1], arr[i0:i1])

    correct, diff = CheckOnsets.compare(manual_onsets, found_onsets, 50)

    if plot_onsets:
        if correct.size > 0:
            for n, idx in enumerate(correct):
                # correctly found onset
                onset = found_onsets[idx]
                # remove corresponding onset from manual_onsets
                man = manual_onsets + (diff[n] / 1000)
                w = np.where(np.isclose(man, onset))[0]
                manual_onsets = np.delete(manual_onsets, w)
                # plot onset
                if onset <= i1 * seg_size:
                    if plot_ms:
                        onset *= 1000
                    plt.axvline(onset, color='lime', label='True positive')

            # remove correct onsets from found
            found_onsets = np.delete(found_onsets, correct)
#        else:
#            for onset in manual_onsets:
#                plt.axvline(onset, color='red', linestyle='--')

# plot false negatives
        for onset in manual_onsets:
            if onset <= i1 * seg_size:
                if plot_ms:
                    onset *= 1000
                plt.axvline(onset,
                            color='indigo',
                            linestyle='--',
                            label='False negative')
        # plot false positives
        for onset in found_onsets:
            if onset <= i1 * seg_size:  # t[i1-1]:
                if plot_ms:
                    onset *= 1000
                plt.axvline(onset,
                            color='mediumorchid',
                            linestyle='--',
                            label='False positive')

    if threshold is not None:
        threshold = np.log(threshold)
        plt.axhline(threshold, color='green', linestyle='--')

#    plt.title(title)

    plt.grid(True)
    if plot_ms:
        xlabel = 'Time (ms)'
    else:
        xlabel = 'Time (s)'
    plt.xlabel(xlabel)
    plt.ylabel('Mean log')

    #    plt.legend()

    sp = SavePlot(
        save,
        os.path.join(figdir, title + '.pdf'),  #+'_w_legend.pdf'),
        auto_overwrite=True)
    sp.plot(plt)
示例#9
0
"""
Example file to show the use of SavePlot and SaveLegend.
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from save_plot import SavePlot, SaveLegend

# make save_plot objects
sp = SavePlot(True, 'waves.pdf')
sl = SaveLegend('legend.pdf', auto_overwrite=True)

# parameters for sine waves
x = np.linspace(0, 2*np.pi, 1000)
freq = np.arange(1,5)
amp = [0.5, 1, 0.75, 0.25]

# colours, linestyles and labels
c = ['darkorange','dodgerblue', 'firebrick', 'green']
ls = ['-', '--', ':', '-.']
mk = ['o', None, None, None]
labels = list('{} rad/s'.format(f) for f in freq)

# using seaborn style
# feel free to omit this
sns.set_style('darkgrid')

# plot some waves
for idx, f in enumerate(freq):
    plt.plot(x, amp[idx]*np.sin(f*x), color=c[idx], linestyle=ls[idx])
    outpath = '/home/keziah/onsets/hsj/Sandpit/results/Iowa/'

    #    sys.stdout = open(os.path.join(outpath, 'compare_mirex_summary.txt'), 'w')

    modes = ['Precision', 'Recall', 'F-measure']
    which = modes[2]

    tab_d = {}

    df = getAllResults(which)

    for idx in df.index:
        name = getName(idx)
        savefile = '{}_{}.pdf'.format(name, which)
        savefile = os.path.join(savepath, savefile)
        sp = SavePlot(False, savefile, auto_overwrite=True, mode='quiet')
        row = df.loc[idx]
        plotResults(row, ylabel=which, sp=sp)

        if name == 'percussion_single_note':
            name = 'Percussion'
        elif name == 'pizz':
            name = 'Pizzicato strings'
        elif name == 'arco':
            name = 'Arco strings'

        print(name.capitalize())
        d = getStats(row)
        print()

        tab_d[name.capitalize()] = d
    f = files[0]

    user = os.path.expanduser('~')
    root = os.path.join(user, 'Iowa', 'all')

    file = os.path.join(root, f)
    audio, sr = sf.read(file)

    onset = 0.076
    orig_found = [0.000, 0.216, 0.334, 0.487, 0.590, 0.788, 1.067, 2.690]

    k = -9

    freq = np.array([440 * 2**(i / 12) for i in [k]])  # [k-1, k, k+1]])

    onsets = getOnsets(audio, sr, freq)
    found = onsets[0]  # onsets[1] #set(onsets[0] + onsets[1] + onsets[2])
    found = sorted(list(found))
    found = np.array(found)
    found = found / sr

    print(found)

    savefile = 'trumpet_C4_right_cap.pdf'
    savepath = '/home/keziah/onsets/hsj/Visualisation/onset_detection/'

    sp = SavePlot(False, os.path.join(savepath, savefile), auto_overwrite=True)

    #    plotMean(audio, sr, freq[0], onset, [orig_found[1]], sp)
    plotMean(audio, sr, freq[0], onset, found, sp)
if __name__ == '__main__':

    user = os.path.expanduser('~')
    audio_root_dir = os.path.join(user, 'Iowa', 'all')
    percussion_dir = os.path.join(audio_root_dir, 'Percussion')
    savepath = '.'

    params = {
        'bow': {
            'file': 'Vibraphone.bow/Vibraphone.bow.A4.stereo.wav',
            'savename': 'vibraphone_bow_400ms.pdf',
            'onset': [0.174],
            'found': [0.022]
        },
        #              'strike':
        #                  {'file': 'Vibraphone.dampen/Vibraphone.dampen.ff.A4.stereo.wav',
        #                   'savename':'vibraphone_strike.pdf',
        #                   'onset':0.015,
        #                   'found':0.019}
    }

    for k, dct in params.items():
        file = os.path.join(percussion_dir, dct['file'])
        savefile = os.path.join(savepath, dct['savename'])
        sp = SavePlot(False, savefile, auto_overwrite=True)
        plot_audio(file,
                   sp,
                   manual_onsets=dct['onset'],
                   found_onsets=dct['found'])

if __name__ == '__main__':

    method = DetectorBank.runge_kutta
    f_norm = DetectorBank.freq_unnormalized
    a_norm = DetectorBank.amp_unnormalized

    sr = 48000
    f0 = 440  # 100 #
    gain = 1.35

    d = [1e-4, 2e-4, 3e-4, 4e-4, 5e-4]

    # make SavePlot objects before calling get_rise_relax_times()
    sp0 = SavePlot(False)
    sp1 = SavePlot(False)

    sp = [sp0, sp1]

    rstms, rxtms = get_rise_relax_times(d, f0, sr, method, f_norm, a_norm,
                                        gain, True, sp)

    if a_norm == 2**16:
        s = 'AMP UNNORMALIZED'
    elif a_norm == 2**17:
        s = 'AMP NORMALIZED'

#    print('Generated by ' + __file__ + '\n')

    print(s + '\n' + '-' * len(s) + '\n')