コード例 #1
0
def load_data(data_path, n_class, batch_num=20, n_total=500):
    train_xs = np.empty((0, 192))
    train_ys = np.empty((0, n_class))
    test_xs = np.empty((0, 192))
    test_ys = np.empty((0, n_class))

    for c in range(0, n_class):
        path = "%(path)s/%(class)d" % {'path': data_path, 'class': c}
        wav_files = find_click.list_wav_files(path)

        print("load data : %s, the number of files : %d" %
              (path, len(wav_files)))

        label = np.zeros(n_class)
        label[c] = 1

        # xs = np.empty((0, 256))
        xs = np.empty((0, 320))
        count = 0
        #
        for pathname in wav_files:
            wave_data, frame_rate = find_click.read_wav_file(pathname)

            energy = np.sqrt(np.sum(wave_data**2))
            wave_data /= energy
            wave_data = np.reshape(wave_data, [-1])
            xs = np.vstack((xs, wave_data))
            count += 1
            if count >= batch_num * n_total:
                break

        xs0, xs1 = split_data(xs)

        temp_train_xs = random_crop(xs0, batch_num, int(n_total * 4 / 5))
        temp_test_xs = random_crop(xs1, batch_num, int(n_total / 5))

        temp_train_ys = np.tile(label, (temp_train_xs.shape[0], 1))
        temp_test_ys = np.tile(label, (temp_test_xs.shape[0], 1))

        train_xs = np.vstack((train_xs, temp_train_xs))
        train_ys = np.vstack((train_ys, temp_train_ys))
        test_xs = np.vstack((test_xs, temp_test_xs))
        test_ys = np.vstack((test_ys, temp_test_ys))

    return train_xs, train_ys, test_xs, test_ys
