def add_classes(db_path, class_csv): db = datasetSQL.LabelSet(db_path) with open(class_csv, 'r') as f: for item in csv.DictReader(f): class_item = {} class_item['ASID'] = item['mid'] class_item['class_name'] = item['display_name'].replace("'", "''") print(item) db.__insert__(class_item, 'classes') db.__commit__()
def initiate_database(db_path, segment_csv): db = datasetSQL.LabelSet(db_path) db.initialize() with open(segment_csv, 'r') as f: for item in csv.DictReader(f): segment = {} segment['segment_id'] = item['YTID'] print(segment) db.__insert__(segment, 'segments') db.__commit__()
def evaluate(weight_path, db_path, feature_path): db = datasetSQL.LabelSet(db_path) h5r = h5py.File(feature_path, 'r') h5w = h5py.File('/tmp/esc_tmp_1.hdf5', 'w') db.cursor.execute( "SELECT segment_id FROM segments WHERE audio_file NOT NULL ORDER BY segment_id ASC;" ) segment_list = [ record[0].decode('utf-8') for record in db.cursor.fetchall() ] n_segment = len(segment_list) h5w.create_dataset('max', data=np.zeros((n_segment, embedding_length))) emb_net = network_arch.EmbNet().to(device) emb_net.load_weight(weight_path) emb_net.eval() for i, segment_id in enumerate(segment_list): f = h5r[segment_id][:] if len(f) < 161: f = padding(f, 161) n, D = f.shape data = np.zeros((1, 1, n, D)) data[0, 0, :, :] = f torch_data = torch.from_numpy(data).float().to(device) with torch.no_grad(): #pred = netx(torch_data)[0] """ pred = emb_net(torch_data) #print (pred.size()) if len(pred.size()) > 2: embedding = torch.max(pred, 2)[0] embedding = embedding.view(embedding.size(0),-1) else: embedding = pred """ embedding = emb_net(torch_data) embedding = embedding.mean(dim=1) #print (embedding.size()) h5w['max'][i] = embedding.cpu().numpy()[0] h5w.create_dataset(segment_id, data=embedding.cpu().numpy()[0]) h5w.close() h5r2 = h5py.File('/tmp/esc_tmp_1.hdf5', 'r') h5w2 = h5py.File('/tmp/esc_tmp_dist_1.hdf5', 'w') similarity.Dist_gpu(h5r2, h5w2) h5w2.close() h5r3 = h5py.File('/tmp/esc_tmp_dist_1.hdf5', 'r') return (similarity_analysis.mAP2(h5r3, db))
def add_labels(db_path, segment_csv): db = datasetSQL.LabelSet(db_path) with open(segment_csv, 'r') as f: for line in f.readlines(): segment_id = line.split(',')[0] if segment_id == "YTID": #title line continue labels = ','.join(line.split(',')[3:]).replace('"', '').strip() for label in labels.split(','): sql = """ INSERT INTO labels (segment_id, class_id, label_type) SELECT segments.segment_id, classes.class_id, 0 FROM segments CROSS JOIN classes WHERE segments.segment_id = '{0}' AND classes.ASID = '{1}' """.format(segment_id, label) print(sql) db.cursor.execute(sql) db.__commit__()
def link_audio_file(db_path, filecsv): db = datasetSQL.LabelSet(db_path) with open(filecsv, 'r') as f: #for item in csv.DictReader(f): for line in f.readlines(): item = {} line = line.rstrip() item["YTID"] = line.split(',')[-1] item["filename"] = ','.join(line.split(',')[:-1]) + '.wav' sql = """ UPDATE segments SET audio_file = '{0}' WHERE segment_id = '{1}' """.format(item['filename'].replace("'", "''"), item['YTID']) db.cursor.execute(sql) db.__commit__()
def compute_features(db_path, feature_path, wav_root): h5w = h5py.File(feature_path, 'w') db = datasetSQL.LabelSet(db_path) #trg_sr = 48000 trg_sr = 16000 sql = """ SELECT segment_id, audio_file FROM segments WHERE audio_file NOT NULL """ segment_list = db.cursor.execute(sql) for segment_tuple in segment_list: segment_id, audio_file = segment_tuple[0].decode( 'utf-8'), segment_tuple[1].decode('utf-8') print(segment_id, audio_file) """ y, src_sr = soundfile.read(os.path.join(wav_root, audio_file)) if len(y.shape) > 1: y = y[:,0] y = librosa.core.resample(y, src_sr, trg_sr) """ y = librosa.load(os.path.join(wav_root, audio_file), trg_sr)[0] try: #mel = librosa.feature.melspectrogram(y,trg_sr,n_fft=1024,hop_length=512,n_mels=128) mel = librosa.feature.melspectrogram(y, trg_sr, n_fft=1024, hop_length=500, n_mels=64) log_mel = librosa.power_to_db(mel).T print(log_mel.shape) h5w.create_dataset(segment_id, data=log_mel) except: print("Failure:", segment_id, audio_file) sql = """ UPDATE segments SET audio_file=NULL WHERE segment_id = '{0}' """.format(segment_id) db.cursor.execute(sql) db.__commit__() continue return
def train(db_path, feature_path): db = datasetSQL.LabelSet(db_path) h5r = h5py.File(feature_path, 'r') db.cursor.execute("SELECT COUNT(*) FROM classes;") n_class = db.cursor.fetchone()[0] cascade_net = network_arch.CascadeNet(n_class).to(device) if os.path.isfile(tmp_model_weight_path1): cascade_net.load_weight(tmp_model_weight_path1) optimizer = optim.Adam(cascade_net.parameters(), lr=1e-4) loss_function = nn.BCELoss() cascade_net.train() best_epoch = 0 class_list = np.random.permutation(range(2, n_class + 1)) for i_epoch in range(100): losses = 0 db.cursor.execute( "SELECT segment_id FROM segments WHERE audio_file NOT NULL") #segment_list = [record[0].decode('utf-8') for record in db.cursor.fetchall()] segment_list = [ record[0].decode('utf-8').strip() for record in db.cursor.fetchall() ] n_segment = len(segment_list) print(i_epoch, n_segment) order_list = np.random.permutation(range(n_segment)) for start_index in range(0, n_segment, training_batch_size): if start_index + training_batch_size > n_segment: continue batch_data = np.zeros((training_batch_size, 1, target_input_length, target_input_dim)) batch_target = np.zeros((training_batch_size, n_class)) optimizer.zero_grad() for i in range(training_batch_size): segment_id = segment_list[order_list[start_index + i]] try: f = h5r[segment_id][:] if len(f) != target_input_length: f = padding(f) batch_data[i, 0, :, :] = f sql = """ SELECT class_id FROM labels WHERE segment_id = '{0}' """.format(segment_id) db.cursor.execute(sql) records = db.cursor.fetchall() for record in records: batch_target[i, record[0] - 1] = 1 except: print(segment_id + ' None!') torch_train = torch.from_numpy(batch_data).float().to(device) torch_target = torch.from_numpy(batch_target).float().to(device) for i in range(1): torch_output = cascade_net(torch_train) loss = loss_function(torch_output, torch_target) loss.backward() optimizer.step() print(start_index, loss) losses += loss print("epoch {0} loss: {1}".format(i_epoch, losses)) torch.save(cascade_net.state_dict(), tmp_model_weight_path1) torch.save(cascade_net.emb_net.state_dict(), tmp_model_weight_path0) print("Model has been saved...") if i_epoch % 1 == 0: criteria_i = evaluate(tmp_model_weight_path0, 'database/db_esc10.sqlite', 'database/mel_esc10.hdf5') if criteria_i > best_epoch: best_epoch = criteria_i copyfile(tmp_model_weight_path0, best_model_weight_path0) copyfile(tmp_model_weight_path1, best_model_weight_path1)
def refine_labels(db_path): #Strategy: loop five times to ensure all the existing labels have their parental classes labeled. import json with open(ontology_json, 'r') as f: ontology_list = json.load(f) db = datasetSQL.LabelSet(db_path) sql = """ ALTER TABLE classes ADD COLUMN leaf_node BOOL """ try: db.cursor.execute(sql) except: pass for loop_i in range(5): for class_dict in ontology_list: parent_asid = class_dict['id'] sql = """ SELECT class_id from classes WHERE ASID = '{0}' """.format(parent_asid) db.cursor.execute(sql) record = db.cursor.fetchone() if record: #Some class in ontology does not exist in dataset parent_id = record[0] else: continue if len(class_dict['child_ids']) == 0: sql = """ UPDATE classes SET leaf_node=1 WHERE class_id = {0} """.format(parent_id) else: sql = """ UPDATE classes SET leaf_node=0 WHERE class_id = {0} """.format(parent_id) db.cursor.execute(sql) for child_asid in class_dict['child_ids']: sql = """ SELECT class_id from classes WHERE ASID = '{0}' """.format(child_asid) db.cursor.execute(sql) record = db.cursor.fetchone() if record: #Some class in ontology does not exist in dataset child_id = record[0] else: continue print(parent_asid, child_asid) sql = """ INSERT INTO labels (segment_id, class_id, label_type) SELECT segment_id, {1},0 from segments WHERE segment_id IN (SELECT segment_id FROM labels WHERE class_id = {0}) AND segment_id NOT IN (SELECT segment_id FROM labels WHERE class_id = {1}) """.format(child_id, parent_id) db.cursor.execute(sql) #print (db.cursor.fetchall()) db.__commit__() db.__close__()