Esempio n. 1
0
    def predict(self, path: str, wav_data=None):
        """Using the model, return the predicted note sequence of a .wav file at the given path.

        Args:
            path (str): The path to a .wav audio file. If path is "binary", then a binary must be specified.
            wav_data (bytes): The binary for the .wav file if that is easier to extract. Defaults to None whqen path is provided.

        Returns:
            NoteSequence object containing the prediction. Convertable to MIDI.
        """
        if path == "binary":
            if wav_data is None:
                raise ValueError(
                    "The binary option is chosen but a binary is not provided."
                )
        else:
            f = open(path, "rb")
            wav_data = f.read()
            f.close()

        ns = note_seq.NoteSequence()
        example_list = [
            audio_label_data_utils.create_example(
                path,
                ns,
                wav_data,
                velocity_range=audio_label_data_utils.
                velocity_range_from_sequence(ns))
        ]
        to_process = [example_list[0].SerializeToString()]

        print('Processing complete for', path)

        sess = tf.Session()

        sess.run([
            tf.initializers.global_variables(),
            tf.initializers.local_variables()
        ])

        sess.run(self.iterator.initializer, {self.examples: to_process})

        def transcription_data(params):
            del params
            return tf.data.Dataset.from_tensors(sess.run(self.next_record))

        input_fn = infer_util.labels_to_features_wrapper(transcription_data)
        prediction_list = list(
            self.estimator.predict(input_fn, yield_single_examples=False))
        assert len(prediction_list) == 1
        sequence_prediction = note_seq.NoteSequence.FromString(
            prediction_list[0]['sequence_predictions'][0])

        return sequence_prediction
Esempio n. 2
0
    def process(files):
        for fn in files:
            print('**\n\n', fn, '\n\n**')
            with open(fn, 'rb', buffering=0) as f:
                wav_data = f.read()
            example_list = list(
                audio_label_data_utils.process_record(
                wav_data=wav_data,
                ns=music_pb2.NoteSequence(),
                example_id=fn,
                min_length=0,
                max_length=-1,
                allow_empty_notesequence=True))
            assert len(example_list) == 1
            to_process.append(example_list[0].SerializeToString())
            print('Processing complete for', fn)

            sess = tf.Session()

            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            sess.run(iterator.initializer, {examples: to_process})

            def transcription_data(params):
                del params
                return tf.data.Dataset.from_tensors(sess.run(next_record))


            input_fn = infer_util.labels_to_features_wrapper(transcription_data)

            #@title Run inference
            prediction_list = list(
                estimator.predict(
                    input_fn,
                    yield_single_examples=False))
            assert len(prediction_list) == 1

            # Ignore warnings caused by pyfluidsynth
            import warnings
            warnings.filterwarnings("ignore", category=DeprecationWarning) 

            sequence_prediction = music_pb2.NoteSequence.FromString(
                prediction_list[0]['sequence_predictions'][0])

            pathname = fn.split('/').pop()
            print('**\n\n', pathname, '\n\n**')
            midi_filename = '{outputs}/{file}.mid'.format(outputs=output,file=pathname)
            midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)
def run(argv, config_map, data_fn):
    """Create transcriptions."""
    tf.logging.set_verbosity(FLAGS.log)

    config = config_map[FLAGS.config]
    hparams = config.hparams
    # For this script, default to not using cudnn.
    hparams.use_cudnn = False
    hparams.parse(FLAGS.hparams)
    hparams.batch_size = 1
    hparams.truncated_length_secs = 0

    with tf.Graph().as_default():
        examples = tf.placeholder(tf.string, [None])

        dataset = data_fn(examples=examples,
                          preprocess_examples=True,
                          params=hparams,
                          is_training=False,
                          shuffle_examples=False,
                          skip_n_initial_records=0)

        estimator = train_util.create_estimator(
            config.model_fn, os.path.expanduser(FLAGS.model_dir), hparams)

        iterator = dataset.make_initializable_iterator()
        next_record = iterator.get_next()

        with tf.Session() as sess:
            sess.run([
                tf.initializers.global_variables(),
                tf.initializers.local_variables()
            ])

            for filename in argv[1:]:
                tf.logging.info('Starting transcription for %s...', filename)

                # The reason we bounce between two Dataset objects is so we can use
                # the data processing functionality in data.py without having to
                # construct all the Example protos in memory ahead of time or create
                # a temporary tfrecord file.
                tf.logging.info('Processing file...')
                sess.run(
                    iterator.initializer, {
                        examples: [
                            create_example(filename,
                                           FLAGS.load_audio_with_librosa)
                        ]
                    })

                def transcription_data(params):
                    del params
                    return tf.data.Dataset.from_tensors(sess.run(next_record))

                input_fn = infer_util.labels_to_features_wrapper(
                    transcription_data)

                tf.logging.info('Running inference...')
                checkpoint_path = None
                if FLAGS.checkpoint_path:
                    checkpoint_path = os.path.expanduser(FLAGS.checkpoint_path)
                prediction_list = list(
                    estimator.predict(input_fn,
                                      checkpoint_path=checkpoint_path,
                                      yield_single_examples=False))
                assert len(prediction_list) == 1

                sequence_prediction = music_pb2.NoteSequence.FromString(
                    prediction_list[0]['sequence_predictions'][0])

                midi_filename = filename + FLAGS.transcribed_file_suffix + '.midi'
                midi_io.sequence_proto_to_midi_file(sequence_prediction,
                                                    midi_filename)

                tf.logging.info('Transcription written to %s.', midi_filename)