コード例 #2
0
def run_cnn_detection(file_name,
                      snr_threshold_low=5,
                      snr_threshold_high=20,
                      save_npy=False,
                      dst_path='',
                      tar_fs=192000):
    """
        support the audio frame rate no bigger than 192000
            :param file_name:输入音频
            :param dst_path:click存储路径
            :param tar_fs:输出信号采样率
    """

    # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 1
    # sess = tf.Session(config=config)
    graph = tf.get_default_graph()
    # graph = tf.reset_default_graph()

    signal_len = 320
    count = 0
    click_arr = []
    audio, fs = find_click.read_wav_file(file_name)
    if audio.shape[1] > 1:
        audio = audio[:, 1]
    else:
        audio = audio[:, 0]

    # # 重采样至
    # audio = resample(audio, fs, tar_fs)
    # fs = tar_fs

    [path, wavname_ext] = os.path.split(file_name)
    wavname = wavname_ext.split('/')[-1]
    wavname = wavname.split('.')[0]

    if fs > tar_fs:
        print('down sample was not supported! current sampling rate is %d' %
              fs)
        return None

    len_audio = len(audio)

    # cost time
    start_t = time.time()
    time_len = len_audio / fs
    print('current audio length:', time_len)

    fl = 5000
    # fl = fs / 40
    # fs = 192000
    wn = 2 * fl / fs
    b, a = signal.butter(8, wn, 'high')

    audio_filted = signal.filtfilt(b, a, audio)
    scale = (2**15 - 1) / max(audio_filted)
    audio_filted *= scale
    # for i in np.arange(audio_filted.size):
    #     audio_filted[i] = audio_filted[i] * scale

    # audio_norm = local_normalize(audio_filted)
    #
    # audio_norm = audio_norm[0]
    #
    # time = np.arange(0, audio_filted.shape[0])
    # # # pl.plot(time, audio)
    # # # pl.show()
    # pl.plot(time, audio_filted)
    # # pl.title('high pass filter')
    # # pl.xlabel('time')
    # # # pl.show()

    seg_length = 192000
    data_seg = []
    if len_audio > seg_length:
        seg_num = math.ceil(len_audio / seg_length)
        for i in range(int(seg_num)):
            start_seg = seg_length * i
            if seg_length * (i + 1) > len_audio:
                end_seg = len_audio
            else:
                end_seg = seg_length * (i + 1)
            if end_seg > len_audio - 1:
                end_seg = len_audio - 1
            data = audio_filted[start_seg:end_seg]
            data_norm = local_normalize(data)
            data_norm = data_norm[0]
            data_seg.append(data_norm)
    else:
        audio_norm = local_normalize(audio_filted)
        audio_norm = audio_norm[0]
        data_seg.append(audio_norm)

        # detected_visual用于定位click
    detected_visual = np.zeros_like(audio_filted)
    # 预加载模型参数
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    with tf.Session(config=config) as sess:
        saver = tf.train.import_meta_graph(
            'params_cnn/allconv_cnn4click_norm_quater_manual.ckpt-300.meta')
        saver.restore(
            sess, 'params_cnn/allconv_cnn4click_norm_quater_manual.ckpt-300')
        # saver = tf.train.import_meta_graph('params_cnn/allconv_cnn4click_norm_quater_manual_conv2_supplement.ckpt-300.meta')
        # saver.restore(sess, 'params_cnn/allconv_cnn4click_norm_quater_manual_conv2_supplement.ckpt-300')
        # graph = tf.reset_default_graph()
        # 获取模型参数
        x = sess.graph.get_operation_by_name('x').outputs[0]
        y = sess.graph.get_operation_by_name('y').outputs[0]
        is_batch = sess.graph.get_operation_by_name('is_batch').outputs[0]
        keep_pro_l4_l5 = sess.graph.get_operation_by_name(
            'keep_pro_l4_l5').outputs[0]
        collection = sess.graph.get_collection('saved_module')
        y_net_out6 = collection[0]
        train_step = collection[1]
        accuracy = collection[2]
        click_label = 0
        # graph.finalize()

        for i in range(len(data_seg)):
            audio_norm = data_seg[i]
            y_out = sess.run(y_net_out6,
                             feed_dict={
                                 x: audio_norm.reshape(1, -1),
                                 keep_pro_l4_l5: 1.0,
                                 is_batch: False
                             })
            col_num = y_out.shape[2]
            y_out = y_out.reshape(col_num, 2)
            y_out = softmax(y_out)
            # print(y_out)
            predict = np.argmax(y_out, axis=1)
            for j in range(len(predict)):
                pro = y_out[j][predict[j]]
                if predict[
                        j] == click_label:  # and pro > 0.9:  # and pro > 0.9:
                    start_point = seg_length * i + 8 * j
                    end_point = start_point + 256
                    detected_visual[start_point:end_point] += 1
                    # num_detected = num_detected+1
                    # elif predict == 1:
                    #     detected_visual[start_point:end_point] -= 10
    # pl.plot(time, detected_visual*max(audio_filted)/32)
    # pl.show()

    # # detected click 定位
    # index_detected = np.where(detected_visual >= 8)[0]
    # if index_detected.size == 0:
    #     print("count = %(count)d" % {'count': count})
    #     return
    # detected_list = []
    # is_begin = False
    # pos_start = index_detected[0]
    # for i in range(len(index_detected)):
    #     if not is_begin:
    #         pos_start = index_detected[i]
    #         is_begin = True
    #     # 考虑到达list终点时的情况
    #     if i+1 >= len(index_detected):
    #         pos_end = index_detected[i]
    #         detected_list.append((pos_start, pos_end+1))
    #         break
    #     if index_detected[i+1] - index_detected[i] > 1:
    #         pos_end = index_detected[i]
    #         detected_list.append((pos_start, pos_end+1))
    #         is_begin = False
    #     else:
    #         continue
    detected_list = connected_component(detected_visual,
                                        threshold=1,
                                        len_threshold=64)
    if detected_list == []:
        print('no click was detected!')
        return detected_list, fs, audio, audio_filted, detected_visual

    # debug: 未过滤检测click数
    print('未过滤click数: %d' % len(detected_list))

    update_detected_list = []
    if snr_threshold_low > 0:
        print('启用snr过滤, snr_threshold_low=', snr_threshold_low)
        # 去掉低于10db的click
        index_to_remove = []
        for i in range(len(detected_list)):
            detected_pos = detected_list[i]
            # detected_length = detected_pos[1] - detected_pos[0]
            # if detected_length < 256 + 8 * 8:
            #     detected_visual[detected_pos[0]:detected_pos[1] + 1] = 0
            #     index_to_remove.append(i)
            #     continue
            ## snr estimate
            click = audio_filted[detected_pos[0]:detected_pos[1] + 1]
            tkeo = tkeo_algorithm(click)
            tkeo_mean = np.mean(tkeo)
            click_pos_list = connected_component(tkeo,
                                                 threshold=3 * tkeo_mean,
                                                 len_threshold=0)
            # print(len(click_pos_list))
            # x = np.arange(0, tkeo.size)
            # mean = np.ones(tkeo.size)*tkeo_mean
            # pl.subplot(311)
            # pl.plot(click)
            # pl.subplot(312)
            # pl.plot(tkeo)
            # pl.plot(x, mean*3)
            # pl.subplot(313)
            # pl.plot(detected_visual[detected_pos[0]:detected_pos[1] + 1])
            # pl.show()
            tmp_pos = []
            for pos in click_pos_list:
                # detected_clicks_energy = calcu_click_energy(click.reshape(1, -1))
                # max_index = np.argmax(click) + detected_pos[0]
                # click = audio_filted[max_index-50:max_index+50]
                # detected_clicks_energy = calcu_energy(click)
                # detected_clicks_energy = audio_filted[max_index]**2 * 0.9
                start = pos[0] + detected_pos[0] - 6
                end = pos[1] + detected_pos[0] + 6
                singel_click = audio_filted[start:end]
                detected_clicks_energy = calcu_energy(singel_click)
                if math.isnan(detected_clicks_energy):
                    continue
                noise_estimate1 = audio_filted[start - 512:start]
                noise_estimate2 = audio_filted[end:end + 512]
                noise_estimate = np.hstack((noise_estimate1, noise_estimate2))
                noise_energy = calcu_energy(noise_estimate)
                if noise_energy <= 0:
                    # detected_visual[detected_pos[0]:detected_pos[1] + 1] = 0
                    # index_to_remove.append(i)
                    continue
                snr = 10 * math.log10(detected_clicks_energy / noise_energy)
                if snr < snr_threshold_low or snr > snr_threshold_high:
                    # detected_visual[detected_pos[0]:detected_pos[1] + 1] = 0
                    continue
                    # index_to_remove.append(i)
                if start - 100 < 0:
                    start = 0
                else:
                    start -= 100
                if end + 100 > len_audio:
                    end = len_audio
                else:
                    end += 100
                ext_pos = (start, end)
                tmp_pos.append(ext_pos)
            tmp_pos = merge_pos(tmp_pos)
            if tmp_pos == []:
                continue
            if len(tmp_pos) >= 3:
                # update_detected_list.append(detected_pos)
                update_detected_list += tmp_pos
            else:
                update_detected_list += tmp_pos
            # has_removed = 0
            # for i in index_to_remove:
            #     detected_list.pop(i - has_removed)
            #     has_removed = has_removed + 1
    else:
        update_detected_list = detected_list

    print('过滤后剩余:', len(update_detected_list))

    # cost time
    end_t = time.time()
    cost_t = end_t - start_t
    print('current audio\'s sample rate:', fs)
    print('current audio length:', time_len)
    real_time_ratio = cost_t / time_len
    print('cost time:', cost_t)
    print('real time ratio:', real_time_ratio)

    # # debug
    # for i in detected_list:
    #     detected_visual[i[0]:i[1]] = 1
    # detected_visual = detected_visual * 20000
    # # # print('the number of detected click: %g' % num_detected)
    # pl.plot(time, detected_visual)
    # pl.show()

    if save_npy:

        dst = "%(path)s/%(pre)s" \
              % {'path': dst_path, 'pre': wavname}
        if not os.path.exists(dst):
            mkdir(dst)
        print(dst)

        pre_time_stamp = 0
        start_time = 0
        tmp_clicktrain = np.empty((0, signal_len))
        train_num = 0
        # is_train_start = True
        for pos_tuple in update_detected_list:
            temp_click = audio_filted[pos_tuple[0]:pos_tuple[1]]
            current_time_stamp = (pos_tuple[0] + pos_tuple[1]) / (2 * fs)

            # temp_click = resample(temp_click, fs, tar_fs)

            max_index = np.argmax(temp_click)
            max_index += pos_tuple[0]
            t_start = max_index - int(signal_len / 2)
            if t_start < 0:
                continue
            t_end = max_index + int(signal_len / 2)
            if t_end > len_audio:
                break
            click_data = audio_filted[t_start:t_end]

            click_data = resample(click_data, fs, tar_fs)

            click_data = cut_data(click_data, signal_len)

            click_data = click_data.astype(np.short)

            beg_idx = np.random.randint(64, (64 + 32))
            crop_x = click_data[beg_idx:(beg_idx + 192)]
            crop_x = np.reshape(crop_x, [1, 192])

            crop_x = np.fft.fft(crop_x)
            crop_x = np.sqrt(crop_x.real**2 + crop_x.imag**2)

            crop_x = crop_x[0, :96]
            crop_x = np.reshape(crop_x, [1, 96])
            # peak值位于20k以下,75k以上的滤去
            peak_index = np.argmax(crop_x)
            if peak_index < 20 or peak_index > 75:
                continue

            # if is_train_start:
            #     start_time = current_time_stamp
            #     pre_time_stamp = current_time_stamp
            #     tmp_clicktrain = np.empty((0, signal_len))
            #     tmp_clicktrain = np.vstack((tmp_clicktrain, click_data))
            #     is_train_start = False
            #     continue

            train_duration = current_time_stamp - start_time
            train_interval = current_time_stamp - pre_time_stamp
            if train_duration > 2.0 or train_interval > 1.0:
                if tmp_clicktrain.shape[0] != 0:
                    # click_arr.append(tmp_clicktrain)
                    train_num += 1
                    npy_path = "%(path)s/%(pre)s_N%(num)d.npy" \
                          % {'path': dst, 'pre': wavname, 'num': train_num}
                    np.save(npy_path, np.array(tmp_clicktrain, dtype=np.short))

                tmp_clicktrain = click_data
                start_time = current_time_stamp
                pre_time_stamp = current_time_stamp
            else:
                tmp_clicktrain = np.vstack((tmp_clicktrain, click_data))
                pre_time_stamp = current_time_stamp
            # print(click_data.shape)
            # click_arr.append(click_data)
            count += 1

        if tmp_clicktrain.shape[0] != 0:
            train_num += 1
            npy_path = "%(path)s/%(pre)s_N%(num)d.npy" \
                       % {'path': dst, 'pre': wavname, 'num': train_num}
            np.save(npy_path, np.array(tmp_clicktrain, dtype=np.short))
        print('click train num:', train_num)

        # dst = "%(path)s/%(pre)s_N%(num)d.npy" \
        #       % {'path': dst_path, 'pre': wavname, 'num': count}
        # print(dst)
        # np.save(dst, np.array(click_arr, dtype=np.short))

        print("count = %(count)d" % {'count': count})

    return update_detected_list, fs, audio, audio_filted, detected_visual
