Beispiel #1
0
def _run_model_for_blocks(input_blocks_np: np.ndarray, model_dir: str,
                          checkpoint: Optional[str], input_tensor_name: str,
                          output_tensor_name: str) -> np.ndarray:
    """Runs separation model for each block.

  The input is a multi-channel signal, but the output is a single channel
  output per source signal.

  Args:
    input_blocks_np: Input mixture signal samples, np.ndarray with shape
      [num_blocks, num_mics, num_samples_in_block].
    model_dir: Model directory with at least one checkpoint and inference.meta
      file.
    checkpoint: If not None, checkpoint path to use, otherwise use the
      latest checkpoint in the model_dir.
    input_tensor_name: The name of the input tensor in the model.
    output_tensor_name: The name of the output tensor in the model.
  Returns:
    output_blocks_np: Output signal samples, np.ndarray with shape
      [num_blocks, num_sources, num_samples_in_block].
  """

    model_graph_filename = os.path.join(model_dir, 'inference.meta')
    tf.logging.info('Importing meta graph: %s', model_graph_filename)

    if not checkpoint:
        checkpoint = tf.train.latest_checkpoint(model_dir)
    # Use separation model.
    separation_model = inference.SeparationModel(checkpoint,
                                                 model_graph_filename,
                                                 input_tensor_name,
                                                 output_tensor_name)
    output_blocks = []
    for i in range(input_blocks_np.shape[0]):
        print('Processing block %d of %d...' %
              (i + 1, input_blocks_np.shape[0]))
        output_blocks.append(separation_model.separate(input_blocks_np[i]))
    output_blocks_np = np.stack(output_blocks, axis=0)
    return output_blocks_np
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser(
        description='Evaluate a source separation model.')
    parser.add_argument('-cp',
                        '--checkpoint_path',
                        help='Path for model checkpoint files.',
                        required=True)
    parser.add_argument('-mp',
                        '--metagraph_path',
                        help='Path for inference metagraph.',
                        required=True)
    parser.add_argument('-dp',
                        '--data_list_path',
                        help='Path for list of files.',
                        required=True)
    parser.add_argument('-op',
                        '--output_path',
                        help='Path of resulting csv file.',
                        required=True)
    args = parser.parse_args()

    model = inference.SeparationModel(args.checkpoint_path,
                                      args.metagraph_path)

    file_list = data_io.read_lines_from_file(args.data_list_path,
                                             skip_fields=1)
    with model.graph.as_default():
        dataset = data_io.wavs_to_dataset(file_list,
                                          batch_size=1,
                                          num_samples=160000,
                                          repeat=False)
        # Strip batch and mic dimensions.
        dataset['receiver_audio'] = dataset['receiver_audio'][0, 0]
        dataset['source_images'] = dataset['source_images'][0, :, 0]

    # Separate with a trained model.
    i = 1
    min_count = 1
    max_count = 4
    sisnri_per_source_count = {c: [] for c in range(min_count, max_count + 1)}
    columns_mix = [
        'SISNR_mixture_source%d' % j for j in range(min_count, max_count + 1)
    ]
    columns_sep = [
        'SISNR_separated_source%d' % j for j in range(min_count, max_count + 1)
    ]
    df = pd.DataFrame(columns=columns_mix + columns_sep)
    while True:
        try:
            waveforms = model.sess.run(dataset)
        except tf.errors.OutOfRangeError:
            break
        separated_waveforms = model.separate(waveforms['receiver_audio'])
        source_waveforms = waveforms['source_images']
        if np.allclose(source_waveforms, 0):
            print('WARNING: all-zeros source_waveforms tensor encountered.'
                  'Skiping this example...')
            continue
        sisnr_sep, sisnr_mix = compute_perminv_sisnri(
            source_waveforms, separated_waveforms, waveforms['receiver_audio'])
        sisnr_sep = sisnr_sep.numpy()
        sisnr_mix = sisnr_mix.numpy()
        sisnr_imp = np.mean(sisnr_sep - sisnr_mix)
        source_count = len(sisnr_sep)

        row_dict = {
            col: sisnr
            for col, sisnr in zip(columns_mix[:len(sisnr_mix)], sisnr_mix)
        }
        row_dict.update({
            col: sisnr
            for col, sisnr in zip(columns_sep[:len(sisnr_sep)], sisnr_sep)
        })
        new_row = pd.Series(row_dict)
        df = df.append(new_row, ignore_index=True)
        sisnri_per_source_count[source_count].append(sisnr_imp)
        print('Example %d: SI-SNR sep = %.1f dB, SI-SNR mix = %.1f dB,'
              'SI-SNR imp = %.1f dB, source count = %d' %
              (i, np.mean(sisnr_sep), np.mean(sisnr_mix),
               np.mean(sisnr_sep - sisnr_mix), source_count))
        if not i % 20:
            # Report mean statistics every so often.
            _print_score_stats(sisnri_per_source_count)
        i += 1

    # Report final mean statistics.
    print('\nFinal statistics:')
    _print_score_stats(sisnri_per_source_count)

    # Write csv.
    csv_path = os.path.join(args.output_path, 'scores.csv')
    print('\nWriting csv to %s.' % csv_path)
    df.to_csv(csv_path)