def transcribe(audio, sr, cuda=False):
    """
    Google sucks and want to use audio path (raw wav) instead of decoded
    samples loosing in decoupling between file format and DSP

    input audio and sample rate, output mat like asmd with (pitch, ons, offs, velocity)
    """

    # simple hack because google sucks... in this way we can accept audio data
    # already loaded and keep our reasonable interface (and decouple i/o
    # from processing)
    original_google_sucks = audio_io.wav_data_to_samples
    audio_io.wav_data_to_samples = google_sucks
    audio = np.array(audio)
    config = configs.CONFIG_MAP['onsets_frames']
    hparams = config.hparams
    hparams.use_cudnn = cuda
    hparams.batch_size = 1
    examples = tf.placeholder(tf.string, [None])

    dataset = data.provide_batch(examples=examples,
                                 preprocess_examples=True,
                                 params=hparams,
                                 is_training=False,
                                 shuffle_examples=False,
                                 skip_n_initial_records=0)

    estimator = train_util.create_estimator(config.model_fn, CHECKPOINT_DIR,
                                            hparams)

    iterator = dataset.make_initializable_iterator()
    next_record = iterator.get_next()

    example_list = list(
        audio_label_data_utils.process_record(wav_data=audio,
                                              sample_rate=sr,
                                              ns=music_pb2.NoteSequence(),
                                              example_id="fakeid",
                                              min_length=0,
                                              max_length=-1,
                                              allow_empty_notesequence=True,
                                              load_audio_with_librosa=False))
    assert len(example_list) == 1
    to_process = [example_list[0].SerializeToString()]

    sess = tf.Session()

    sess.run([
        tf.initializers.global_variables(),
        tf.initializers.local_variables()
    ])

    sess.run(iterator.initializer, {examples: to_process})

    def transcription_data(params):
        del params
        return tf.data.Dataset.from_tensors(sess.run(next_record))

    # put back the original function (it still writes and reload... stupid
    # though
    audio_io.wav_data_to_samples = original_google_sucks
    input_fn = infer_util.labels_to_features_wrapper(transcription_data)

    prediction_list = list(
        estimator.predict(input_fn, yield_single_examples=False))

    assert len(prediction_list) == 1

    notes = music_pb2.NoteSequence.FromString(
        prediction_list[0]['sequence_predictions'][0]).notes

    out = np.empty((len(notes), 4))
    for i, note in enumerate(notes):
        out[i] = [note.pitch, note.start_time, note.end_time, note.velocity]
    return out
