Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(
        description='Train the DCASE2020 FUSS baseline source separation model.'
    )
    parser.add_argument('-dd',
                        '--data_dir',
                        help='Data directory.',
                        required=True)
    parser.add_argument('-md',
                        '--model_dir',
                        help='Directory for checkpoints and summaries.',
                        required=True)
    args = parser.parse_args()

    hparams = model.get_model_hparams()
    hparams.sr = 16000.0
    hparams.num_sources_for_summaries = [1, 2, 3, 4]

    roomsim_params = {
        'num_sources': len(hparams.signal_names),
        'num_receivers': 1,
        'num_samples': int(hparams.sr * 10.0),
    }
    tf.logging.info('Params: %s', roomsim_params.values())

    feature_spec = data_io.get_roomsim_spec(**roomsim_params)
    inference_spec = data_io.get_inference_spec()

    train_list = os.path.join(args.data_dir, 'train_example_list.txt')
    validation_list = os.path.join(args.data_dir,
                                   'validation_example_list.txt')

    params = {
        'feature_spec': feature_spec,
        'inference_spec': inference_spec,
        'hparams': hparams,
        'io_params': {
            'parallel_readers': 512,
            'num_samples': int(hparams.sr * 10.0)
        },
        'input_data_train': train_list,
        'input_data_eval': validation_list,
        'model_dir': args.model_dir,
        'train_batch_size': 3,
        'eval_batch_size': 3,
        'train_steps': 20000000,
        'eval_suffix': 'validation',
        'eval_examples': 800,
        'save_checkpoints_secs': 600,
        'save_summary_steps': 1000,
        'keep_checkpoint_every_n_hours': 4,
        'write_inference_graph': True,
        'randomize_training': True,
    }
    tf.logging.info(params)
    params['input_data_train'] = data_io.read_lines_from_file(
        params['input_data_train'], skip_fields=1)
    params['input_data_eval'] = data_io.read_lines_from_file(
        params['input_data_eval'], skip_fields=1)
    train_with_estimator.execute(model.model_fn, data_io.input_fn, **params)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser(
        description='Evaluate a source separation model.')
    parser.add_argument('-cp',
                        '--checkpoint_path',
                        help='Path for model checkpoint files.',
                        required=True)
    parser.add_argument('-mp',
                        '--metagraph_path',
                        help='Path for inference metagraph.',
                        required=True)
    parser.add_argument('-dp',
                        '--data_list_path',
                        help='Path for list of files.',
                        required=True)
    parser.add_argument('-op',
                        '--output_path',
                        help='Path of resulting csv file.',
                        required=True)
    args = parser.parse_args()

    model = inference.SeparationModel(args.checkpoint_path,
                                      args.metagraph_path)

    file_list = data_io.read_lines_from_file(args.data_list_path,
                                             skip_fields=1)
    with model.graph.as_default():
        dataset = data_io.wavs_to_dataset(file_list,
                                          batch_size=1,
                                          num_samples=160000,
                                          repeat=False)
        # Strip batch and mic dimensions.
        dataset['receiver_audio'] = dataset['receiver_audio'][0, 0]
        dataset['source_images'] = dataset['source_images'][0, :, 0]

    # Separate with a trained model.
    i = 1
    min_count = 1
    max_count = 4
    sisnri_per_source_count = {c: [] for c in range(min_count, max_count + 1)}
    columns_mix = [
        'SISNR_mixture_source%d' % j for j in range(min_count, max_count + 1)
    ]
    columns_sep = [
        'SISNR_separated_source%d' % j for j in range(min_count, max_count + 1)
    ]
    df = pd.DataFrame(columns=columns_mix + columns_sep)
    while True:
        try:
            waveforms = model.sess.run(dataset)
        except tf.errors.OutOfRangeError:
            break
        separated_waveforms = model.separate(waveforms['receiver_audio'])
        source_waveforms = waveforms['source_images']
        if np.allclose(source_waveforms, 0):
            print('WARNING: all-zeros source_waveforms tensor encountered.'
                  'Skiping this example...')
            continue
        sisnr_sep, sisnr_mix = compute_perminv_sisnri(
            source_waveforms, separated_waveforms, waveforms['receiver_audio'])
        sisnr_sep = sisnr_sep.numpy()
        sisnr_mix = sisnr_mix.numpy()
        sisnr_imp = np.mean(sisnr_sep - sisnr_mix)
        source_count = len(sisnr_sep)

        row_dict = {
            col: sisnr
            for col, sisnr in zip(columns_mix[:len(sisnr_mix)], sisnr_mix)
        }
        row_dict.update({
            col: sisnr
            for col, sisnr in zip(columns_sep[:len(sisnr_sep)], sisnr_sep)
        })
        new_row = pd.Series(row_dict)
        df = df.append(new_row, ignore_index=True)
        sisnri_per_source_count[source_count].append(sisnr_imp)
        print('Example %d: SI-SNR sep = %.1f dB, SI-SNR mix = %.1f dB,'
              'SI-SNR imp = %.1f dB, source count = %d' %
              (i, np.mean(sisnr_sep), np.mean(sisnr_mix),
               np.mean(sisnr_sep - sisnr_mix), source_count))
        if not i % 20:
            # Report mean statistics every so often.
            _print_score_stats(sisnri_per_source_count)
        i += 1

    # Report final mean statistics.
    print('\nFinal statistics:')
    _print_score_stats(sisnri_per_source_count)

    # Write csv.
    csv_path = os.path.join(args.output_path, 'scores.csv')
    print('\nWriting csv to %s.' % csv_path)
    df.to_csv(csv_path)