def evaluate(checkpoint_path, metagraph_path, data_list_path, output_path):
    """Evaluate a model on FUSS data."""
    model = inference.SeparationModel(checkpoint_path, metagraph_path)

    file_list = data_io.read_lines_from_file(data_list_path, skip_fields=1)
    with model.graph.as_default():
        dataset = data_io.wavs_to_dataset(file_list,
                                          batch_size=1,
                                          num_samples=160000,
                                          repeat=False)
        # Strip batch and mic dimensions.
        dataset['receiver_audio'] = dataset['receiver_audio'][0, 0]
        dataset['source_images'] = dataset['source_images'][0, :, 0]

    # Separate with a trained model.
    i = 1
    max_count = 4
    dict_per_source_count = lambda: {c: [] for c in range(1, max_count + 1)}
    sisnr_per_source_count = dict_per_source_count()
    sisnri_per_source_count = dict_per_source_count()
    under_seps = []
    equal_seps = []
    over_seps = []
    df = None
    while True:
        try:
            waveforms = model.sess.run(dataset)
        except tf.errors.OutOfRangeError:
            break
        separated_waveforms = model.separate(waveforms['receiver_audio'])
        source_waveforms = waveforms['source_images']
        if np.allclose(source_waveforms, 0):
            print('WARNING: all-zeros source_waveforms tensor encountered.'
                  'Skiping this example...')
            continue
        metrics_dict = compute_metrics(source_waveforms, separated_waveforms,
                                       waveforms['receiver_audio'])
        metrics_dict = {k: v.numpy() for k, v in metrics_dict.items()}
        sisnr_sep = metrics_dict['sisnr_separated']
        sisnr_mix = metrics_dict['sisnr_mixture']
        sisnr_imp = metrics_dict['sisnr_improvement']
        weights_active_pairs = metrics_dict['weights_active_pairs']

        # Create and initialize the dataframe if it doesn't exist.
        if df is None:
            # Need to create the dataframe.
            columns = []
            for metric_name, metric_value in metrics_dict.items():
                if metric_value.shape:
                    # Per-source metric.
                    for i_src in range(1, max_count + 1):
                        columns.append(metric_name + '_source%d' % i_src)
                else:
                    # Scalar metric.
                    columns.append(metric_name)
            columns.sort()
            df = pd.DataFrame(columns=columns)
            if output_path.endswith('.csv'):
                csv_path = output_path
            else:
                csv_path = os.path.join(output_path, 'scores.csv')

        # Update dataframe with new metrics.
        row_dict = {}
        for metric_name, metric_value in metrics_dict.items():
            if metric_value.shape:
                # Per-source metric.
                for i_src in range(1, max_count + 1):
                    row_dict[metric_name +
                             '_source%d' % i_src] = metric_value[i_src - 1]
            else:
                # Scalar metric.
                row_dict[metric_name] = metric_value
        new_row = pd.Series(row_dict)
        df = df.append(new_row, ignore_index=True)

        # Store metrics per source count and report results so far.
        under_seps.append(metrics_dict['under_separation'])
        equal_seps.append(metrics_dict['equal_separation'])
        over_seps.append(metrics_dict['over_separation'])
        sisnr_per_source_count[metrics_dict['num_active_refs']].extend(
            sisnr_sep[weights_active_pairs].tolist())
        sisnri_per_source_count[metrics_dict['num_active_refs']].extend(
            sisnr_imp[weights_active_pairs].tolist())
        print('Example %d: SI-SNR sep = %.1f dB, SI-SNR mix = %.1f dB, '
              'SI-SNR imp = %.1f dB, ref count = %d, sep count = %d' %
              (i, np.mean(sisnr_sep), np.mean(sisnr_mix),
               np.mean(sisnr_sep - sisnr_mix), metrics_dict['num_active_refs'],
               metrics_dict['num_active_seps']))
        if not i % 20:
            # Report mean statistics and save csv every so often.
            lines = [
                'Metrics after %d examples:' % i,
                _report_score_stats(sisnr_per_source_count,
                                    'SI-SNR',
                                    counts=[1]),
                _report_score_stats(sisnri_per_source_count,
                                    'SI-SNRi',
                                    counts=[2]),
                _report_score_stats(sisnri_per_source_count,
                                    'SI-SNRi',
                                    counts=[3]),
                _report_score_stats(sisnri_per_source_count,
                                    'SI-SNRi',
                                    counts=[4]),
                _report_score_stats(sisnri_per_source_count,
                                    'SI-SNRi',
                                    counts=[2, 3, 4]),
                'Under separation: %.2f' % np.mean(under_seps),
                'Equal separation: %.2f' % np.mean(equal_seps),
                'Over separation: %.2f' % np.mean(over_seps),
            ]
            print('')
            for line in lines:
                print(line)
            with open(csv_path.replace('.csv', '_summary.txt'), 'w+') as f:
                f.writelines([line + '\n' for line in lines])

            print('\nWriting csv to %s.\n' % csv_path)
            df.to_csv(csv_path)
        i += 1

    # Report final mean statistics.
    lines = [
        'Final statistics:',
        _report_score_stats(sisnr_per_source_count, 'SI-SNR', counts=[1]),
        _report_score_stats(sisnri_per_source_count, 'SI-SNRi', counts=[2]),
        _report_score_stats(sisnri_per_source_count, 'SI-SNRi', counts=[3]),
        _report_score_stats(sisnri_per_source_count, 'SI-SNRi', counts=[4]),
        _report_score_stats(sisnri_per_source_count,
                            'SI-SNRi',
                            counts=[2, 3, 4]),
        'Under separation: %.2f' % np.mean(under_seps),
        'Equal separation: %.2f' % np.mean(equal_seps),
        'Over separation: %.2f' % np.mean(over_seps),
    ]
    print('')
    for line in lines:
        print(line)
    with open(csv_path.replace('.csv', '_summary.txt'), 'w+') as f:
        f.writelines([line + '\n' for line in lines])

    # Write final csv.
    print('\nWriting csv to %s.' % csv_path)
    df.to_csv(csv_path)
