def plot_logmel(args):
    """Plot log Mel feature of one audio per class. 
    """

    # Arguments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    
    sample_rate = config.sample_rate
    window_size = config.window_size
    overlap = config.overlap
    seq_len = config.seq_len
    mel_bins = config.mel_bins
    
    plot_num = 12
    
    # Paths
    meta_csv = os.path.join(workspace, 'validation.csv')
    audios_dir = os.path.join(dataset_dir, 'wav')
    
    # Feature extractor
    feature_extractor = LogMelExtractor(sample_rate=sample_rate, 
                                        window_size=window_size, 
                                        overlap=overlap, 
                                        mel_bins=mel_bins)
    
    

    # Calculate log mel feature of audio clips
    df = pd.read_csv(meta_csv, sep=',')
    df = pd.DataFrame(df)
    
    n = 0
    itemids = []
    features = []
    hasbirds = []
    
    for row in df.iterrows():
        
        if n == 12:
            break
        
        itemid = row[1]['itemid']
        hasbird = row[1]['hasbird']
    
        audio_path = os.path.join(audios_dir, '{}.wav'.format(itemid))
        
        feature = calculate_logmel(audio_path=audio_path, 
                                   sample_rate=sample_rate, 
                                   feature_extractor=feature_extractor)
                
        itemids.append(itemid)
        features.append(feature)
        hasbirds.append(hasbird)
        
        n += 1
        
    # Plot
    rows_num = 3
    cols_num = 4
    n = 0
    
    fig, axs = plt.subplots(rows_num, cols_num, figsize=(10, 5))
    
    for n in range(plot_num):
        row = n // cols_num
        col = n % cols_num
        axs[row, col].matshow(features[n].T, origin='lower', aspect='auto', 
                              cmap='jet', vmin=-10, vmax=-2)
        axs[row, col].set_title('No. {}, hasbird={}'.format(n, hasbirds[n]))
        axs[row, col].set_ylabel('log mel bins')
        axs[row, col].yaxis.set_ticks([])
        axs[row, col].xaxis.set_ticks([0, seq_len])
        axs[row, col].xaxis.set_ticklabels(['0', '10 s'], fontsize='small')
        axs[row, col].xaxis.tick_bottom()
    
    for n in range(plot_num, rows_num * cols_num):
        row = n // cols_num
        col = n % cols_num
        axs[row, col].set_visible(False)
    
    for (n, itemid) in enumerate(itemids):
        print('No. {}, {}.wav'.format(n, itemid))
    
    fig.tight_layout()
    plt.show()