Esempio n. 5
0
def model_inference(model_fn, model_dir, checkpoint_path, data_fn, hparams,
                    examples_path, output_dir, summary_writer, master,
                    preprocess_examples, shuffle_examples):
    """Runs inference for the given examples."""
    tf.logging.info('model_dir=%s', model_dir)
    tf.logging.info('checkpoint_path=%s', checkpoint_path)
    tf.logging.info('examples_path=%s', examples_path)
    tf.logging.info('output_dir=%s', output_dir)

    estimator = train_util.create_estimator(model_fn,
                                            model_dir,
                                            hparams,
                                            master=master)

    transcription_data = functools.partial(
        data_fn,
        examples=examples_path,
        preprocess_examples=preprocess_examples,
        is_training=False,
        shuffle_examples=shuffle_examples,
        skip_n_initial_records=0)

    input_fn = infer_util.labels_to_features_wrapper(transcription_data)

    start_time = time.time()
    infer_times = []
    num_frames = []

    file_num = 0

    all_metrics = collections.defaultdict(list)

    for predictions in estimator.predict(input_fn,
                                         checkpoint_path=checkpoint_path,
                                         yield_single_examples=False):

        # Remove batch dimension for convenience.
        for k in predictions.keys():
            if predictions[k].shape[0] != 1:
                raise ValueError(
                    'All predictions must have batch size 1, but shape of '
                    '{} was: {}'.format(k, +predictions[k].shape[0]))
            predictions[k] = predictions[k][0]

        end_time = time.time()
        infer_time = end_time - start_time
        infer_times.append(infer_time)
        num_frames.append(predictions['frame_predictions'].shape[0])
        tf.logging.info(
            'Infer time %f, frames %d, frames/sec %f, running average %f',
            infer_time, num_frames[-1], num_frames[-1] / infer_time,
            np.sum(num_frames) / np.sum(infer_times))

        tf.logging.info('Scoring sequence %s', predictions['sequence_ids'])

        sequence_prediction = music_pb2.NoteSequence.FromString(
            predictions['sequence_predictions'])
        sequence_label = music_pb2.NoteSequence.FromString(
            predictions['sequence_labels'])

        # Make filenames UNIX-friendly.
        filename_chars = predictions['sequence_ids'].decode('utf-8')
        filename_chars = [c if c.isalnum() else '_' for c in filename_chars]
        filename_safe = ''.join(filename_chars).rstrip()
        filename_safe = '{:04d}_{}'.format(file_num, filename_safe[:200])
        file_num += 1
        output_file = os.path.join(output_dir, filename_safe + '.mid')
        tf.logging.info('Writing inferred midi file to %s', output_file)
        midi_io.sequence_proto_to_midi_file(sequence_prediction, output_file)

        label_output_file = os.path.join(output_dir,
                                         filename_safe + '_label.mid')
        tf.logging.info('Writing label midi file to %s', label_output_file)
        midi_io.sequence_proto_to_midi_file(sequence_label, label_output_file)

        # Also write a pianoroll showing acoustic model output vs labels.
        pianoroll_output_file = os.path.join(output_dir,
                                             filename_safe + '_pianoroll.png')
        tf.logging.info('Writing acoustic logit/label file to %s',
                        pianoroll_output_file)
        with tf.gfile.GFile(pianoroll_output_file, mode='w') as f:
            scipy.misc.imsave(
                f,
                infer_util.posterior_pianoroll_image(
                    predictions['frame_probs'], predictions['frame_labels']))

        # Update histogram and current scalar for metrics.
        with tf.Graph().as_default(), tf.Session().as_default():
            for k, v in predictions.items():
                if not k.startswith('metrics/'):
                    continue
                all_metrics[k].extend(v)
                histogram_name = 'histogram/' + k
                metric_summary = tf.summary.histogram(histogram_name,
                                                      tf.constant(
                                                          all_metrics[k],
                                                          name=histogram_name),
                                                      collections=[])
                summary_writer.add_summary(metric_summary.eval(),
                                           global_step=file_num)
                scalar_name = k
                metric_summary = tf.summary.scalar(scalar_name,
                                                   tf.constant(
                                                       np.mean(all_metrics[k]),
                                                       name=scalar_name),
                                                   collections=[])
                summary_writer.add_summary(metric_summary.eval(),
                                           global_step=file_num)
            summary_writer.flush()

        start_time = time.time()

    # Write final mean values for all metrics.
    with tf.Graph().as_default(), tf.Session().as_default():
        for k, v in all_metrics.items():
            final_scalar_name = 'final/' + k
            metric_summary = tf.summary.scalar(final_scalar_name,
                                               tf.constant(
                                                   np.mean(all_metrics[k]),
                                                   name=final_scalar_name),
                                               collections=[])
            summary_writer.add_summary(metric_summary.eval())
        summary_writer.flush()

    start_time = time.time()
Esempio n. 6
0
                                          ns=music_pb2.NoteSequence(),
                                          example_id='accompaniment.wav',
                                          min_length=0,
                                          max_length=-1,
                                          allow_empty_notesequence=True))

to_process.append(example_list[0].SerializeToString())
sess = tf.Session()

sess.run(
    [tf.initializers.global_variables(),
     tf.initializers.local_variables()])

sess.run(iterator.initializer, {examples: to_process})

input_fn = infer_util.labels_to_features_wrapper(transcription_data)
"""# Inference

Run the following cell to transcribe the files you uploaded. Each time it runs it will transcribe one of the uploaded files.
"""
#@title Run inference
prediction_list = list(estimator.predict(input_fn,
                                         yield_single_examples=False))
assert len(prediction_list) == 1
sequence_prediction = note_seq.NoteSequence.FromString(
    prediction_list[0]['sequence_predictions'][0])

#@title Download MIDI
midi_filename = ('transcription.mid')
midi_io.sequence_proto_to_midi_file(sequence_prediction, midi_filename)