Beispiel #4
0
            os.makedirs(out_sep_fodler, exist_ok=True)
            out_file = osp.join(out_sep_fodler, f"{cnt}.wav")
            sf.write(out_file, sep_wav, samplerate=16000)


if __name__ == '__main__':
    import glob
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("-a", '--audio_path', type=str, required=True)
    parser.add_argument("-o", '--output_folder', type=str, required=True)
    parser.add_argument("-c", "--checkpoint", type=str, required=True)
    parser.add_argument("-i", "--inference", type=str, required=True)
    f_args = parser.parse_args()

    wav_list = glob.glob(osp.join(f_args.audio_path, "*.wav"))
    if len(wav_list) == 0:
        wav_list = glob.glob(
            osp.join(f_args.audio_path, "soundscapes", "*.wav"))
    if len(wav_list) == 0:
        raise IndexError(
            f"Empty wav_list, you need to give a valid audio_path. Not valid: {f_args.audio_path}"
        )
    # model_dir = f_args.model_dir
    # checkpoint_path = osp.join(model_dir, 'baseline_model')
    # metagraph_path = osp.join(model_dir, 'baseline_inference.meta')
    checkpoint_path = f_args.checkpoint
    metagraph_path = f_args.inference
    ss_model = inference.SeparationModel(checkpoint_path, metagraph_path)

    main(wav_list, ss_model, f_args.output_folder)