示例#1
0
def PostProcess(pred,
                mode="note",
                onset_th=7.5,
                lower_onset_th=None,
                split_bound=36,
                dura_th=2,
                frm_th=1,
                t_unit=0.02):
    if mode == "note" or mode == "mpe_note":
        if lower_onset_th is not None:
            norm_pred = norm_split_onset_dura(pred,
                                              onset_th=onset_th,
                                              lower_onset_th=lower_onset_th,
                                              split_bound=split_bound,
                                              dura_th=dura_th,
                                              interpolate=True)
        else:
            norm_pred = norm_onset_dura(pred,
                                        onset_th=onset_th,
                                        dura_th=dura_th,
                                        interpolate=True)

        notes = infer_piece(down_sample(norm_pred), t_unit=0.01)
        midi = to_midi(notes, t_unit=t_unit / 2)

    elif mode == "frame" or mode == "mpe_frame":
        ch_num = pred.shape[2]
        if ch_num == 2:
            mix = pred[:, :, 1]
        elif ch_num == 3:
            mix = (pred[:, :, 1] + pred[:, :, 2]) / 2
        else:
            raise ValueError("Unknown channel length: {}".format(ch_num))

        p = norm(mix)
        p = np.where(p > frm_th, 1, 0)
        p = roll_down_sample(p)

        notes = []
        for idx in range(p.shape[1]):
            p_note = find_occur(p[:, idx], t_unit=t_unit)
            for nn in p_note:
                note = {
                    "pitch": idx,
                    "start": nn["onset"],
                    "end": nn["offset"],
                    "stren": mix[int(nn["onset"] * t_unit), idx * 4]
                }
                notes.append(note)
        midi = to_midi(notes, t_unit=t_unit)

    else:
        raise ValueError(
            "Supported mode are ['note', 'frame']. Given mode: {}".format(
                mode))

    return midi
示例#2
0
def plot3(pred):
    fig, axes = plt.subplots(nrows=2)
    
    th = 0.5
    
    on = pred[:,:,1]
    on = (on-np.mean(on))/np.std(on)
    #on_th = np.where(on>2.5, 1, 0)
    on[on<th] = 0
    on = roll_down_sample(on)
    axes[0].imshow(on.transpose(), origin="lower", aspect="auto")
    
    on = pred[:,:,2]
    on = (on-np.mean(on))/np.std(on)
    #on_th = np.where(on>5, 1, 0)
    on[on<th] = 0
    on = roll_down_sample(on)
    axes[1].imshow(on.transpose(), origin="lower", aspect="auto")
    plt.show()
示例#3
0
def down_sample(pred):
    dd = roll_down_sample(pred[:,:,0])
    for i in range(1, pred.shape[2]):
        dd = np.dstack([dd, roll_down_sample(pred[:,:,i], occur_num=3)])

    return dd