Esempio n. 1
0
class BeatRelativeEncoding:

    def __init__(self, units_per_beat=12, velocity_unit=4, use_velocity=True, default_velocity=127,
                 use_all_off_event=False, use_drum_events=False, errors='remove',
                 warn_on_errors=False):

        self._units_per_beat = units_per_beat
        self._velocity_unit = velocity_unit
        self._use_velocity = use_velocity
        self._use_all_off_event = use_all_off_event
        self._use_drum_events = use_drum_events
        self._errors = errors
        self._warn_on_errors = warn_on_errors

        wordlist = (['<pad>', '<s>', '</s>'] +
                    [('NoteOn', i) for i in range(128)] +
                    [('NoteOff', i) for i in range(128)] +
                    ([('NoteOff', '*')] if use_all_off_event else []) +
                    ([('DrumOn', i) for i in range(128)] +
                     [('DrumOff', i) for i in range(128)]
                     if use_drum_events else []) +
                    [('SetTime', i) for i in range(units_per_beat)] +
                    [('SetTimeNext', i) for i in range(units_per_beat)])

        if use_velocity:
            max_velocity_units = (128 + velocity_unit - 1) // velocity_unit
            wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)])
        self._default_velocity = default_velocity

        self.vocabulary = Vocabulary(wordlist)

    def encode(self, sequence, as_ids=True, add_start=False, add_end=False):
        sequence = note_sequence_utils.normalize_tempo(sequence)

        queue = _NoteEventQueue(sequence, quantization_step=1 / self._units_per_beat)
        events = [self.vocabulary.start_token] if add_start else []

        last_beat = 0
        last_t = 0
        velocity_quantized = None
        for t, note, is_onset in queue:
            if t > last_t:
                beat = t // self._units_per_beat
                step_in_beat = t % self._units_per_beat

                while beat - last_beat > 1:
                    # Skip to the beginning of the next beat
                    events.append(('SetTimeNext', 0))
                    last_beat += 1

                if beat == last_beat:
                    events.append(('SetTime', step_in_beat))
                else:  # beat == last_beat + 1
                    events.append(('SetTimeNext', step_in_beat))
                    last_beat += 1
                assert beat == last_beat

                last_t = t

            if is_onset:
                note_velocity = note.velocity
                if note_velocity > 127 or note_velocity < 1:
                    warnings.warn(f'Invalid velocity value: {note_velocity}')
                    note_velocity = self._default_velocity
                note_velocity_quantized = note_velocity // self._velocity_unit + 1
                if velocity_quantized != note_velocity_quantized:
                    velocity_quantized = note_velocity_quantized
                    if self._use_velocity:
                        events.append(('SetVelocity', velocity_quantized))

                if note.is_drum and self._use_drum_events:
                    events.append(('DrumOn', note.pitch))
                else:
                    events.append(('NoteOn', note.pitch))
            else:
                if note.is_drum and self._use_drum_events:
                    events.append(('DrumOff', note.pitch))
                else:
                    events.append(('NoteOff', note.pitch))

        if self._use_all_off_event:
            events = _compress_note_offs(events)

        if add_end:
            events.append(self.vocabulary.end_token)

        if as_ids:
            return self.vocabulary.to_ids(events)
        return events

    def decode(self, tokens):
        sequence = music_pb2.NoteSequence()
        sequence.ticks_per_quarter = STANDARD_PPQ

        notes_on = defaultdict(list)
        error_count = 0

        t = 0.
        current_beat = 0
        velocity = self._default_velocity
        for token in tokens:
            if isinstance(token, (int, np.integer)):
                token = self.vocabulary.from_id(token)
            if token not in self.vocabulary:
                raise RuntimeError(f'Invalid token: {token}')
            if not isinstance(token, tuple):
                continue
            event, value = token

            if event in ['SetTime', 'SetTimeNext']:
                if event == 'SetTimeNext':
                    current_beat += 1
                new_t = current_beat + value / self._units_per_beat
                if new_t > t:
                    t = new_t
                else:
                    error_count += 1
                continue

            if event == 'SetVelocity':
                velocity = (value - 1) * self._velocity_unit
            elif event in ['NoteOn', 'DrumOn']:
                note = sequence.notes.add()
                note.start_time = t
                note.pitch = value
                note.velocity = velocity
                note.is_drum = (event == 'DrumOn')
                notes_on[note.pitch].append(note)
            elif event in ['NoteOff', 'DrumOff']:
                if value == '*':
                    assert self._use_all_off_event

                    if not any(notes_on.values()):
                        error_count += 1

                    for note_list in notes_on.values():
                        for note in note_list:
                            note.end_time = t
                        note_list.clear()
                else:
                    try:
                        note = notes_on[value].pop()
                        note.end_time = t
                    except IndexError:
                        error_count += 1
        sequence.total_time = t

        if error_count:
            self._log_errors('Encountered {} errors'.format(error_count))

        # Handle hanging notes
        num_hanging = sum(len(lst) for lst in notes_on.values())
        if any(notes_on.values()):
            if self._errors == 'remove':
                self._log_errors(f'Removing {num_hanging} hanging note(s)')
                notes_filtered = list(sequence.notes)
                for hanging_notes in notes_on.values():
                    notes_filtered = [n for n in notes_filtered if n not in hanging_notes]
                del sequence.notes[:]
                sequence.notes.extend(notes_filtered)
            else:  # 'fix'
                self._log_errors(f'Ending {num_hanging} hanging note(s)')
                for hanging_notes in notes_on.values():
                    for note in hanging_notes:
                        note.end_time = sequence.total_time

        return sequence

    def _log_errors(self, message):
        if self._warn_on_errors:
            warnings.warn(message, RuntimeWarning)
        else:
            _LOGGER.debug(message)
