class BeatRelativeEncoding: def __init__(self, units_per_beat=12, velocity_unit=4, use_velocity=True, default_velocity=127, use_all_off_event=False, use_drum_events=False, errors='remove', warn_on_errors=False): self._units_per_beat = units_per_beat self._velocity_unit = velocity_unit self._use_velocity = use_velocity self._use_all_off_event = use_all_off_event self._use_drum_events = use_drum_events self._errors = errors self._warn_on_errors = warn_on_errors wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i) for i in range(128)] + [('NoteOff', i) for i in range(128)] + ([('NoteOff', '*')] if use_all_off_event else []) + ([('DrumOn', i) for i in range(128)] + [('DrumOff', i) for i in range(128)] if use_drum_events else []) + [('SetTime', i) for i in range(units_per_beat)] + [('SetTimeNext', i) for i in range(units_per_beat)]) if use_velocity: max_velocity_units = (128 + velocity_unit - 1) // velocity_unit wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)]) self._default_velocity = default_velocity self.vocabulary = Vocabulary(wordlist) def encode(self, sequence, as_ids=True, add_start=False, add_end=False): sequence = note_sequence_utils.normalize_tempo(sequence) queue = _NoteEventQueue(sequence, quantization_step=1 / self._units_per_beat) events = [self.vocabulary.start_token] if add_start else [] last_beat = 0 last_t = 0 velocity_quantized = None for t, note, is_onset in queue: if t > last_t: beat = t // self._units_per_beat step_in_beat = t % self._units_per_beat while beat - last_beat > 1: # Skip to the beginning of the next beat events.append(('SetTimeNext', 0)) last_beat += 1 if beat == last_beat: events.append(('SetTime', step_in_beat)) else: # beat == last_beat + 1 events.append(('SetTimeNext', step_in_beat)) last_beat += 1 assert beat == last_beat last_t = t if is_onset: note_velocity = note.velocity if note_velocity > 127 or note_velocity < 1: warnings.warn(f'Invalid velocity value: {note_velocity}') note_velocity = self._default_velocity note_velocity_quantized = note_velocity // self._velocity_unit + 1 if velocity_quantized != note_velocity_quantized: velocity_quantized = note_velocity_quantized if self._use_velocity: events.append(('SetVelocity', velocity_quantized)) if note.is_drum and self._use_drum_events: events.append(('DrumOn', note.pitch)) else: events.append(('NoteOn', note.pitch)) else: if note.is_drum and self._use_drum_events: events.append(('DrumOff', note.pitch)) else: events.append(('NoteOff', note.pitch)) if self._use_all_off_event: events = _compress_note_offs(events) if add_end: events.append(self.vocabulary.end_token) if as_ids: return self.vocabulary.to_ids(events) return events def decode(self, tokens): sequence = music_pb2.NoteSequence() sequence.ticks_per_quarter = STANDARD_PPQ notes_on = defaultdict(list) error_count = 0 t = 0. current_beat = 0 velocity = self._default_velocity for token in tokens: if isinstance(token, (int, np.integer)): token = self.vocabulary.from_id(token) if token not in self.vocabulary: raise RuntimeError(f'Invalid token: {token}') if not isinstance(token, tuple): continue event, value = token if event in ['SetTime', 'SetTimeNext']: if event == 'SetTimeNext': current_beat += 1 new_t = current_beat + value / self._units_per_beat if new_t > t: t = new_t else: error_count += 1 continue if event == 'SetVelocity': velocity = (value - 1) * self._velocity_unit elif event in ['NoteOn', 'DrumOn']: note = sequence.notes.add() note.start_time = t note.pitch = value note.velocity = velocity note.is_drum = (event == 'DrumOn') notes_on[note.pitch].append(note) elif event in ['NoteOff', 'DrumOff']: if value == '*': assert self._use_all_off_event if not any(notes_on.values()): error_count += 1 for note_list in notes_on.values(): for note in note_list: note.end_time = t note_list.clear() else: try: note = notes_on[value].pop() note.end_time = t except IndexError: error_count += 1 sequence.total_time = t if error_count: self._log_errors('Encountered {} errors'.format(error_count)) # Handle hanging notes num_hanging = sum(len(lst) for lst in notes_on.values()) if any(notes_on.values()): if self._errors == 'remove': self._log_errors(f'Removing {num_hanging} hanging note(s)') notes_filtered = list(sequence.notes) for hanging_notes in notes_on.values(): notes_filtered = [n for n in notes_filtered if n not in hanging_notes] del sequence.notes[:] sequence.notes.extend(notes_filtered) else: # 'fix' self._log_errors(f'Ending {num_hanging} hanging note(s)') for hanging_notes in notes_on.values(): for note in hanging_notes: note.end_time = sequence.total_time return sequence def _log_errors(self, message): if self._warn_on_errors: warnings.warn(message, RuntimeWarning) else: _LOGGER.debug(message)
class PerformanceEncoding: """An encoding of note sequences based on Magenta's PerformanceRNN. This is actually very similar to how MIDI works. See https://magenta.tensorflow.org/performance-rnn. """ def __init__(self, time_unit=0.01, max_shift_units=100, velocity_unit=4, use_velocity=True, use_all_off_event=False, use_drum_events=False, use_magenta=False, errors='remove', warn_on_errors=False): self._time_unit = time_unit self._max_shift_units = max_shift_units self._velocity_unit = velocity_unit self._use_velocity = use_velocity self._use_all_off_event = use_all_off_event self._use_drum_events = use_drum_events self._use_magenta = use_magenta self._errors = errors self._warn_on_errors = warn_on_errors if use_drum_events: assert use_magenta max_velocity_units = (128 + velocity_unit - 1) // velocity_unit wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i) for i in range(128)] + [('NoteOff', i) for i in range(128)] + ([('NoteOff', '*')] if use_all_off_event else []) + ([('DrumOn', i) for i in range(128)] + [('DrumOff', i) for i in range(128)] if use_drum_events else []) + [('TimeShift', i + 1) for i in range(max_shift_units)]) if use_velocity: wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)]) self._default_velocity = 0 else: self._default_velocity = 127 self.vocabulary = Vocabulary(wordlist) def encode(self, notes, as_ids=True, add_start=False, add_end=False): is_drum = False if isinstance(notes, music_pb2.NoteSequence): is_drum = (len(notes.notes) > 0 and notes.notes[0].is_drum) notes = [ pretty_midi.Note(start=n.start_time, end=n.end_time, pitch=n.pitch, velocity=n.velocity) for n in notes.notes ] queue = _NoteEventQueue(notes, quantization_step=self._time_unit) events = [self.vocabulary.start_token] if add_start else [] last_t = 0 velocity = self._default_velocity for t, note, is_onset in queue: while last_t < t: shift_amount = min(t - last_t, self._max_shift_units) last_t += shift_amount events.append(('TimeShift', shift_amount)) if is_onset: note_velocity = note.velocity if note_velocity > 127 or note_velocity < 1: warnings.warn(f'Invalid velocity value: {note_velocity}') note_velocity = self._default_velocity if velocity != note_velocity: velocity = note_velocity if self._use_velocity: events.append(('SetVelocity', velocity // self._velocity_unit + 1)) if is_drum and self._use_drum_events: events.append(('DrumOn', note.pitch)) else: events.append(('NoteOn', note.pitch)) else: if is_drum and self._use_drum_events: events.append(('DrumOff', note.pitch)) else: events.append(('NoteOff', note.pitch)) if self._use_all_off_event: events = _compress_note_offs(events) if add_end: events.append(self.vocabulary.end_token) if as_ids: return self.vocabulary.to_ids(events) return events def decode(self, tokens): notes = [] notes_on = defaultdict(list) error_count = 0 is_drum = False t = 0 velocity = self._default_velocity for token in tokens: if isinstance(token, (int, np.integer)): token = self.vocabulary.from_id(token) if token not in self.vocabulary: raise RuntimeError(f'Invalid token: {token}') if not isinstance(token, tuple): continue event, value = token if event == 'TimeShift': t += value * self._time_unit elif event == 'SetVelocity': velocity = (value - 1) * self._velocity_unit elif event in ['NoteOn', 'DrumOn']: note = pretty_midi.Note(start=t, end=None, pitch=value, velocity=velocity) notes.append(note) notes_on[value].append(note) is_drum |= (event == 'DrumOn') elif event in ['NoteOff', 'DrumOff']: if value == '*': assert self._use_all_off_event if not any(notes_on.values()): error_count += 1 for note_list in notes_on.values(): for note in note_list: note.end = t note_list.clear() else: try: note = notes_on[value].pop() note.end = t except IndexError: error_count += 1 if error_count: self._log_errors('Encountered {} errors'.format(error_count)) if any(notes_on.values()): if self._errors == 'remove': self._log_errors('Removing {} hanging note(s)'.format( sum(len(l) for l in notes_on.values()))) for notes_on_list in notes_on.values(): for note in notes_on_list: notes.remove(note) else: # 'ignore' self._log_errors('Ignoring {} hanging note(s)'.format( sum(len(l) for l in notes_on.values()))) if self._use_magenta: sequence = music_pb2.NoteSequence() sequence.ticks_per_quarter = STANDARD_PPQ for note0 in notes: note = sequence.notes.add() note.start_time = note0.start note.end_time = note0.end note.pitch = note0.pitch note.velocity = note0.velocity note.is_drum = is_drum return sequence return notes def _log_errors(self, message): if self._warn_on_errors: warnings.warn(message, RuntimeWarning) else: logger.debug(message)