コード例 #3
0
def detect_save_click(class_path,
                      class_name,
                      snr_threshold_low=5,
                      snr_threshold_high=100):
    tar_fs = 96000
    signal_len = 320
    folder_list = find_click.list_files(class_path)
    if folder_list == []:
        folder_list = folder_list + [class_path]
    for folder in folder_list:
        print(folder)
        count = 0
        wav_files = find_click.list_wav_files(folder)

        # wav_files = shuffle_frames(wav_files)

        path_name = folder.split('/')[-1]

        dst_path = "./TKEO_wk3_complete/%(class)s/%(type)s" % {
            'class': class_name,
            'type': path_name
        }
        if not os.path.exists(dst_path):
            mkdir(dst_path)

        for pathname in wav_files:

            print(pathname)

            wave_data, frameRate = find_click.read_wav_file(pathname)

            # wave_data = resample(wave_data, frameRate, tar_fs)  #

            [path, wavname_ext] = os.path.split(pathname)
            wavname = wavname_ext.split('/')[-1]
            wavname = wavname.split('.')[0]

            fl = 5000
            fwhm = 0.0004
            fdr_threshold = 0.65
            click_index, xn = find_click.find_click_fdr_tkeo(
                wave_data, frameRate, fl, fwhm, fdr_threshold, signal_len, 8)

            scale = (2**12 - 1) / max(xn)
            for i in np.arange(xn.size):
                xn[i] = xn[i] * scale

            click_arr = []
            for j in range(click_index.shape[0]):
                index = click_index[j]
                # click_data = wave_data[index[0]:index[1], 0]

                click_data = xn[index[0]:index[1]]

                #  信噪比过滤
                detected_clicks_energy = calcu_click_energy(
                    click_data.reshape(1, -1))
                noise_estimate1 = xn[index[0] - 256:index[0]]
                noise_estimate2 = xn[index[1] + 1:index[1] + 257]
                noise_estimate = np.hstack((noise_estimate1, noise_estimate2))
                noise_energy = calcu_energy(noise_estimate)
                if noise_energy <= 0 or detected_clicks_energy <= 0:
                    continue
                snr = 10 * math.log10(detected_clicks_energy / noise_energy)
                if snr < snr_threshold_low or snr > snr_threshold_high:
                    continue

                click_data = resample(click_data, frameRate, tar_fs)  # 前置TKEO前

                click_data = cut_data(click_data, signal_len)

                click_data = click_data.astype(np.short)

                click_arr.append(click_data)
                # filename = "%(path)s/%(pre)s_click_%(n)06d.wav" % {'path': dst_path, 'pre': wavname, 'n': count}
                # f = wave.open(filename, "wb")
                # # set wav params
                # f.setnchannels(1)
                # f.setsampwidth(2)
                # f.setframerate(tar_fs)
                # # turn the data to string
                # f.writeframes(click_data.tostring())
                # f.close()
                count = count + 1

            dst = "%(path)s/%(pre)s_N%(num)d.npy" \
                      % {'path': dst_path, 'pre': wavname, 'num': len(click_arr)}
            print(dst)
            np.save(dst, np.array(click_arr, dtype=np.short))

            # if count > 20000:
            #     break

        print("count = %(count)d" % {'count': count})