Esempio n. 2
0
class PerformanceEncoding:
    """An encoding of note sequences based on Magenta's PerformanceRNN.

    This is actually very similar to how MIDI works.
    See https://magenta.tensorflow.org/performance-rnn.
    """
    def __init__(self,
                 time_unit=0.01,
                 max_shift_units=100,
                 velocity_unit=4,
                 use_velocity=True,
                 use_all_off_event=False,
                 use_drum_events=False,
                 use_magenta=False,
                 errors='remove',
                 warn_on_errors=False):

        self._time_unit = time_unit
        self._max_shift_units = max_shift_units
        self._velocity_unit = velocity_unit
        self._use_velocity = use_velocity
        self._use_all_off_event = use_all_off_event
        self._use_drum_events = use_drum_events
        self._use_magenta = use_magenta
        self._errors = errors
        self._warn_on_errors = warn_on_errors

        if use_drum_events:
            assert use_magenta

        max_velocity_units = (128 + velocity_unit - 1) // velocity_unit

        wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i)
                                                for i in range(128)] +
                    [('NoteOff', i) for i in range(128)] +
                    ([('NoteOff', '*')] if use_all_off_event else []) +
                    ([('DrumOn', i) for i in range(128)] +
                     [('DrumOff', i)
                      for i in range(128)] if use_drum_events else []) +
                    [('TimeShift', i + 1) for i in range(max_shift_units)])

        if use_velocity:
            wordlist.extend([('SetVelocity', i + 1)
                             for i in range(max_velocity_units)])
            self._default_velocity = 0
        else:
            self._default_velocity = 127

        self.vocabulary = Vocabulary(wordlist)

    def encode(self, notes, as_ids=True, add_start=False, add_end=False):
        is_drum = False
        if isinstance(notes, music_pb2.NoteSequence):
            is_drum = (len(notes.notes) > 0 and notes.notes[0].is_drum)
            notes = [
                pretty_midi.Note(start=n.start_time,
                                 end=n.end_time,
                                 pitch=n.pitch,
                                 velocity=n.velocity) for n in notes.notes
            ]

        queue = _NoteEventQueue(notes, quantization_step=self._time_unit)
        events = [self.vocabulary.start_token] if add_start else []

        last_t = 0
        velocity = self._default_velocity
        for t, note, is_onset in queue:
            while last_t < t:
                shift_amount = min(t - last_t, self._max_shift_units)
                last_t += shift_amount
                events.append(('TimeShift', shift_amount))

            if is_onset:
                note_velocity = note.velocity
                if note_velocity > 127 or note_velocity < 1:
                    warnings.warn(f'Invalid velocity value: {note_velocity}')
                    note_velocity = self._default_velocity
                if velocity != note_velocity:
                    velocity = note_velocity
                    if self._use_velocity:
                        events.append(('SetVelocity',
                                       velocity // self._velocity_unit + 1))
                if is_drum and self._use_drum_events:
                    events.append(('DrumOn', note.pitch))
                else:
                    events.append(('NoteOn', note.pitch))
            else:
                if is_drum and self._use_drum_events:
                    events.append(('DrumOff', note.pitch))
                else:
                    events.append(('NoteOff', note.pitch))

        if self._use_all_off_event:
            events = _compress_note_offs(events)

        if add_end:
            events.append(self.vocabulary.end_token)

        if as_ids:
            return self.vocabulary.to_ids(events)
        return events

    def decode(self, tokens):
        notes = []
        notes_on = defaultdict(list)
        error_count = 0
        is_drum = False

        t = 0
        velocity = self._default_velocity
        for token in tokens:
            if isinstance(token, (int, np.integer)):
                token = self.vocabulary.from_id(token)
            if token not in self.vocabulary:
                raise RuntimeError(f'Invalid token: {token}')
            if not isinstance(token, tuple):
                continue
            event, value = token

            if event == 'TimeShift':
                t += value * self._time_unit
            elif event == 'SetVelocity':
                velocity = (value - 1) * self._velocity_unit
            elif event in ['NoteOn', 'DrumOn']:
                note = pretty_midi.Note(start=t,
                                        end=None,
                                        pitch=value,
                                        velocity=velocity)
                notes.append(note)
                notes_on[value].append(note)
                is_drum |= (event == 'DrumOn')
            elif event in ['NoteOff', 'DrumOff']:
                if value == '*':
                    assert self._use_all_off_event

                    if not any(notes_on.values()):
                        error_count += 1

                    for note_list in notes_on.values():
                        for note in note_list:
                            note.end = t
                        note_list.clear()
                else:
                    try:
                        note = notes_on[value].pop()
                        note.end = t
                    except IndexError:
                        error_count += 1

        if error_count:
            self._log_errors('Encountered {} errors'.format(error_count))

        if any(notes_on.values()):
            if self._errors == 'remove':
                self._log_errors('Removing {} hanging note(s)'.format(
                    sum(len(l) for l in notes_on.values())))
                for notes_on_list in notes_on.values():
                    for note in notes_on_list:
                        notes.remove(note)
            else:  # 'ignore'
                self._log_errors('Ignoring {} hanging note(s)'.format(
                    sum(len(l) for l in notes_on.values())))

        if self._use_magenta:
            sequence = music_pb2.NoteSequence()
            sequence.ticks_per_quarter = STANDARD_PPQ
            for note0 in notes:
                note = sequence.notes.add()
                note.start_time = note0.start
                note.end_time = note0.end
                note.pitch = note0.pitch
                note.velocity = note0.velocity
                note.is_drum = is_drum
            return sequence

        return notes

    def _log_errors(self, message):
        if self._warn_on_errors:
            warnings.warn(message, RuntimeWarning)
        else:
            logger.debug(message)