def load_data(data_path, n_class, batch_num=20, n_total=500): train_xs = np.empty((0, 192)) train_ys = np.empty((0, n_class)) test_xs = np.empty((0, 192)) test_ys = np.empty((0, n_class)) for c in range(0, n_class): path = "%(path)s/%(class)d" % {'path': data_path, 'class': c} wav_files = find_click.list_wav_files(path) print("load data : %s, the number of files : %d" % (path, len(wav_files))) label = np.zeros(n_class) label[c] = 1 # xs = np.empty((0, 256)) xs = np.empty((0, 320)) count = 0 # for pathname in wav_files: wave_data, frame_rate = find_click.read_wav_file(pathname) energy = np.sqrt(np.sum(wave_data**2)) wave_data /= energy wave_data = np.reshape(wave_data, [-1]) xs = np.vstack((xs, wave_data)) count += 1 if count >= batch_num * n_total: break xs0, xs1 = split_data(xs) temp_train_xs = random_crop(xs0, batch_num, int(n_total * 4 / 5)) temp_test_xs = random_crop(xs1, batch_num, int(n_total / 5)) temp_train_ys = np.tile(label, (temp_train_xs.shape[0], 1)) temp_test_ys = np.tile(label, (temp_test_xs.shape[0], 1)) train_xs = np.vstack((train_xs, temp_train_xs)) train_ys = np.vstack((train_ys, temp_train_ys)) test_xs = np.vstack((test_xs, temp_test_xs)) test_ys = np.vstack((test_ys, temp_test_ys)) return train_xs, train_ys, test_xs, test_ys
def run_cnn_detection(file_name, snr_threshold_low=5, snr_threshold_high=20, save_npy=False, dst_path='', tar_fs=192000): """ support the audio frame rate no bigger than 192000 :param file_name:输入音频 :param dst_path:click存储路径 :param tar_fs:输出信号采样率 """ # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"] = "-1" config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 1 # sess = tf.Session(config=config) graph = tf.get_default_graph() # graph = tf.reset_default_graph() signal_len = 320 count = 0 click_arr = [] audio, fs = find_click.read_wav_file(file_name) if audio.shape[1] > 1: audio = audio[:, 1] else: audio = audio[:, 0] # # 重采样至 # audio = resample(audio, fs, tar_fs) # fs = tar_fs [path, wavname_ext] = os.path.split(file_name) wavname = wavname_ext.split('/')[-1] wavname = wavname.split('.')[0] if fs > tar_fs: print('down sample was not supported! current sampling rate is %d' % fs) return None len_audio = len(audio) # cost time start_t = time.time() time_len = len_audio / fs print('current audio length:', time_len) fl = 5000 # fl = fs / 40 # fs = 192000 wn = 2 * fl / fs b, a = signal.butter(8, wn, 'high') audio_filted = signal.filtfilt(b, a, audio) scale = (2**15 - 1) / max(audio_filted) audio_filted *= scale # for i in np.arange(audio_filted.size): # audio_filted[i] = audio_filted[i] * scale # audio_norm = local_normalize(audio_filted) # # audio_norm = audio_norm[0] # # time = np.arange(0, audio_filted.shape[0]) # # # pl.plot(time, audio) # # # pl.show() # pl.plot(time, audio_filted) # # pl.title('high pass filter') # # pl.xlabel('time') # # # pl.show() seg_length = 192000 data_seg = [] if len_audio > seg_length: seg_num = math.ceil(len_audio / seg_length) for i in range(int(seg_num)): start_seg = seg_length * i if seg_length * (i + 1) > len_audio: end_seg = len_audio else: end_seg = seg_length * (i + 1) if end_seg > len_audio - 1: end_seg = len_audio - 1 data = audio_filted[start_seg:end_seg] data_norm = local_normalize(data) data_norm = data_norm[0] data_seg.append(data_norm) else: audio_norm = local_normalize(audio_filted) audio_norm = audio_norm[0] data_seg.append(audio_norm) # detected_visual用于定位click detected_visual = np.zeros_like(audio_filted) # 预加载模型参数 # os.environ['CUDA_VISIBLE_DEVICES'] = '0' with tf.Session(config=config) as sess: saver = tf.train.import_meta_graph( 'params_cnn/allconv_cnn4click_norm_quater_manual.ckpt-300.meta') saver.restore( sess, 'params_cnn/allconv_cnn4click_norm_quater_manual.ckpt-300') # saver = tf.train.import_meta_graph('params_cnn/allconv_cnn4click_norm_quater_manual_conv2_supplement.ckpt-300.meta') # saver.restore(sess, 'params_cnn/allconv_cnn4click_norm_quater_manual_conv2_supplement.ckpt-300') # graph = tf.reset_default_graph() # 获取模型参数 x = sess.graph.get_operation_by_name('x').outputs[0] y = sess.graph.get_operation_by_name('y').outputs[0] is_batch = sess.graph.get_operation_by_name('is_batch').outputs[0] keep_pro_l4_l5 = sess.graph.get_operation_by_name( 'keep_pro_l4_l5').outputs[0] collection = sess.graph.get_collection('saved_module') y_net_out6 = collection[0] train_step = collection[1] accuracy = collection[2] click_label = 0 # graph.finalize() for i in range(len(data_seg)): audio_norm = data_seg[i] y_out = sess.run(y_net_out6, feed_dict={ x: audio_norm.reshape(1, -1), keep_pro_l4_l5: 1.0, is_batch: False }) col_num = y_out.shape[2] y_out = y_out.reshape(col_num, 2) y_out = softmax(y_out) # print(y_out) predict = np.argmax(y_out, axis=1) for j in range(len(predict)): pro = y_out[j][predict[j]] if predict[ j] == click_label: # and pro > 0.9: # and pro > 0.9: start_point = seg_length * i + 8 * j end_point = start_point + 256 detected_visual[start_point:end_point] += 1 # num_detected = num_detected+1 # elif predict == 1: # detected_visual[start_point:end_point] -= 10 # pl.plot(time, detected_visual*max(audio_filted)/32) # pl.show() # # detected click 定位 # index_detected = np.where(detected_visual >= 8)[0] # if index_detected.size == 0: # print("count = %(count)d" % {'count': count}) # return # detected_list = [] # is_begin = False # pos_start = index_detected[0] # for i in range(len(index_detected)): # if not is_begin: # pos_start = index_detected[i] # is_begin = True # # 考虑到达list终点时的情况 # if i+1 >= len(index_detected): # pos_end = index_detected[i] # detected_list.append((pos_start, pos_end+1)) # break # if index_detected[i+1] - index_detected[i] > 1: # pos_end = index_detected[i] # detected_list.append((pos_start, pos_end+1)) # is_begin = False # else: # continue detected_list = connected_component(detected_visual, threshold=1, len_threshold=64) if detected_list == []: print('no click was detected!') return detected_list, fs, audio, audio_filted, detected_visual # debug: 未过滤检测click数 print('未过滤click数: %d' % len(detected_list)) update_detected_list = [] if snr_threshold_low > 0: print('启用snr过滤, snr_threshold_low=', snr_threshold_low) # 去掉低于10db的click index_to_remove = [] for i in range(len(detected_list)): detected_pos = detected_list[i] # detected_length = detected_pos[1] - detected_pos[0] # if detected_length < 256 + 8 * 8: # detected_visual[detected_pos[0]:detected_pos[1] + 1] = 0 # index_to_remove.append(i) # continue ## snr estimate click = audio_filted[detected_pos[0]:detected_pos[1] + 1] tkeo = tkeo_algorithm(click) tkeo_mean = np.mean(tkeo) click_pos_list = connected_component(tkeo, threshold=3 * tkeo_mean, len_threshold=0) # print(len(click_pos_list)) # x = np.arange(0, tkeo.size) # mean = np.ones(tkeo.size)*tkeo_mean # pl.subplot(311) # pl.plot(click) # pl.subplot(312) # pl.plot(tkeo) # pl.plot(x, mean*3) # pl.subplot(313) # pl.plot(detected_visual[detected_pos[0]:detected_pos[1] + 1]) # pl.show() tmp_pos = [] for pos in click_pos_list: # detected_clicks_energy = calcu_click_energy(click.reshape(1, -1)) # max_index = np.argmax(click) + detected_pos[0] # click = audio_filted[max_index-50:max_index+50] # detected_clicks_energy = calcu_energy(click) # detected_clicks_energy = audio_filted[max_index]**2 * 0.9 start = pos[0] + detected_pos[0] - 6 end = pos[1] + detected_pos[0] + 6 singel_click = audio_filted[start:end] detected_clicks_energy = calcu_energy(singel_click) if math.isnan(detected_clicks_energy): continue noise_estimate1 = audio_filted[start - 512:start] noise_estimate2 = audio_filted[end:end + 512] noise_estimate = np.hstack((noise_estimate1, noise_estimate2)) noise_energy = calcu_energy(noise_estimate) if noise_energy <= 0: # detected_visual[detected_pos[0]:detected_pos[1] + 1] = 0 # index_to_remove.append(i) continue snr = 10 * math.log10(detected_clicks_energy / noise_energy) if snr < snr_threshold_low or snr > snr_threshold_high: # detected_visual[detected_pos[0]:detected_pos[1] + 1] = 0 continue # index_to_remove.append(i) if start - 100 < 0: start = 0 else: start -= 100 if end + 100 > len_audio: end = len_audio else: end += 100 ext_pos = (start, end) tmp_pos.append(ext_pos) tmp_pos = merge_pos(tmp_pos) if tmp_pos == []: continue if len(tmp_pos) >= 3: # update_detected_list.append(detected_pos) update_detected_list += tmp_pos else: update_detected_list += tmp_pos # has_removed = 0 # for i in index_to_remove: # detected_list.pop(i - has_removed) # has_removed = has_removed + 1 else: update_detected_list = detected_list print('过滤后剩余:', len(update_detected_list)) # cost time end_t = time.time() cost_t = end_t - start_t print('current audio\'s sample rate:', fs) print('current audio length:', time_len) real_time_ratio = cost_t / time_len print('cost time:', cost_t) print('real time ratio:', real_time_ratio) # # debug # for i in detected_list: # detected_visual[i[0]:i[1]] = 1 # detected_visual = detected_visual * 20000 # # # print('the number of detected click: %g' % num_detected) # pl.plot(time, detected_visual) # pl.show() if save_npy: dst = "%(path)s/%(pre)s" \ % {'path': dst_path, 'pre': wavname} if not os.path.exists(dst): mkdir(dst) print(dst) pre_time_stamp = 0 start_time = 0 tmp_clicktrain = np.empty((0, signal_len)) train_num = 0 # is_train_start = True for pos_tuple in update_detected_list: temp_click = audio_filted[pos_tuple[0]:pos_tuple[1]] current_time_stamp = (pos_tuple[0] + pos_tuple[1]) / (2 * fs) # temp_click = resample(temp_click, fs, tar_fs) max_index = np.argmax(temp_click) max_index += pos_tuple[0] t_start = max_index - int(signal_len / 2) if t_start < 0: continue t_end = max_index + int(signal_len / 2) if t_end > len_audio: break click_data = audio_filted[t_start:t_end] click_data = resample(click_data, fs, tar_fs) click_data = cut_data(click_data, signal_len) click_data = click_data.astype(np.short) beg_idx = np.random.randint(64, (64 + 32)) crop_x = click_data[beg_idx:(beg_idx + 192)] crop_x = np.reshape(crop_x, [1, 192]) crop_x = np.fft.fft(crop_x) crop_x = np.sqrt(crop_x.real**2 + crop_x.imag**2) crop_x = crop_x[0, :96] crop_x = np.reshape(crop_x, [1, 96]) # peak值位于20k以下,75k以上的滤去 peak_index = np.argmax(crop_x) if peak_index < 20 or peak_index > 75: continue # if is_train_start: # start_time = current_time_stamp # pre_time_stamp = current_time_stamp # tmp_clicktrain = np.empty((0, signal_len)) # tmp_clicktrain = np.vstack((tmp_clicktrain, click_data)) # is_train_start = False # continue train_duration = current_time_stamp - start_time train_interval = current_time_stamp - pre_time_stamp if train_duration > 2.0 or train_interval > 1.0: if tmp_clicktrain.shape[0] != 0: # click_arr.append(tmp_clicktrain) train_num += 1 npy_path = "%(path)s/%(pre)s_N%(num)d.npy" \ % {'path': dst, 'pre': wavname, 'num': train_num} np.save(npy_path, np.array(tmp_clicktrain, dtype=np.short)) tmp_clicktrain = click_data start_time = current_time_stamp pre_time_stamp = current_time_stamp else: tmp_clicktrain = np.vstack((tmp_clicktrain, click_data)) pre_time_stamp = current_time_stamp # print(click_data.shape) # click_arr.append(click_data) count += 1 if tmp_clicktrain.shape[0] != 0: train_num += 1 npy_path = "%(path)s/%(pre)s_N%(num)d.npy" \ % {'path': dst, 'pre': wavname, 'num': train_num} np.save(npy_path, np.array(tmp_clicktrain, dtype=np.short)) print('click train num:', train_num) # dst = "%(path)s/%(pre)s_N%(num)d.npy" \ # % {'path': dst_path, 'pre': wavname, 'num': count} # print(dst) # np.save(dst, np.array(click_arr, dtype=np.short)) print("count = %(count)d" % {'count': count}) return update_detected_list, fs, audio, audio_filted, detected_visual
def detect_save_click(class_path, class_name, snr_threshold_low=5, snr_threshold_high=100): tar_fs = 96000 signal_len = 320 folder_list = find_click.list_files(class_path) if folder_list == []: folder_list = folder_list + [class_path] for folder in folder_list: print(folder) count = 0 wav_files = find_click.list_wav_files(folder) # wav_files = shuffle_frames(wav_files) path_name = folder.split('/')[-1] dst_path = "./TKEO_wk3_complete/%(class)s/%(type)s" % { 'class': class_name, 'type': path_name } if not os.path.exists(dst_path): mkdir(dst_path) for pathname in wav_files: print(pathname) wave_data, frameRate = find_click.read_wav_file(pathname) # wave_data = resample(wave_data, frameRate, tar_fs) # [path, wavname_ext] = os.path.split(pathname) wavname = wavname_ext.split('/')[-1] wavname = wavname.split('.')[0] fl = 5000 fwhm = 0.0004 fdr_threshold = 0.65 click_index, xn = find_click.find_click_fdr_tkeo( wave_data, frameRate, fl, fwhm, fdr_threshold, signal_len, 8) scale = (2**12 - 1) / max(xn) for i in np.arange(xn.size): xn[i] = xn[i] * scale click_arr = [] for j in range(click_index.shape[0]): index = click_index[j] # click_data = wave_data[index[0]:index[1], 0] click_data = xn[index[0]:index[1]] # 信噪比过滤 detected_clicks_energy = calcu_click_energy( click_data.reshape(1, -1)) noise_estimate1 = xn[index[0] - 256:index[0]] noise_estimate2 = xn[index[1] + 1:index[1] + 257] noise_estimate = np.hstack((noise_estimate1, noise_estimate2)) noise_energy = calcu_energy(noise_estimate) if noise_energy <= 0 or detected_clicks_energy <= 0: continue snr = 10 * math.log10(detected_clicks_energy / noise_energy) if snr < snr_threshold_low or snr > snr_threshold_high: continue click_data = resample(click_data, frameRate, tar_fs) # 前置TKEO前 click_data = cut_data(click_data, signal_len) click_data = click_data.astype(np.short) click_arr.append(click_data) # filename = "%(path)s/%(pre)s_click_%(n)06d.wav" % {'path': dst_path, 'pre': wavname, 'n': count} # f = wave.open(filename, "wb") # # set wav params # f.setnchannels(1) # f.setsampwidth(2) # f.setframerate(tar_fs) # # turn the data to string # f.writeframes(click_data.tostring()) # f.close() count = count + 1 dst = "%(path)s/%(pre)s_N%(num)d.npy" \ % {'path': dst_path, 'pre': wavname, 'num': len(click_arr)} print(dst) np.save(dst, np.array(click_arr, dtype=np.short)) # if count > 20000: # break print("count = %(count)d" % {'count': count})
def load_lwy_data(batch_num=20, n_total=500): dict = { '0': '', '1': '', '2': '', '3': '', '4': '', '5': '', '6': '', '7': '' } dict[ "0"] = "/home/fish/ROBB/CNN_click/click/WavData/BBW/Blainvilles_beaked_whale_(Mesoplodon_densirostris)" dict[ "1"] = "/home/fish/ROBB/CNN_click/click/WavData/Gm/Pilot_whale_(Globicephala_macrorhynchus)" dict[ "2"] = "/home/fish/ROBB/CNN_click/click/WavData/Gg/Rissos_(Grampus_grisieus)" dict["3"] = "/home/fish/ROBB/CNN_click/click/WavData/Tt/palmyra2006" dict["4"] = "/home/fish/ROBB/CNN_click/click/WavData/Dc/Dc" dict["5"] = "/home/fish/ROBB/CNN_click/click/WavData/Dd/Dd" dict["6"] = "/home/fish/ROBB/CNN_click/click/WavData/Melon/palmyra2006" dict["7"] = "/home/fish/ROBB/CNN_click/click/WavData/Spinner/palmyra2006" n_class = len(dict) train_xs = np.empty((0, 192)) train_ys = np.empty((0, n_class)) test_xs = np.empty((0, 192)) test_ys = np.empty((0, n_class)) for key in dict: # path = "%(path)s/%(class)d" % {'path': data_path, 'class': c} path = dict[key] c = int(key) wav_files = find_click.list_wav_files(path) print("load data : %s, the number of files : %d, class: %d" % (path, len(wav_files), c)) label = np.zeros(n_class) label[c] = 1 # xs = np.empty((0, 256)) xs = np.empty((0, 320)) count = 0 # for pathname in wav_files: wave_data, frame_rate = find_click.read_wav_file(pathname) # energy = np.sqrt(np.sum(wave_data ** 2)) # wave_data /= energy wave_data = np.reshape(wave_data, [-1]) xs = np.vstack((xs, wave_data)) count += 1 if count >= (batch_num + 10) * n_total: break xs0, xs1 = split_data(xs) temp_train_xs = random_crop(xs0, batch_num, int(n_total * 4 / 5)) temp_test_xs = random_crop(xs1, batch_num, int(n_total / 5)) temp_train_ys = np.tile(label, (temp_train_xs.shape[0], 1)) temp_test_ys = np.tile(label, (temp_test_xs.shape[0], 1)) train_xs = np.vstack((train_xs, temp_train_xs)) train_ys = np.vstack((train_ys, temp_train_ys)) test_xs = np.vstack((test_xs, temp_test_xs)) test_ys = np.vstack((test_ys, temp_test_ys)) return train_xs, train_ys, test_xs, test_ys
"1"] = "/media/ywy/本地磁盘/Data/MobySound/3rd_Workshop/Training_Data/Pilot_whale_(Globicephala_macrorhynchus)" dict[ "2"] = "/media/ywy/本地磁盘/Data/MobySound/3rd_Workshop/Training_Data/Rissos_(Grampus_grisieus)" for key in dict: print(dict[key]) count = 0 wav_files = find_click.list_wav_files(dict[key]) dst_path = "./Data/ClickC8/%(class)s" % {'class': key} mkdir(dst_path) for pathname in wav_files: print(pathname) wave_data, frameRate = find_click.read_wav_file(pathname) fl = 5000 fwhm = 0.0008 fdr_threshold = 0.62 click_index, xn = find_click.find_click_fdr_tkeo( wave_data, frameRate, fl, fwhm, fdr_threshold, signal_len, 8) scale = (2**15 - 1) / max(xn) for i in np.arange(xn.size): xn[i] = xn[i] * scale for j in range(click_index.shape[0]): index = click_index[j] # click_data = wave_data[index[0]:index[1], 0]
def load_data_lstm(data_path, n_class, batch_num=20, n_total=500): train = [] test = [] x_in = tf.placeholder("float", [None, 192]) # 输入 x_image = tf.reshape(x_in, [-1, 1, 192, 1]) # 第一个卷积层 W_conv1 = weight_variable([1, 5, 1, 32]) b_conv1 = bias_variable([32]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) h_pool1 = max_pool_1x2(h_conv1) # 第二个卷积层 W_conv2 = weight_variable([1, 5, 32, 32]) b_conv2 = bias_variable([32]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_1x2(h_conv2) # 密集链接层 W_fc1 = weight_variable([1 * 48 * 32, 256]) b_fc1 = bias_variable([256]) h_pool2_flat = tf.reshape(h_pool2, [-1, 1 * 48 * 32]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) # Dropout keep_prob = tf.placeholder("float") h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob=keep_prob) # 输出层 W_fc2 = weight_variable([256, n_class]) b_fc2 = bias_variable([n_class]) y = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) saver.restore(sess, "params/cnn_net.ckpt") # 加载训练好的网络参数 test_cnn = [] for c in range(0, n_class): # path = "./Data/Click/%(class)d" % {'class': c} path = "%(path)s/%(class)d" % {'path': data_path, 'class': c} wav_files = find_click.list_wav_files(path) print("load data : %s, the number of files : %d" % (path, len(wav_files))) # xs = np.empty((0, 256)) xs = np.empty((0, 320)) count = 0 for pathname in wav_files: wave_data, frame_rate = find_click.read_wav_file(pathname) energy = np.sqrt(np.sum(wave_data**2)) wave_data /= energy wave_data = np.reshape(wave_data, [-1]) xs = np.vstack((xs, wave_data)) count += 1 if count > batch_num * n_total: break xs0, xs1 = split_data(xs) sample_num = xs0.shape[0] for i in range(0, int(n_total * 4 / 5)): frames = np.empty((0, 256)) for j in range(batch_num * i, batch_num * (i + 1)): index = j % sample_num temp_x = xs0[index] # beg_idx = np.random.randint(0, 32) beg_idx = np.random.randint(64, (64 + 32)) crop_x = temp_x[beg_idx:(beg_idx + 192)] crop_x = np.reshape(crop_x, [1, 192]) ftu = sess.run(h_fc1, feed_dict={x_in: crop_x}) # 计算CNN网络输出 frames = np.vstack((frames, ftu)) frames = np.expand_dims(np.expand_dims(frames, axis=0), axis=0) frames = list(frames) label = [0] * n_class label[c] = 1 label = np.array([[label]]) label = list(label) sample = frames + label train.append(sample) sample_num = xs1.shape[0] for i in range(0, int(n_total / 5)): frames = np.empty((0, 256)) tmp_xs = np.empty((0, 192)) for j in range(batch_num * i, batch_num * (i + 1)): index = j % sample_num temp_x = xs1[index] # beg_idx = np.random.randint(0, 32) beg_idx = np.random.randint(64, (64 + 32)) crop_x = temp_x[beg_idx:(beg_idx + 192)] crop_x = np.reshape(crop_x, [1, 192]) ftu = sess.run(h_fc1, feed_dict={x_in: crop_x}) # 计算CNN网络输出 frames = np.vstack((frames, ftu)) tmp_xs = np.vstack((tmp_xs, crop_x)) frames = np.expand_dims(np.expand_dims(frames, axis=0), axis=0) frames = list(frames) label = [0] * n_class label[c] = 1 label = np.array([[label]]) label = list(label) sample = frames + label test.append(sample) tmp_xs = np.expand_dims(np.expand_dims(tmp_xs, axis=0), axis=0) tmp_xs = list(tmp_xs) sample = tmp_xs + label test_cnn.append(sample) count = 0 for i in range(len(test_cnn)): test_xs = test_cnn[i][0] label = np.zeros(n_class) for j in range(0, test_xs.shape[1]): txs = test_xs[0, j, :] txs = np.reshape(txs, [1, 192]) out_y = sess.run(y, feed_dict={x_in: txs, keep_prob: 1.0}) c = np.argmax(out_y, 1) label[c] += 1 ref_y = test_cnn[i][1] if np.equal(np.argmax(label), np.argmax(ref_y)): count += 1 print('cnn test accuracy: ', round(count / len(test_cnn), 3)) return train, test
def test_cnn_batch_data(data_path, n_class, batch_num=20, n_total=500): click_batch = [] for c in range(0, n_class): path = "%(path)s/%(class)d" % {'path': data_path, 'class': c} wav_files = find_click.list_wav_files(path) print("load data : %s, the number of files : %d" % (path, len(wav_files))) # 为避免训练网络用的Click用于测试, 类似于训练时区分训练和测试样本 # 利用全部样本后1/5的Click生成测试样本 xs = np.empty((0, 320)) count = 0 split_idx = int(len(wav_files) * 4 / 5) for pathname in wav_files: count += 1 if count < split_idx: continue wave_data, frame_rate = find_click.read_wav_file(pathname) energy = np.sqrt(np.sum(wave_data ** 2)) wave_data /= energy wave_data = np.reshape(wave_data, [-1]) xs = np.vstack((xs, wave_data)) if count >= batch_num * n_total: break sample_num = xs.shape[0] for i in range(0, int(n_total / 5)): tmp_xs = np.empty((0, 192)) for j in range(batch_num * i, batch_num * (i + 1)): index = j % sample_num temp_x = xs[index] beg_idx = np.random.randint(64, (64 + 32)) crop_x = temp_x[beg_idx:(beg_idx + 192)] crop_x = np.reshape(crop_x, [1, 192]) tmp_xs = np.vstack((tmp_xs, crop_x)) label = [0] * n_class label[c] = 1 label = np.array([[label]]) label = list(label) tmp_xs = np.expand_dims(np.expand_dims(tmp_xs, axis=0), axis=0) tmp_xs = list(tmp_xs) sample = tmp_xs + label click_batch.append(sample) x = tf.placeholder("float", [None, 192]) # 输入 x_image = tf.reshape(x, [-1, 1, 192, 1]) # 第一个卷积层 W_conv1 = weight_variable([1, 5, 1, 32]) b_conv1 = bias_variable([32]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) h_pool1 = max_pool_1x2(h_conv1) # 第二个卷积层 W_conv2 = weight_variable([1, 5, 32, 32]) b_conv2 = bias_variable([32]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_1x2(h_conv2) # 密集链接层 W_fc1 = weight_variable([1 * 48 * 32, 256]) b_fc1 = bias_variable([256]) h_pool2_flat = tf.reshape(h_pool2, [-1, 1 * 48 * 32]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) # Dropout keep_prob = tf.placeholder("float") h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob=keep_prob) # 输出层 W_fc2 = weight_variable([256, n_class]) b_fc2 = bias_variable([n_class]) y = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) saver.restore(sess, "params/cnn_net.ckpt") # 加载训练好的网络参数 count = 0 for i in range(len(click_batch)): temp_xs = click_batch[i][0] label = np.zeros(n_class) for j in range(0, temp_xs.shape[1]): txs = temp_xs[0, j, :] txs = np.reshape(txs, [1, 192]) out_y = sess.run(y, feed_dict={x: txs, keep_prob: 1.0}) c = np.argmax(out_y, 1) label[c] += 1 ref_y = click_batch[i][1] if np.equal(np.argmax(label), np.argmax(ref_y)): count += 1 print('cnn test accuracy (majority voting): ', round(count / len(click_batch), 3)) count = 0 weight = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 for i in range(len(click_batch)): temp_xs = click_batch[i][0] label = np.zeros(n_class) for j in range(0, temp_xs.shape[1]): txs = temp_xs[0, j, :] txs = np.reshape(txs, [1, 192]) out = sess.run(weight, feed_dict={x: txs, keep_prob: 1.0}) out = np.reshape(out, label.shape) label = label + out ref_y = click_batch[i][1] if np.equal(np.argmax(label), np.argmax(ref_y)): count += 1 print('cnn test accuracy (weight voting): ', round(count / len(click_batch), 3)) count = 0 for i in range(len(click_batch)): temp_xs = click_batch[i][0] label = np.zeros(n_class) for j in range(0, temp_xs.shape[1]): txs = temp_xs[0, j, :] txs = np.reshape(txs, [1, 192]) out = sess.run(y, feed_dict={x: txs, keep_prob: 1.0}) out = np.reshape(out, label.shape) label = label + out ref_y = click_batch[i][1] if np.equal(np.argmax(label), np.argmax(ref_y)): count += 1 print('cnn test accuracy (sum of softmax voting): ', round(count / len(click_batch), 3))