Ejemplo n.º 1
0
def add_classes(db_path, class_csv):
    db = datasetSQL.LabelSet(db_path)
    with open(class_csv, 'r') as f:
        for item in csv.DictReader(f):
            class_item = {}
            class_item['ASID'] = item['mid']
            class_item['class_name'] = item['display_name'].replace("'", "''")
            print(item)
            db.__insert__(class_item, 'classes')
        db.__commit__()
Ejemplo n.º 2
0
def initiate_database(db_path, segment_csv):
    db = datasetSQL.LabelSet(db_path)
    db.initialize()
    with open(segment_csv, 'r') as f:
        for item in csv.DictReader(f):
            segment = {}
            segment['segment_id'] = item['YTID']
            print(segment)
            db.__insert__(segment, 'segments')
    db.__commit__()
Ejemplo n.º 3
0
def evaluate(weight_path, db_path, feature_path):
    db = datasetSQL.LabelSet(db_path)
    h5r = h5py.File(feature_path, 'r')
    h5w = h5py.File('/tmp/esc_tmp_1.hdf5', 'w')
    db.cursor.execute(
        "SELECT segment_id FROM segments WHERE audio_file NOT NULL ORDER BY segment_id ASC;"
    )
    segment_list = [
        record[0].decode('utf-8') for record in db.cursor.fetchall()
    ]
    n_segment = len(segment_list)
    h5w.create_dataset('max', data=np.zeros((n_segment, embedding_length)))

    emb_net = network_arch.EmbNet().to(device)
    emb_net.load_weight(weight_path)
    emb_net.eval()

    for i, segment_id in enumerate(segment_list):
        f = h5r[segment_id][:]

        if len(f) < 161:
            f = padding(f, 161)
        n, D = f.shape
        data = np.zeros((1, 1, n, D))
        data[0, 0, :, :] = f
        torch_data = torch.from_numpy(data).float().to(device)

        with torch.no_grad():
            #pred = netx(torch_data)[0]
            """
            pred = emb_net(torch_data)
            #print (pred.size())

            if len(pred.size()) > 2:
                embedding = torch.max(pred, 2)[0]
                embedding = embedding.view(embedding.size(0),-1)
            else:
                embedding = pred
            """
            embedding = emb_net(torch_data)
            embedding = embedding.mean(dim=1)
        #print (embedding.size())

        h5w['max'][i] = embedding.cpu().numpy()[0]
        h5w.create_dataset(segment_id, data=embedding.cpu().numpy()[0])
    h5w.close()

    h5r2 = h5py.File('/tmp/esc_tmp_1.hdf5', 'r')
    h5w2 = h5py.File('/tmp/esc_tmp_dist_1.hdf5', 'w')
    similarity.Dist_gpu(h5r2, h5w2)
    h5w2.close()

    h5r3 = h5py.File('/tmp/esc_tmp_dist_1.hdf5', 'r')
    return (similarity_analysis.mAP2(h5r3, db))
Ejemplo n.º 4
0
def add_labels(db_path, segment_csv):
    db = datasetSQL.LabelSet(db_path)
    with open(segment_csv, 'r') as f:
        for line in f.readlines():
            segment_id = line.split(',')[0]
            if segment_id == "YTID":  #title line
                continue
            labels = ','.join(line.split(',')[3:]).replace('"', '').strip()
            for label in labels.split(','):
                sql = """
                INSERT INTO labels (segment_id, class_id, label_type)                
                SELECT segments.segment_id, classes.class_id, 0 FROM segments CROSS JOIN classes WHERE segments.segment_id = '{0}' AND classes.ASID = '{1}'
                """.format(segment_id, label)
                print(sql)
                db.cursor.execute(sql)
    db.__commit__()
Ejemplo n.º 5
0
def link_audio_file(db_path, filecsv):
    db = datasetSQL.LabelSet(db_path)
    with open(filecsv, 'r') as f:
        #for item in csv.DictReader(f):
        for line in f.readlines():
            item = {}
            line = line.rstrip()
            item["YTID"] = line.split(',')[-1]
            item["filename"] = ','.join(line.split(',')[:-1]) + '.wav'

            sql = """
            UPDATE segments SET audio_file = '{0}'
            WHERE segment_id = '{1}'
            """.format(item['filename'].replace("'", "''"), item['YTID'])
            db.cursor.execute(sql)
        db.__commit__()
Ejemplo n.º 6
0
def compute_features(db_path, feature_path, wav_root):
    h5w = h5py.File(feature_path, 'w')
    db = datasetSQL.LabelSet(db_path)
    #trg_sr = 48000
    trg_sr = 16000

    sql = """
    SELECT segment_id, audio_file FROM segments
    WHERE audio_file NOT NULL
    """
    segment_list = db.cursor.execute(sql)
    for segment_tuple in segment_list:
        segment_id, audio_file = segment_tuple[0].decode(
            'utf-8'), segment_tuple[1].decode('utf-8')
        print(segment_id, audio_file)
        """
        y, src_sr = soundfile.read(os.path.join(wav_root, audio_file))
        if len(y.shape) > 1:
            y = y[:,0]
        y = librosa.core.resample(y, src_sr, trg_sr)
        """
        y = librosa.load(os.path.join(wav_root, audio_file), trg_sr)[0]
        try:
            #mel = librosa.feature.melspectrogram(y,trg_sr,n_fft=1024,hop_length=512,n_mels=128)
            mel = librosa.feature.melspectrogram(y,
                                                 trg_sr,
                                                 n_fft=1024,
                                                 hop_length=500,
                                                 n_mels=64)
            log_mel = librosa.power_to_db(mel).T
            print(log_mel.shape)
            h5w.create_dataset(segment_id, data=log_mel)

        except:
            print("Failure:", segment_id, audio_file)
            sql = """
            UPDATE segments SET audio_file=NULL WHERE segment_id = '{0}'
            """.format(segment_id)
            db.cursor.execute(sql)
            db.__commit__()
            continue
    return