コード例 #4
0
def load_lwy_data(batch_num=20, n_total=500):

    dict = {
        '0': '',
        '1': '',
        '2': '',
        '3': '',
        '4': '',
        '5': '',
        '6': '',
        '7': ''
    }

    dict[
        "0"] = "/home/fish/ROBB/CNN_click/click/WavData/BBW/Blainvilles_beaked_whale_(Mesoplodon_densirostris)"
    dict[
        "1"] = "/home/fish/ROBB/CNN_click/click/WavData/Gm/Pilot_whale_(Globicephala_macrorhynchus)"
    dict[
        "2"] = "/home/fish/ROBB/CNN_click/click/WavData/Gg/Rissos_(Grampus_grisieus)"

    dict["3"] = "/home/fish/ROBB/CNN_click/click/WavData/Tt/palmyra2006"
    dict["4"] = "/home/fish/ROBB/CNN_click/click/WavData/Dc/Dc"
    dict["5"] = "/home/fish/ROBB/CNN_click/click/WavData/Dd/Dd"
    dict["6"] = "/home/fish/ROBB/CNN_click/click/WavData/Melon/palmyra2006"
    dict["7"] = "/home/fish/ROBB/CNN_click/click/WavData/Spinner/palmyra2006"

    n_class = len(dict)
    train_xs = np.empty((0, 192))
    train_ys = np.empty((0, n_class))
    test_xs = np.empty((0, 192))
    test_ys = np.empty((0, n_class))

    for key in dict:
        # path = "%(path)s/%(class)d" % {'path': data_path, 'class': c}
        path = dict[key]
        c = int(key)
        wav_files = find_click.list_wav_files(path)

        print("load data : %s, the number of files : %d, class: %d" %
              (path, len(wav_files), c))

        label = np.zeros(n_class)
        label[c] = 1

        # xs = np.empty((0, 256))
        xs = np.empty((0, 320))
        count = 0
        #
        for pathname in wav_files:
            wave_data, frame_rate = find_click.read_wav_file(pathname)

            # energy = np.sqrt(np.sum(wave_data ** 2))
            # wave_data /= energy
            wave_data = np.reshape(wave_data, [-1])
            xs = np.vstack((xs, wave_data))
            count += 1
            if count >= (batch_num + 10) * n_total:
                break

        xs0, xs1 = split_data(xs)

        temp_train_xs = random_crop(xs0, batch_num, int(n_total * 4 / 5))
        temp_test_xs = random_crop(xs1, batch_num, int(n_total / 5))

        temp_train_ys = np.tile(label, (temp_train_xs.shape[0], 1))
        temp_test_ys = np.tile(label, (temp_test_xs.shape[0], 1))

        train_xs = np.vstack((train_xs, temp_train_xs))
        train_ys = np.vstack((train_ys, temp_train_ys))
        test_xs = np.vstack((test_xs, temp_test_xs))
        test_ys = np.vstack((test_ys, temp_test_ys))

    return train_xs, train_ys, test_xs, test_ys