Ejemplo n.º 3
0
def evaluate(checkpoint_path, metagraph_path, data_list_path, output_path):
    """Evaluate a model on FUSS data."""
    model = inference.SeparationModel(checkpoint_path, metagraph_path)

    file_list = data_io.read_lines_from_file(data_list_path, skip_fields=1)
    with model.graph.as_default():
        dataset = data_io.wavs_to_dataset(file_list,
                                          batch_size=1,
                                          num_samples=160000,
                                          repeat=False)
        # Strip batch and mic dimensions.
        dataset['receiver_audio'] = dataset['receiver_audio'][0, 0]
        dataset['source_images'] = dataset['source_images'][0, :, 0]

    # Separate with a trained model.
    i = 1
    max_count = 4
    dict_per_source_count = lambda: {c: [] for c in range(1, max_count + 1)}
    sisnr_per_source_count = dict_per_source_count()
    sisnri_per_source_count = dict_per_source_count()
    under_seps = []
    equal_seps = []
    over_seps = []
    df = None
    while True:
        try:
            waveforms = model.sess.run(dataset)
        except tf.errors.OutOfRangeError:
            break
        separated_waveforms = model.separate(waveforms['receiver_audio'])
        source_waveforms = waveforms['source_images']
        if np.allclose(source_waveforms, 0):
            print('WARNING: all-zeros source_waveforms tensor encountered.'
                  'Skiping this example...')
            continue
        metrics_dict = compute_metrics(source_waveforms, separated_waveforms,
                                       waveforms['receiver_audio'])
        metrics_dict = {k: v.numpy() for k, v in metrics_dict.items()}
        sisnr_sep = metrics_dict['sisnr_separated']
        sisnr_mix = metrics_dict['sisnr_mixture']
        sisnr_imp = metrics_dict['sisnr_improvement']
        weights_active_pairs = metrics_dict['weights_active_pairs']

        # Create and initialize the dataframe if it doesn't exist.
        if df is None:
            # Need to create the dataframe.
            columns = []
            for metric_name, metric_value in metrics_dict.items():
                if metric_value.shape:
                    # Per-source metric.
                    for i_src in range(1, max_count + 1):
                        columns.append(metric_name + '_source%d' % i_src)
                else:
                    # Scalar metric.
                    columns.append(metric_name)
            columns.sort()
            df = pd.DataFrame(columns=columns)
            if output_path.endswith('.csv'):
                csv_path = output_path
            else:
                csv_path = os.path.join(output_path, 'scores.csv')

        # Update dataframe with new metrics.
        row_dict = {}
        for metric_name, metric_value in metrics_dict.items():
            if metric_value.shape:
                # Per-source metric.
                for i_src in range(1, max_count + 1):
                    row_dict[metric_name +
                             '_source%d' % i_src] = metric_value[i_src - 1]
            else:
                # Scalar metric.
                row_dict[metric_name] = metric_value
        new_row = pd.Series(row_dict)
        df = df.append(new_row, ignore_index=True)

        # Store metrics per source count and report results so far.
        under_seps.append(metrics_dict['under_separation'])
        equal_seps.append(metrics_dict['equal_separation'])
        over_seps.append(metrics_dict['over_separation'])
        sisnr_per_source_count[metrics_dict['num_active_refs']].extend(
            sisnr_sep[weights_active_pairs].tolist())
        sisnri_per_source_count[metrics_dict['num_active_refs']].extend(
            sisnr_imp[weights_active_pairs].tolist())
        print('Example %d: SI-SNR sep = %.1f dB, SI-SNR mix = %.1f dB, '
              'SI-SNR imp = %.1f dB, ref count = %d, sep count = %d' %
              (i, np.mean(sisnr_sep), np.mean(sisnr_mix),
               np.mean(sisnr_sep - sisnr_mix), metrics_dict['num_active_refs'],
               metrics_dict['num_active_seps']))
        if not i % 20:
            # Report mean statistics and save csv every so often.
            lines = [
                'Metrics after %d examples:' % i,
                _report_score_stats(sisnr_per_source_count,
                                    'SI-SNR',
                                    counts=[1]),
                _report_score_stats(sisnri_per_source_count,
                                    'SI-SNRi',
                                    counts=[2]),
                _report_score_stats(sisnri_per_source_count,
                                    'SI-SNRi',
                                    counts=[3]),
                _report_score_stats(sisnri_per_source_count,
                                    'SI-SNRi',
                                    counts=[4]),
                _report_score_stats(sisnri_per_source_count,
                                    'SI-SNRi',
                                    counts=[2, 3, 4]),
                'Under separation: %.2f' % np.mean(under_seps),
                'Equal separation: %.2f' % np.mean(equal_seps),
                'Over separation: %.2f' % np.mean(over_seps),
            ]
            print('')
            for line in lines:
                print(line)
            with open(csv_path.replace('.csv', '_summary.txt'), 'w+') as f:
                f.writelines([line + '\n' for line in lines])

            print('\nWriting csv to %s.\n' % csv_path)
            df.to_csv(csv_path)
        i += 1

    # Report final mean statistics.
    lines = [
        'Final statistics:',
        _report_score_stats(sisnr_per_source_count, 'SI-SNR', counts=[1]),
        _report_score_stats(sisnri_per_source_count, 'SI-SNRi', counts=[2]),
        _report_score_stats(sisnri_per_source_count, 'SI-SNRi', counts=[3]),
        _report_score_stats(sisnri_per_source_count, 'SI-SNRi', counts=[4]),
        _report_score_stats(sisnri_per_source_count,
                            'SI-SNRi',
                            counts=[2, 3, 4]),
        'Under separation: %.2f' % np.mean(under_seps),
        'Equal separation: %.2f' % np.mean(equal_seps),
        'Over separation: %.2f' % np.mean(over_seps),
    ]
    print('')
    for line in lines:
        print(line)
    with open(csv_path.replace('.csv', '_summary.txt'), 'w+') as f:
        f.writelines([line + '\n' for line in lines])

    # Write final csv.
    print('\nWriting csv to %s.' % csv_path)
    df.to_csv(csv_path)