Ejemplo n.º 2
0
def plot_logmel(args):
    """Plot log Mel feature of one audio per class. 
    """

    # Arguments & parameters
    audios_dir = args.audios_dir

    sample_rate = config.sample_rate
    window_size = config.window_size
    overlap = config.overlap
    seq_len = config.seq_len
    mel_bins = config.mel_bins
    labels = config.labels

    # Paths
    audio_names = os.listdir(audios_dir)

    # Feature extractor
    feature_extractor = LogMelExtractor(sample_rate=sample_rate,
                                        window_size=window_size,
                                        overlap=overlap,
                                        mel_bins=mel_bins)

    #feature_list = []

    # Select one audio per class and extract feature

    chunk_length = 2000
    for audio_name in audio_names:
        if os.path.splitext(audio_name)[1] == '.wav':
            if not 'segment' in audio_name:
                audio = AudioSegment.from_wav(
                    os.path.join(audios_dir, audio_name))

                chunks = make_chunks(audio, chunk_length)

                chunks[1].export(os.path.join(audios_dir,
                                              'segment_' + audio_name),
                                 format='wav')

    for audio_name in audio_names:
        if os.path.splitext(audio_name)[1] == '.wav':
            if 'segment' in audio_name:

                audio_path = os.path.join(audios_dir, audio_name)

                feature = calculate_logmel(audio_path=audio_path,
                                           sample_rate=sample_rate,
                                           feature_extractor=feature_extractor)

                # feature_list.append(feature)
                # log mel spectrogram
                fig, axs = plt.subplots(1, 1)

                axs.matshow(feature.T,
                            origin='lower',
                            aspect='auto',
                            cmap='jet')
                axs.set_xlabel('Time (s)', fontsize=font)
                axs.set_ylabel('Mel bins', fontsize=font)
                axs.xaxis.set_ticks([0, 31, 61])
                axs.yaxis.set_ticks([0, 32, 63])
                axs.xaxis.set_ticklabels(['0', '1', '2'], fontsize=font)
                axs.yaxis.set_ticklabels(['0', '32', '64'], fontsize=font)
                axs.xaxis.tick_bottom()

                axs.spines['top'].set_visible(False)
                axs.spines['right'].set_visible(False)
                plt.savefig(
                    os.path.join(audios_dir,
                                 audio_name.split('.wav')[0] + '_logmel.pdf'))
                plt.close()

                # wave
                audio = AudioSegment.from_wav(
                    os.path.join(audios_dir, audio_name))
                chunk = audio.get_array_of_samples()
                chunk = np.array(chunk)
                time = np.arange(0, 8000)
                fig, axs = plt.subplots(1, 1)
                axs.plot(time, chunk)
                axs.set_xlabel('Time (s)', fontsize=font)
                axs.xaxis.set_label_coords(1, 0.4)
                axs.set_ylabel('Amplitude', fontsize=font)
                axs.xaxis.set_ticks([0, 8000])
                axs.yaxis.set_ticks([])
                axs.xaxis.set_ticklabels(['0', '3'], fontsize=font)
                plt.xlim([0, 8000])
                axs.spines['bottom'].set_position(('data', 0))
                axs.spines['top'].set_visible(False)
                axs.spines['right'].set_visible(False)
                plt.savefig(
                    os.path.join(audios_dir,
                                 audio_name.split('.wav')[0] + '_wave.pdf'))
                plt.close()

                # together
                fig, axs = plt.subplots(2, 1)
                axs[0].plot(time, chunk)
                axs[0].set_ylabel('Amplitude', fontsize=font)
                axs[0].xaxis.set_ticks([])
                axs[0].yaxis.set_ticks([])
                axs[0].set_xlim([0, 8000])
                axs[0].spines['bottom'].set_position(('data', 0))
                axs[0].spines['bottom'].set_color('gray')
                axs[0].spines['top'].set_visible(False)
                axs[0].spines['right'].set_visible(False)

                axs[1].matshow(feature.T,
                               origin='lower',
                               aspect='auto',
                               cmap='jet')
                axs[1].set_xlabel('Time (s)', fontsize=font)
                axs[1].set_ylabel('Mel bins', fontsize=font)
                axs[1].xaxis.set_ticks([0, 31, 61])
                axs[1].yaxis.set_ticks([0, 32, 63])
                axs[1].xaxis.set_ticklabels(['0', '1', '2'], fontsize=font)
                axs[1].yaxis.set_ticklabels(['0', '32', '64'], fontsize=font)
                axs[1].xaxis.tick_bottom()
                axs[1].spines['top'].set_visible(False)
                axs[1].spines['right'].set_visible(False)

                if '0008' in audio_name:
                    axs[1].axvline(x=2.5,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(3.5, 128, 'S1', fontsize=font - 2)

                    axs[1].axvline(x=7,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(7.5, 128, 'stole', fontsize=font - 2)

                    axs[1].axvline(x=12.5,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(13.2, 128, 'S2', fontsize=font - 2)

                    axs[1].axvline(x=16,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(18, 128, 'diastole', fontsize=font - 2)

                    axs[1].axvline(x=27,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(28.5, 128, 'S1', fontsize=font - 2)

                    axs[1].axvline(x=32,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(32.2, 128, 'stole', fontsize=font - 2)

                    axs[1].axvline(x=37,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(37.3, 128, 'S2', fontsize=font - 2)

                    axs[1].axvline(x=40,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(41.5, 128, 'diastole', fontsize=font - 2)

                    axs[1].axvline(x=50,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)
                    plt.text(51.5, 128, 'S1', fontsize=font - 2)

                    axs[1].axvline(x=56,
                                   color='red',
                                   linestyle='--',
                                   ymax=2.05,
                                   lw=1,
                                   clip_on=False)

                plt.subplots_adjust(wspace=0, hspace=0.05)
                plt.savefig(
                    os.path.join(
                        audios_dir,
                        audio_name.split('.wav')[0] + '_wave_logmel.pdf'))
                plt.close()
    '''
def plot_mel_masks(args):
    
    # Arugments & parameters
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    scene_type = args.scene_type
    snr = args.snr
    iteration = args.iteration
    model_type = args.model_type
    cuda = args.cuda

    labels = config.labels
    classes_num = len(labels)
    sample_rate = config.sample_rate
    window_size = config.window_size
    overlap = config.overlap
    hop_size = window_size-overlap
    mel_bins = config.mel_bins
    seq_len = config.seq_len
    ix_to_lb = config.ix_to_lb
    
    thres = 0.1
    batch_size = 24

    # Paths
    hdf5_path = os.path.join(workspace, 'features', 'logmel', 
        'scene_type={},snr={}'.format(scene_type, snr), 'development.h5')

    model_path = os.path.join(workspace, 'models', 'main_pytorch', 
        'model_type={}'.format(model_type), 'scene_type={},snr={}'
        ''.format(scene_type, snr), 'holdout_fold{}'.format(holdout_fold), 
        'md_{}_iters.tar'.format(iteration))
    
    yaml_path = os.path.join(workspace, 'mixture.yaml')
    
    audios_dir = os.path.join(workspace, 'mixed_audios', 
                              'scene_type={},snr={}'.format(scene_type, snr))
    
    sep_wavs_dir = os.path.join(workspace, 'separated_wavs', 'main_pytorch', 
        'model_type={}'.format(model_type), 
        'scene_type={},snr={}'.format(scene_type, snr), 
        'holdout_fold{}'.format(holdout_fold))
        
    create_folder(sep_wavs_dir)
    
    # Load yaml file
    load_yaml_time = time.time()
    with open(yaml_path, 'r') as f:
        meta = yaml.load(f)        
    print('Load yaml file time: {:.3f} s'.format(time.time() - load_yaml_time))
    
    feature_extractor = LogMelExtractor(
        sample_rate=sample_rate, 
        window_size=window_size, 
        overlap=overlap, 
        mel_bins=mel_bins)

    inverse_melW = feature_extractor.get_inverse_melW()
    
    # Load model
    Model = get_model(model_type)
    model = Model(classes_num, seq_len, mel_bins, cuda)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])

    if cuda:
        model.cuda()

    # Data generator
    generator = InferenceDataGenerator(
        hdf5_path=hdf5_path,
        batch_size=batch_size, 
        holdout_fold=holdout_fold)

    generate_func = generator.generate_validate(
        data_type='validate', 
        shuffle=False, 
        max_iteration=None)
    
    # Evaluate on mini-batch
    for (iteration, data) in enumerate(generate_func):
        
        (batch_x, batch_y, batch_audio_names) = data            
        batch_x = move_data_to_gpu(batch_x, cuda)

        # Predict
        with torch.no_grad():
            model.eval()
            (batch_output, batch_bottleneck) = model(
                batch_x, return_bottleneck=True)
    
        batch_output = batch_output.data.cpu().numpy()
        '''(batch_size, classes_num)'''
        
        batch_bottleneck = batch_bottleneck.data.cpu().numpy()  
        '''(batch_size, classes_num, seq_len, mel_bins)'''

        batch_pred_sed = np.mean(batch_bottleneck, axis=-1)
        batch_pred_sed = np.transpose(batch_pred_sed, (0, 2, 1))    
        '''(batch_size, seq_len, classes_num)'''
        
        batch_gt_masks = []
        
        for n in range(len(batch_audio_names)):
            curr_meta = search_meta_by_mixture_name(meta, batch_audio_names[n])
            curr_events = curr_meta['events']
              
            pred_indexes = np.where(batch_output[n] > thres)[0]
            gt_indexes = get_ground_truth_indexes(curr_events)
 
            gt_sed = get_sed_from_meta(curr_events) # (seq_len, classes_num)
            
            pred_sed = np.zeros((seq_len, classes_num))
            pred_sed[:, pred_indexes] = batch_pred_sed[n][:, pred_indexes]  # (seq_len, classes_num)
 
            (events_stft, scene_stft, _) = generator.get_events_scene_mixture_stft(batch_audio_names[n])
            events_stft = np.dot(events_stft, feature_extractor.melW)
            scene_stft = np.dot(scene_stft, feature_extractor.melW)
            
            gt_mask = ideal_binary_mask(events_stft, scene_stft)    # (seq_len, fft_size)
            
            gt_masks = gt_mask[:, :, None] * gt_sed[:, None, :] # (seq_len, fft_size, classes_num)
            gt_masks = gt_masks.astype(np.float32)
            batch_gt_masks.append(gt_masks)
            
            pred_masks = batch_bottleneck[n].transpose(1, 2, 0) # (seq_len, fft_size, classes_num)

            # Save out separated audio
            if True:
                curr_audio_name = curr_meta['mixture_name']
                audio_path = os.path.join(audios_dir, curr_audio_name)
                (mixed_audio, fs) = read_audio(audio_path, target_fs=sample_rate, mono=True)
                
                out_wav_path = os.path.join(sep_wavs_dir, curr_audio_name)
                write_audio(out_wav_path, mixed_audio, sample_rate)
                
                window = np.hamming(window_size)
                mixed_stft_cmplx = stft(x=mixed_audio, window_size=window_size, hop_size=hop_size, window=window, mode='complex')
                mixed_stft_cmplx = mixed_stft_cmplx[0 : seq_len, :]
                mixed_stft = np.abs(mixed_stft_cmplx)
                
                for k in gt_indexes:
                    masked_stft = np.dot(pred_masks[:, :, k], inverse_melW) * mixed_stft
                    masked_stft_cmplx = real_to_complex(masked_stft, mixed_stft_cmplx)
                    
                    frames = istft(masked_stft_cmplx)
                    cola_constant = get_cola_constant(hop_size, window)
                    sep_audio = overlap_add(frames, hop_size, cola_constant)
                    
                    sep_wav_path = os.path.join(sep_wavs_dir, '{}_{}.wav'.format(os.path.splitext(curr_audio_name)[0], ix_to_lb[k]))
                    write_audio(sep_wav_path, sep_audio, sample_rate)
                    print('Audio wrote to {}'.format(sep_wav_path))
      
        # Visualize learned representations
        if True:
            for n in range(len(batch_output)):
            
                # Plot segmentation masks. (00013.wav is used for plot in the paper)
                print('audio_name: {}'.format(batch_audio_names[n]))
                print('target: {}'.format(batch_y[n]))
                target_labels = target_to_labels(batch_y[n], labels)
                print('target labels: {}'.format(target_labels))
            
                (events_stft, scene_stft, _) = generator.get_events_scene_mixture_stft(batch_audio_names[n])
    
                fig, axs = plt.subplots(7, 7, figsize=(15, 10))
                for k in range(classes_num):
                    axs[k // 6, k % 6].matshow(batch_bottleneck[n, k].T, origin='lower', aspect='auto', cmap='jet')
                    if labels[k] in target_labels:
                        color = 'r'
                    else:
                        color = 'k'
                    axs[k // 6, k % 6].set_title(labels[k], color=color)
                    axs[k // 6, k % 6].xaxis.set_ticks([])
                    axs[k // 6, k % 6].yaxis.set_ticks([])
                    axs[k // 6, k % 6].set_xlabel('time')
                    axs[k // 6, k % 6].set_ylabel('mel bins')
                    
                axs[6, 5].matshow(np.log(events_stft + 1e-8).T, origin='lower', aspect='auto', cmap='jet')
                axs[6, 5].set_title('Spectrogram (in log scale)')
                axs[6, 5].xaxis.set_ticks([0, 310])
                axs[6, 5].xaxis.set_ticklabels(['0.0', '10.0 s'])
                axs[6, 5].xaxis.tick_bottom()
                axs[6, 5].yaxis.set_ticks([0, 1024])
                axs[6, 5].yaxis.set_ticklabels(['0', '1025'])
                axs[6, 5].set_xlabel('time')
                axs[6, 5].set_ylabel('FFT bins')
                
                axs[6, 6].matshow(np.log(np.dot(events_stft, feature_extractor.melW) + 1e-8).T, origin='lower', aspect='auto', cmap='jet')
                axs[6, 6].set_title('Log mel pectrogram')
                axs[6, 6].xaxis.set_ticks([0, 310])
                axs[6, 6].xaxis.set_ticklabels(['0.0', '10.0 s'])
                axs[6, 6].xaxis.tick_bottom()
                axs[6, 6].yaxis.set_ticks([0, 63])
                axs[6, 6].yaxis.set_ticklabels(['0', '64'])
                axs[6, 6].set_xlabel('time')
                axs[6, 6].set_ylabel('mel bins')
                
                plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
                plt.show()
                
                # Plot frame-wise SED
                fig, ax = plt.subplots(1, 1, figsize=(4, 4))
                score_mat = []
                for k in range(classes_num):
                    score = np.mean(batch_bottleneck[n, k], axis=-1)
                    score_mat.append(score)
                    
                score_mat = np.array(score_mat)
                
                ax.matshow(score_mat, origin='lower', aspect='auto', cmap='jet')
                ax.set_title('Frame-wise predictions')
                ax.xaxis.set_ticks([0, 310])
                ax.xaxis.set_ticklabels(['0.0', '10.0 s'])
                ax.xaxis.tick_bottom()
                ax.set_xlabel('time')
                ax.yaxis.set_ticks(np.arange(classes_num))
                ax.yaxis.set_ticklabels(config.labels, fontsize='xx-small')
                ax.yaxis.grid(color='k', linestyle='solid', linewidth=0.3)
                
                plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
                plt.show()
                
                # Plot event-wise SED
                est_event_list = get_est_event_list(batch_pred_sed[n:n+1], batch_audio_names[n:n+1], labels)
                event_mat = event_list_to_matrix(est_event_list)
                
                fig, ax = plt.subplots(1, 1, figsize=(4, 4))
                ax.matshow(event_mat.T, origin='lower', aspect='auto', cmap='jet')
                ax.set_title('Event-wise predictions')
                ax.xaxis.set_ticks([0, 310])
                ax.xaxis.set_ticklabels(['0.0', '10.0 s'])
                ax.xaxis.tick_bottom()
                ax.set_xlabel('time')
                ax.yaxis.set_ticks(np.arange(classes_num))
                ax.yaxis.set_ticklabels(config.labels, fontsize='xx-small')
                ax.yaxis.grid(color='k', linestyle='solid', linewidth=0.3)
                
                plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
                plt.show()
                
                # Plot event-wise ground truth
                ref_event_list = get_ref_event_list(meta, batch_audio_names[n:n+1])
                event_mat = event_list_to_matrix(ref_event_list)
                
                fig, ax = plt.subplots(1, 1, figsize=(4, 4))
                ax.matshow(event_mat.T, origin='lower', aspect='auto', cmap='jet')
                ax.set_title('Event-wise ground truth')
                ax.xaxis.set_ticks([0, 310])
                ax.xaxis.set_ticklabels(['0.0', '10.0 s'])
                ax.xaxis.tick_bottom()
                ax.set_xlabel('time')
                ax.yaxis.set_ticks(np.arange(classes_num))
                ax.yaxis.set_ticklabels(config.labels, fontsize='xx-small')
                ax.yaxis.grid(color='k', linestyle='solid', linewidth=0.3)
                
                plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
                plt.show()
def inference(args):

    # Arugments & parameters
    workspace = args.workspace
    model_type = args.model_type
    holdout_fold = args.holdout_fold
    scene_type = args.scene_type
    snr = args.snr
    iteration = args.iteration
    filename = args.filename
    cuda = args.cuda

    labels = config.labels
    classes_num = len(labels)
    sample_rate = config.sample_rate
    window_size = config.window_size
    overlap = config.overlap
    hop_size = window_size - overlap
    mel_bins = config.mel_bins
    seq_len = config.seq_len
    ix_to_lb = config.ix_to_lb

    threshold = 0.1

    # Paths
    hdf5_path = os.path.join(workspace, 'features', 'logmel',
                             'scene_type={},snr={}'.format(scene_type, snr),
                             'development.h5')

    model_path = os.path.join(
        workspace, 'models', filename, 'model_type={}'.format(model_type),
        'scene_type={},snr={}'
        ''.format(scene_type, snr), 'holdout_fold{}'.format(holdout_fold),
        'md_{}_iters.tar'.format(iteration))

    yaml_path = os.path.join(workspace, 'mixture.yaml')

    out_stat_path = os.path.join(
        workspace, 'stats', filename, 'model_type={}'.format(model_type),
        'scene_type={},snr={}'
        ''.format(scene_type,
                  snr), 'holdout_fold{}'.format(holdout_fold), 'stat.p')

    create_folder(os.path.dirname(out_stat_path))

    pred_prob_path = os.path.join(
        workspace, 'pred_probs', filename, 'model_type={}'.format(model_type),
        'scene_type={},snr={}'
        ''.format(scene_type,
                  snr), 'holdout_fold{}'.format(holdout_fold), 'pred_prob.p')

    create_folder(os.path.dirname(pred_prob_path))

    # Load yaml file
    load_yaml_time = time.time()

    with open(yaml_path, 'r') as f:
        meta = yaml.load(f)

    logging.info('Load yaml file time: {:.3f} s'.format(time.time() -
                                                        load_yaml_time))

    feature_extractor = LogMelExtractor(sample_rate=sample_rate,
                                        window_size=window_size,
                                        overlap=overlap,
                                        mel_bins=mel_bins)

    # Load model
    Model = get_model(model_type)

    model = Model(classes_num, seq_len, mel_bins, cuda)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])

    if cuda:
        model.cuda()

    # Data generator
    generator = InferenceDataGenerator(hdf5_path=hdf5_path,
                                       batch_size=batch_size,
                                       holdout_fold=holdout_fold)

    generate_func = generator.generate_validate(data_type='validate',
                                                shuffle=False,
                                                max_iteration=None)

    audio_names = []
    at_outputs = []
    at_targets = []

    sed_outputs = []
    sed_targets = []

    ss_outputs = []
    ss_targets = []

    validate_num = len(generator.validate_audio_indexes)

    # Evaluate on mini-batch
    for iteration, data in enumerate(generate_func):

        print('{} / {} inferenced & detected!'.format(iteration * batch_size,
                                                      validate_num))

        (batch_x, batch_y, batch_audio_names) = data

        batch_x = move_data_to_gpu(batch_x, cuda)

        # Predict
        with torch.no_grad():
            model.eval()
            (batch_output, batch_bottleneck) = model(batch_x,
                                                     return_bottleneck=True)

        batch_output = batch_output.data.cpu().numpy()
        '''(batch_size, classes_num)'''

        batch_bottleneck = batch_bottleneck.data.cpu().numpy()
        '''(batch_size, classes_num, seq_len, mel_bins)'''

        audio_names.append(batch_audio_names)
        at_outputs.append(batch_output)
        at_targets.append(batch_y)

        batch_pred_sed = np.mean(batch_bottleneck, axis=-1)
        batch_pred_sed = np.transpose(batch_pred_sed, (0, 2, 1))
        '''(batch_size, seq_len, classes_num)'''

        for n in range(len(batch_audio_names)):

            gt_meta = search_meta_by_mixture_name(meta, batch_audio_names[n])
            gt_events = gt_meta['events']
            gt_sed = get_sed_from_meta(gt_events)
            '''(seq_len, classes_num)'''

            # Do audio tagging first, then only apply SED to the positive
            # classes to reduce the false positives.
            pred_classes = np.where(batch_output[n] > threshold)[0]
            pred_sed = np.zeros((seq_len, classes_num))
            pred_sed[:, pred_classes] = batch_pred_sed[n][:, pred_classes]
            '''(seq_len, classes_num)'''

            sed_outputs.append(pred_sed)
            sed_targets.append(gt_sed)

            (events_stft, scene_stft, mixture_stft) = \
                generator.get_events_scene_mixture_stft(batch_audio_names[n])
            '''(seq_len, fft_bins)'''

            events_stft = np.dot(events_stft, feature_extractor.melW)
            scene_stft = np.dot(scene_stft, feature_extractor.melW)
            '''(seq_len, mel_bins)'''

            gt_mask = ideal_binary_mask(events_stft, scene_stft)
            '''(seq_len, mel_bins)'''

            gt_masks = gt_mask[:, :, None] * gt_sed[:, None, :]
            gt_masks = gt_masks.astype(np.float32)
            '''(seq_len, fft_size, classes_num)'''

            pred_masks = batch_bottleneck[n].transpose(1, 2, 0)
            '''(seq_len, fft_size, classes_num)'''

            ss_outputs.append(pred_masks)
            ss_targets.append(gt_masks)

        # if iteration == 3: break

    audio_names = np.concatenate(audio_names, axis=0)

    at_outputs = np.concatenate(at_outputs, axis=0)
    at_targets = np.concatenate(at_targets, axis=0)
    '''(audio_clips,)'''

    sed_outputs = np.array(sed_outputs)
    sed_targets = np.array(sed_targets)
    '''(audio_clips, seq_len, classes_num)'''

    ss_outputs = np.array(ss_outputs)
    ss_targets = np.array(ss_targets)
    '''(audio_clips, seq_len, mel_bins, classes_num)'''

    pred_prob = {
        'audio_name': audio_names,
        'at_output': at_outputs,
        'at_target': at_targets,
        'sed_output': sed_outputs,
        'sed_target': sed_targets
    }

    pickle.dump(pred_prob, open(pred_prob_path, 'wb'))
    logging.info('Saved stat to {}'.format(pred_prob_path))

    # Evaluate audio tagging
    at_time = time.time()

    (at_precision, at_recall,
     at_f1_score) = prec_recall_fvalue(at_targets, at_outputs, threshold, None)
    at_auc = metrics.roc_auc_score(at_targets, at_outputs, average=None)
    at_ap = metrics.average_precision_score(at_targets,
                                            at_outputs,
                                            average=None)

    logging.info('Audio tagging time: {:.3f} s'.format(time.time() - at_time))

    # Evaluate SED
    sed_time = time.time()

    (sed_precision, sed_recall, sed_f1_score) = prec_recall_fvalue(
        sed_targets.reshape((sed_targets.shape[0] * sed_targets.shape[1],
                             sed_targets.shape[2])),
        sed_outputs.reshape((sed_outputs.shape[0] * sed_outputs.shape[1],
                             sed_outputs.shape[2])),
        thres=threshold,
        average=None)

    sed_auc = metrics.roc_auc_score(sed_targets.reshape(
        (sed_targets.shape[0] * sed_targets.shape[1], sed_targets.shape[2])),
                                    sed_outputs.reshape(
                                        (sed_outputs.shape[0] *
                                         sed_outputs.shape[1],
                                         sed_outputs.shape[2])),
                                    average=None)

    sed_ap = metrics.average_precision_score(sed_targets.reshape(
        (sed_targets.shape[0] * sed_targets.shape[1], sed_targets.shape[2])),
                                             sed_outputs.reshape(
                                                 (sed_outputs.shape[0] *
                                                  sed_outputs.shape[1],
                                                  sed_outputs.shape[2])),
                                             average=None)

    logging.info('SED time: {:.3f} s'.format(time.time() - sed_time))

    # Evaluate source separation
    ss_time = time.time()
    (ss_precision, ss_recall, ss_f1_score) = prec_recall_fvalue(
        ss_targets.reshape(
            (ss_targets.shape[0] * ss_targets.shape[1] * ss_targets.shape[2],
             ss_targets.shape[3])),
        ss_outputs.reshape(
            (ss_outputs.shape[0] * ss_outputs.shape[1] * ss_outputs.shape[2],
             ss_outputs.shape[3])),
        thres=threshold,
        average=None)

    logging.info('SS fvalue time: {:.3f} s'.format(time.time() - ss_time))

    ss_time = time.time()
    ss_auc = metrics.roc_auc_score(
        ss_targets.reshape(
            (ss_targets.shape[0] * ss_targets.shape[1] * ss_targets.shape[2],
             ss_targets.shape[3])),
        ss_outputs.reshape(
            (ss_outputs.shape[0] * ss_outputs.shape[1] * ss_outputs.shape[2],
             ss_outputs.shape[3])),
        average=None)

    logging.info('SS AUC time: {:.3f} s'.format(time.time() - ss_time))

    ss_time = time.time()
    ss_ap = metrics.average_precision_score(
        ss_targets.reshape(
            (ss_targets.shape[0] * ss_targets.shape[1] * ss_targets.shape[2],
             ss_targets.shape[3])),
        ss_outputs.reshape(
            (ss_outputs.shape[0] * ss_outputs.shape[1] * ss_outputs.shape[2],
             ss_outputs.shape[3])),
        average=None)

    logging.info('SS AP time: {:.3f} s'.format(time.time() - ss_time))

    # Write stats
    stat = {
        'at_precision': at_precision,
        'at_recall': at_recall,
        'at_f1_score': at_f1_score,
        'at_auc': at_auc,
        'at_ap': at_ap,
        'sed_precision': sed_precision,
        'sed_recall': sed_recall,
        'sed_f1_score': sed_f1_score,
        'sed_auc': sed_auc,
        'sed_ap': sed_ap,
        'ss_precision': ss_precision,
        'ss_recall': ss_recall,
        'ss_f1_score': ss_f1_score,
        'ss_auc': ss_auc,
        'ss_ap': ss_ap
    }

    pickle.dump(stat, open(out_stat_path, 'wb'))
    logging.info('Saved stat to {}'.format(out_stat_path))
Ejemplo n.º 5
0
def plot_logmel(args):
    """Plot log Mel feature of one audio per class. 
    """

    # Arguments & parameters
    dataset_dir = args.dataset_dir

    sample_rate = config.sample_rate
    window_size = config.window_size
    overlap = config.overlap
    seq_len = config.seq_len
    mel_bins = config.mel_bins
    labels = config.labels
    classes_num = len(labels)

    plot_num = 12

    # Paths
    meta_csv = os.path.join(args.dataset_dir, 'metadata', 'train', 'weak.csv')
    audios_dir = os.path.join(args.dataset_dir, 'audio', 'train', 'weak')

    # Feature extractor
    feature_extractor = LogMelExtractor(sample_rate=sample_rate,
                                        window_size=window_size,
                                        overlap=overlap,
                                        mel_bins=mel_bins)

    # Calculate log mel feature of audio clips
    (audio_names, event_labels) = read_meta(meta_csv)

    selected_features_list = []
    selected_audio_names = []
    selected_labels = []

    # Select one audio per class and extract feature
    for label in labels:

        for (n, audio_name) in enumerate(audio_names):

            if label in event_labels[
                    n] and audio_name not in selected_audio_names:

                audio_path = os.path.join(audios_dir, audio_name)

                feature = calculate_logmel(audio_path=audio_path,
                                           sample_rate=sample_rate,
                                           feature_extractor=feature_extractor,
                                           seq_len=seq_len)

                selected_features_list.append(feature)
                selected_audio_names.append(audio_name)
                selected_labels.append(event_labels[n])

                break

    # Plot
    rows_num = 3
    cols_num = 4
    n = 0

    fig, axs = plt.subplots(rows_num, cols_num, figsize=(10, 5))

    for n in range(classes_num):
        row = n // cols_num
        col = n % cols_num
        axs[row, col].matshow(selected_features_list[n].T,
                              origin='lower',
                              aspect='auto',
                              cmap='jet')
        axs[row, col].set_title('No. {}, {}'.format(n, selected_labels[n]),
                                fontsize='small')
        axs[row, col].set_ylabel('log mel')
        axs[row, col].yaxis.set_ticks([])
        axs[row, col].xaxis.set_ticks([0, seq_len])
        axs[row, col].xaxis.set_ticklabels(['0', '10 s'], fontsize='small')
        axs[row, col].xaxis.tick_bottom()

    for n in range(classes_num, rows_num * cols_num):
        row = n // cols_num
        col = n % cols_num
        axs[row, col].set_visible(False)

    for n in range(classes_num):
        print('No. {}, {}'.format(n, selected_audio_names[n]))

    fig.tight_layout()
    plt.show()
Ejemplo n.º 6
0
from features import LogMelExtractor
import glob
import os
import numpy as np
import h5py
from multiprocessing.pool import ThreadPool
from utilities import pad_or_trunc
import librosa
import joblib
from functools import partial
feature_extractor = LogMelExtractor(sample_rate=32000,
                                        window_size=2048,
                                        overlap=720,
                                        mel_bins=64)

def covfea(audio):
    try:
        y, sr = librosa.core.load(audio, sr=32000)
        duration = librosa.get_duration(y=y, sr=sr)
        if duration < 11:
            return None
        else:
            data = y[0:320000]
            data /= np.max(np.abs(data))
            feature = feature_extractor.transform(data)
            feature = pad_or_trunc(feature, 240)
            return feature

    except KeyboardInterrupt:
        exit(0)
    except Exception as e:
Ejemplo n.º 7
0
def plot_logmel(args):
    """Plot log Mel feature of one audio per class. 
    """

    # Arguments & parameters
    audios_dir = args.audios_dir

    sample_rate = config.sample_rate
    window_size = config.window_size
    overlap = config.overlap
    seq_len = config.seq_len
    mel_bins = config.mel_bins
    labels = config.labels

    # Paths
    audio_names = os.listdir(audios_dir)

    # Feature extractor
    feature_extractor = LogMelExtractor(sample_rate=sample_rate,
                                        window_size=window_size,
                                        overlap=overlap,
                                        mel_bins=mel_bins)

    feature_list = []

    # Select one audio per class and extract feature
    for label in labels:

        for audio_name in audio_names:

            if label in audio_name:

                audio_path = os.path.join(audios_dir, audio_name)

                feature = calculate_logmel(audio_path=audio_path,
                                           sample_rate=sample_rate,
                                           feature_extractor=feature_extractor)

                feature_list.append(feature)

                break

    # Plot
    rows_num = 3
    cols_num = 4
    n = 0

    fig, axs = plt.subplots(rows_num, cols_num, figsize=(10, 5))

    classes_num = len(labels)

    for n in range(classes_num):
        row = n // cols_num
        col = n % cols_num
        axs[row, col].matshow(feature_list[n].T,
                              origin='lower',
                              aspect='auto',
                              cmap='jet')
        axs[row, col].set_title(labels[n])
        axs[row, col].set_ylabel('log mel')
        axs[row, col].yaxis.set_ticks([])
        axs[row, col].xaxis.set_ticks([0, seq_len])
        axs[row, col].xaxis.set_ticklabels(['0', '10 s'], fontsize='small')
        axs[row, col].xaxis.tick_bottom()

    for n in range(classes_num, rows_num * cols_num):
        row = n // cols_num
        col = n % cols_num
        axs[row, col].set_visible(False)

    fig.tight_layout()
    plt.show()
Ejemplo n.º 8
0
def plot_logmel():
    """Plot log Mel feature of one audio per class. 
    """

    # Arguments & parameters
    # audios_dir = args.audios_dir

    sample_rate = config.sample_rate
    window_size = config.window_size
    overlap = config.overlap
    seq_len = config.seq_len
    mel_bins = config.mel_bins
    labels = config.labels

    # Paths
    audio_names = ['airport-barcelona-0-0-a.wav']

    # Feature extractor
    feature_extractor = LogMelExtractor(sample_rate=sample_rate,
                                        window_size=window_size,
                                        overlap=overlap,
                                        mel_bins=mel_bins)

    feature_list = []

    # Select one audio per class and extract feature
    for label in labels:

        for audio_name in audio_names:

            if label in audio_name:
                # audio_path = os.path.join(audios_dir, audio_name)

                feature = calculate_three_logmel(
                    audio_path=audio_name,
                    sample_rate=sample_rate,
                    feature_extractor=feature_extractor)

                feature_list.append(feature)

                break

    # Plot
    rows_num = 2
    cols_num = 2
    n = 0

    fig, axs = plt.subplots(rows_num, cols_num, figsize=(10, 5))

    classes_num = len(labels)

    data = [feature[i] for i in range(3)]
    # for da in data:
    axs[0, 0].matshow(data[0].T, origin='lower', aspect='auto', cmap='jet')
    axs[0, 1].matshow(data[1].T, origin='lower', aspect='auto', cmap='jet')
    axs[1, 0].matshow(data[2].T, origin='lower', aspect='auto', cmap='jet')

    # for n in range(classes_num):
    #     row = n // cols_num
    #     col = n % cols_num
    #     axs[row, col].matshow(feature_list[n].T, origin='lower', aspect='auto', cmap='jet')
    #     axs[row, col].set_title(labels[n])
    #     axs[row, col].set_ylabel('log mel')
    #     axs[row, col].yaxis.set_ticks([])
    #     axs[row, col].xaxis.set_ticks([0, seq_len])
    #     axs[row, col].xaxis.set_ticklabels(['0', '10 s'], fontsize='small')
    #     axs[row, col].xaxis.tick_bottom()
    #
    # for n in range(classes_num, rows_num * cols_num):
    #     row = n // cols_num
    #     col = n % cols_num
    #     axs[row, col].set_visible(False)

    fig.tight_layout()
    plt.show()