コード例 #5
0
        "1"] = "/media/ywy/本地磁盘/Data/MobySound/3rd_Workshop/Training_Data/Pilot_whale_(Globicephala_macrorhynchus)"
    dict[
        "2"] = "/media/ywy/本地磁盘/Data/MobySound/3rd_Workshop/Training_Data/Rissos_(Grampus_grisieus)"

    for key in dict:
        print(dict[key])
        count = 0
        wav_files = find_click.list_wav_files(dict[key])

        dst_path = "./Data/ClickC8/%(class)s" % {'class': key}
        mkdir(dst_path)

        for pathname in wav_files:

            print(pathname)
            wave_data, frameRate = find_click.read_wav_file(pathname)

            fl = 5000
            fwhm = 0.0008
            fdr_threshold = 0.62
            click_index, xn = find_click.find_click_fdr_tkeo(
                wave_data, frameRate, fl, fwhm, fdr_threshold, signal_len, 8)

            scale = (2**15 - 1) / max(xn)
            for i in np.arange(xn.size):
                xn[i] = xn[i] * scale

            for j in range(click_index.shape[0]):
                index = click_index[j]
                # click_data = wave_data[index[0]:index[1], 0]
コード例 #6
0
def load_data_lstm(data_path, n_class, batch_num=20, n_total=500):
    train = []
    test = []

    x_in = tf.placeholder("float", [None, 192])

    # 输入
    x_image = tf.reshape(x_in, [-1, 1, 192, 1])

    # 第一个卷积层
    W_conv1 = weight_variable([1, 5, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_1x2(h_conv1)

    # 第二个卷积层
    W_conv2 = weight_variable([1, 5, 32, 32])
    b_conv2 = bias_variable([32])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_1x2(h_conv2)

    # 密集链接层
    W_fc1 = weight_variable([1 * 48 * 32, 256])
    b_fc1 = bias_variable([256])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 1 * 48 * 32])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

    # Dropout
    keep_prob = tf.placeholder("float")
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob=keep_prob)

    # 输出层
    W_fc2 = weight_variable([256, n_class])
    b_fc2 = bias_variable([n_class])
    y = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

    init = tf.global_variables_initializer()

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init)
        saver.restore(sess, "params/cnn_net.ckpt")  # 加载训练好的网络参数

        test_cnn = []

        for c in range(0, n_class):
            # path = "./Data/Click/%(class)d" % {'class': c}
            path = "%(path)s/%(class)d" % {'path': data_path, 'class': c}
            wav_files = find_click.list_wav_files(path)

            print("load data : %s, the number of files : %d" %
                  (path, len(wav_files)))

            # xs = np.empty((0, 256))
            xs = np.empty((0, 320))
            count = 0
            for pathname in wav_files:
                wave_data, frame_rate = find_click.read_wav_file(pathname)
                energy = np.sqrt(np.sum(wave_data**2))
                wave_data /= energy
                wave_data = np.reshape(wave_data, [-1])
                xs = np.vstack((xs, wave_data))
                count += 1
                if count > batch_num * n_total:
                    break

            xs0, xs1 = split_data(xs)

            sample_num = xs0.shape[0]
            for i in range(0, int(n_total * 4 / 5)):
                frames = np.empty((0, 256))
                for j in range(batch_num * i, batch_num * (i + 1)):
                    index = j % sample_num
                    temp_x = xs0[index]
                    # beg_idx = np.random.randint(0, 32)
                    beg_idx = np.random.randint(64, (64 + 32))
                    crop_x = temp_x[beg_idx:(beg_idx + 192)]
                    crop_x = np.reshape(crop_x, [1, 192])
                    ftu = sess.run(h_fc1, feed_dict={x_in:
                                                     crop_x})  # 计算CNN网络输出
                    frames = np.vstack((frames, ftu))

                frames = np.expand_dims(np.expand_dims(frames, axis=0), axis=0)
                frames = list(frames)

                label = [0] * n_class
                label[c] = 1
                label = np.array([[label]])
                label = list(label)
                sample = frames + label
                train.append(sample)

            sample_num = xs1.shape[0]
            for i in range(0, int(n_total / 5)):
                frames = np.empty((0, 256))
                tmp_xs = np.empty((0, 192))
                for j in range(batch_num * i, batch_num * (i + 1)):
                    index = j % sample_num
                    temp_x = xs1[index]
                    # beg_idx = np.random.randint(0, 32)
                    beg_idx = np.random.randint(64, (64 + 32))
                    crop_x = temp_x[beg_idx:(beg_idx + 192)]
                    crop_x = np.reshape(crop_x, [1, 192])
                    ftu = sess.run(h_fc1, feed_dict={x_in:
                                                     crop_x})  # 计算CNN网络输出
                    frames = np.vstack((frames, ftu))
                    tmp_xs = np.vstack((tmp_xs, crop_x))

                frames = np.expand_dims(np.expand_dims(frames, axis=0), axis=0)
                frames = list(frames)

                label = [0] * n_class
                label[c] = 1
                label = np.array([[label]])
                label = list(label)
                sample = frames + label

                test.append(sample)

                tmp_xs = np.expand_dims(np.expand_dims(tmp_xs, axis=0), axis=0)
                tmp_xs = list(tmp_xs)
                sample = tmp_xs + label
                test_cnn.append(sample)

        count = 0
        for i in range(len(test_cnn)):

            test_xs = test_cnn[i][0]

            label = np.zeros(n_class)
            for j in range(0, test_xs.shape[1]):
                txs = test_xs[0, j, :]
                txs = np.reshape(txs, [1, 192])
                out_y = sess.run(y, feed_dict={x_in: txs, keep_prob: 1.0})
                c = np.argmax(out_y, 1)
                label[c] += 1

            ref_y = test_cnn[i][1]
            if np.equal(np.argmax(label), np.argmax(ref_y)):
                count += 1

        print('cnn test accuracy: ', round(count / len(test_cnn), 3))

    return train, test