Ejemplo n.º 7
0
def train(db_path, feature_path):
    db = datasetSQL.LabelSet(db_path)
    h5r = h5py.File(feature_path, 'r')
    db.cursor.execute("SELECT COUNT(*) FROM classes;")
    n_class = db.cursor.fetchone()[0]

    cascade_net = network_arch.CascadeNet(n_class).to(device)

    if os.path.isfile(tmp_model_weight_path1):
        cascade_net.load_weight(tmp_model_weight_path1)

    optimizer = optim.Adam(cascade_net.parameters(), lr=1e-4)
    loss_function = nn.BCELoss()

    cascade_net.train()
    best_epoch = 0
    class_list = np.random.permutation(range(2, n_class + 1))

    for i_epoch in range(100):
        losses = 0

        db.cursor.execute(
            "SELECT segment_id FROM segments WHERE audio_file NOT NULL")

        #segment_list = [record[0].decode('utf-8') for record in db.cursor.fetchall()]
        segment_list = [
            record[0].decode('utf-8').strip()
            for record in db.cursor.fetchall()
        ]
        n_segment = len(segment_list)
        print(i_epoch, n_segment)

        order_list = np.random.permutation(range(n_segment))
        for start_index in range(0, n_segment, training_batch_size):
            if start_index + training_batch_size > n_segment:
                continue
            batch_data = np.zeros((training_batch_size, 1, target_input_length,
                                   target_input_dim))
            batch_target = np.zeros((training_batch_size, n_class))
            optimizer.zero_grad()

            for i in range(training_batch_size):
                segment_id = segment_list[order_list[start_index + i]]
                try:
                    f = h5r[segment_id][:]
                    if len(f) != target_input_length:
                        f = padding(f)
                    batch_data[i, 0, :, :] = f

                    sql = """
                    SELECT class_id FROM labels WHERE segment_id = '{0}'
                    """.format(segment_id)
                    db.cursor.execute(sql)
                    records = db.cursor.fetchall()
                    for record in records:
                        batch_target[i, record[0] - 1] = 1

                except:
                    print(segment_id + ' None!')

            torch_train = torch.from_numpy(batch_data).float().to(device)
            torch_target = torch.from_numpy(batch_target).float().to(device)

            for i in range(1):
                torch_output = cascade_net(torch_train)
                loss = loss_function(torch_output, torch_target)
                loss.backward()
                optimizer.step()
                print(start_index, loss)
                losses += loss

        print("epoch {0} loss: {1}".format(i_epoch, losses))
        torch.save(cascade_net.state_dict(), tmp_model_weight_path1)
        torch.save(cascade_net.emb_net.state_dict(), tmp_model_weight_path0)

        print("Model has been saved...")

        if i_epoch % 1 == 0:
            criteria_i = evaluate(tmp_model_weight_path0,
                                  'database/db_esc10.sqlite',
                                  'database/mel_esc10.hdf5')
            if criteria_i > best_epoch:
                best_epoch = criteria_i
                copyfile(tmp_model_weight_path0, best_model_weight_path0)
                copyfile(tmp_model_weight_path1, best_model_weight_path1)
Ejemplo n.º 8
0
def refine_labels(db_path):
    #Strategy: loop five times to ensure all the existing labels have their parental classes labeled.
    import json
    with open(ontology_json, 'r') as f:
        ontology_list = json.load(f)
    db = datasetSQL.LabelSet(db_path)

    sql = """
       ALTER TABLE classes
       ADD COLUMN leaf_node BOOL
    """
    try:
        db.cursor.execute(sql)
    except:
        pass

    for loop_i in range(5):
        for class_dict in ontology_list:
            parent_asid = class_dict['id']
            sql = """
            SELECT class_id from classes WHERE ASID = '{0}'
            """.format(parent_asid)
            db.cursor.execute(sql)
            record = db.cursor.fetchone()
            if record:  #Some class in ontology does not exist in dataset
                parent_id = record[0]
            else:
                continue

            if len(class_dict['child_ids']) == 0:
                sql = """
                UPDATE classes SET leaf_node=1
                WHERE class_id = {0}
                """.format(parent_id)
            else:
                sql = """
                UPDATE classes SET leaf_node=0
                WHERE class_id = {0}
                """.format(parent_id)

            db.cursor.execute(sql)

            for child_asid in class_dict['child_ids']:
                sql = """
                SELECT class_id from classes WHERE ASID = '{0}'
                """.format(child_asid)
                db.cursor.execute(sql)
                record = db.cursor.fetchone()
                if record:  #Some class in ontology does not exist in dataset
                    child_id = record[0]
                else:
                    continue
                print(parent_asid, child_asid)
                sql = """
                INSERT INTO labels (segment_id, class_id, label_type)
                SELECT segment_id, {1},0 from segments
                WHERE segment_id IN
                (SELECT segment_id FROM labels WHERE class_id = {0})
                AND segment_id NOT IN
                (SELECT segment_id FROM labels WHERE class_id = {1})                
                """.format(child_id, parent_id)
                db.cursor.execute(sql)
                #print (db.cursor.fetchall())

    db.__commit__()
    db.__close__()