コード例 #7
0
def test_cnn_batch_data(data_path, n_class, batch_num=20, n_total=500):
    click_batch = []
    for c in range(0, n_class):
        path = "%(path)s/%(class)d" % {'path': data_path, 'class': c}
        wav_files = find_click.list_wav_files(path)
        print("load data : %s, the number of files : %d" % (path, len(wav_files)))

        # 为避免训练网络用的Click用于测试, 类似于训练时区分训练和测试样本
        #  利用全部样本后1/5的Click生成测试样本
        xs = np.empty((0, 320))
        count = 0
        split_idx = int(len(wav_files) * 4 / 5)
        for pathname in wav_files:
            count += 1
            if count < split_idx:
                continue
            wave_data, frame_rate = find_click.read_wav_file(pathname)
            energy = np.sqrt(np.sum(wave_data ** 2))
            wave_data /= energy
            wave_data = np.reshape(wave_data, [-1])
            xs = np.vstack((xs, wave_data))
            if count >= batch_num * n_total:
                break

        sample_num = xs.shape[0]
        for i in range(0, int(n_total / 5)):
            tmp_xs = np.empty((0, 192))
            for j in range(batch_num * i, batch_num * (i + 1)):
                index = j % sample_num
                temp_x = xs[index]
                beg_idx = np.random.randint(64, (64 + 32))
                crop_x = temp_x[beg_idx:(beg_idx + 192)]
                crop_x = np.reshape(crop_x, [1, 192])
                tmp_xs = np.vstack((tmp_xs, crop_x))

            label = [0] * n_class
            label[c] = 1

            label = np.array([[label]])
            label = list(label)

            tmp_xs = np.expand_dims(np.expand_dims(tmp_xs, axis=0), axis=0)
            tmp_xs = list(tmp_xs)
            sample = tmp_xs + label
            click_batch.append(sample)

    x = tf.placeholder("float", [None, 192])
    # 输入
    x_image = tf.reshape(x, [-1, 1, 192, 1])

    # 第一个卷积层
    W_conv1 = weight_variable([1, 5, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_1x2(h_conv1)

    # 第二个卷积层
    W_conv2 = weight_variable([1, 5, 32, 32])
    b_conv2 = bias_variable([32])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_1x2(h_conv2)

    # 密集链接层
    W_fc1 = weight_variable([1 * 48 * 32, 256])
    b_fc1 = bias_variable([256])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 1 * 48 * 32])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

    # Dropout
    keep_prob = tf.placeholder("float")
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob=keep_prob)

    # 输出层
    W_fc2 = weight_variable([256, n_class])
    b_fc2 = bias_variable([n_class])
    y = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

    init = tf.global_variables_initializer()

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init)
        saver.restore(sess, "params/cnn_net.ckpt")  # 加载训练好的网络参数

        count = 0
        for i in range(len(click_batch)):
            temp_xs = click_batch[i][0]
            label = np.zeros(n_class)
            for j in range(0, temp_xs.shape[1]):
                txs = temp_xs[0, j, :]
                txs = np.reshape(txs, [1, 192])
                out_y = sess.run(y, feed_dict={x: txs, keep_prob: 1.0})
                c = np.argmax(out_y, 1)
                label[c] += 1

            ref_y = click_batch[i][1]
            if np.equal(np.argmax(label), np.argmax(ref_y)):
                count += 1

        print('cnn test accuracy (majority voting): ', round(count / len(click_batch), 3))

        count = 0
        weight = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
        for i in range(len(click_batch)):
            temp_xs = click_batch[i][0]
            label = np.zeros(n_class)
            for j in range(0, temp_xs.shape[1]):
                txs = temp_xs[0, j, :]
                txs = np.reshape(txs, [1, 192])
                out = sess.run(weight, feed_dict={x: txs, keep_prob: 1.0})
                out = np.reshape(out, label.shape)
                label = label + out

            ref_y = click_batch[i][1]
            if np.equal(np.argmax(label), np.argmax(ref_y)):
                count += 1

        print('cnn test accuracy (weight voting): ', round(count / len(click_batch), 3))

        count = 0
        for i in range(len(click_batch)):
            temp_xs = click_batch[i][0]
            label = np.zeros(n_class)
            for j in range(0, temp_xs.shape[1]):
                txs = temp_xs[0, j, :]
                txs = np.reshape(txs, [1, 192])
                out = sess.run(y, feed_dict={x: txs, keep_prob: 1.0})
                out = np.reshape(out, label.shape)
                label = label + out

            ref_y = click_batch[i][1]
            if np.equal(np.argmax(label), np.argmax(ref_y)):
                count += 1

        print('cnn test accuracy (sum of softmax voting): ', round(count / len(click_batch), 3))