def __init__(self, units_per_beat=12, velocity_unit=4, use_velocity=True, default_velocity=127, use_all_off_event=False, use_drum_events=False, errors='remove', warn_on_errors=False): self._units_per_beat = units_per_beat self._velocity_unit = velocity_unit self._use_velocity = use_velocity self._use_all_off_event = use_all_off_event self._use_drum_events = use_drum_events self._errors = errors self._warn_on_errors = warn_on_errors wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i) for i in range(128)] + [('NoteOff', i) for i in range(128)] + ([('NoteOff', '*')] if use_all_off_event else []) + ([('DrumOn', i) for i in range(128)] + [('DrumOff', i) for i in range(128)] if use_drum_events else []) + [('SetTime', i) for i in range(units_per_beat)] + [('SetTimeNext', i) for i in range(units_per_beat)]) if use_velocity: max_velocity_units = (128 + velocity_unit - 1) // velocity_unit wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)]) self._default_velocity = default_velocity self.vocabulary = Vocabulary(wordlist)
def __init__(self, logdir, train_mode): set_random_seed(self._cfg.get('random_seed', None)) self.input_encoding = self._cfg['input_encoding'].configure() self.output_encoding = self._cfg['output_encoding'].configure() #with open(self._cfg.get('style_list')) as f: with open("./seq2seq/data/parallel/styles") as f: style_list = [line.rstrip('\n') for line in f] self.style_vocabulary = Vocabulary(style_list, pad_token=None, start_token=None, end_token=None) self.input_shapes = ([self.input_encoding.num_rows, None], [], [None], [None]) self.input_types = (tf.float32, tf.int32, tf.int32, tf.int32) self.dataset_manager = DatasetManager( output_types=self.input_types, output_shapes=tuple([None, *shape] for shape in self.input_shapes)) self.model = self._cfg['model'].configure( CNNRNNSeq2Seq, dataset_manager=self.dataset_manager, train_mode=train_mode, vocabulary=self.output_encoding.vocabulary, style_vocabulary=self.style_vocabulary) self.trainer = self._cfg['trainer'].configure( BasicTrainer, dataset_manager=self.dataset_manager, training_ops=self.model.training_ops, logdir=logdir, write_summaries=train_mode) self._load_data_kwargs = dict(input_encoding=self.input_encoding, output_encoding=self.output_encoding, style_vocabulary=self.style_vocabulary) if train_mode: # Configure the dataset manager with the training and validation data. self._cfg['data_prep'].configure( prepare_train_and_val_data, dataset_manager=self.dataset_manager, train_generator=self._cfg['train_data'].configure( load_data, log=True, **self._load_data_kwargs), val_generator=self._cfg['val_data'].configure( load_data, **self._load_data_kwargs), output_types=self.input_types, output_shapes=self.input_shapes)
def __init__(self, time_unit=0.01, max_shift_units=100, velocity_unit=4, use_velocity=True, use_all_off_event=False, use_drum_events=False, use_magenta=False, errors='remove', warn_on_errors=False): self._time_unit = time_unit self._max_shift_units = max_shift_units self._velocity_unit = velocity_unit self._use_velocity = use_velocity self._use_all_off_event = use_all_off_event self._use_drum_events = use_drum_events self._use_magenta = use_magenta self._errors = errors self._warn_on_errors = warn_on_errors if use_drum_events: assert use_magenta max_velocity_units = (128 + velocity_unit - 1) // velocity_unit wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i) for i in range(128)] + [('NoteOff', i) for i in range(128)] + ([('NoteOff', '*')] if use_all_off_event else []) + ([('DrumOn', i) for i in range(128)] + [('DrumOff', i) for i in range(128)] if use_drum_events else []) + [('TimeShift', i + 1) for i in range(max_shift_units)]) if use_velocity: wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)]) self._default_velocity = 0 else: self._default_velocity = 127 self.vocabulary = Vocabulary(wordlist)
def _init(cfg, logdir, train_mode, **kwargs): set_random_seed(cfg.get('random_seed', None)) encoding = cfg['encoding'].configure() with open(cfg.get('style_list')) as f: style_list = [line.rstrip('\n') for line in f] style_vocabulary = Vocabulary(style_list, pad_token=None, start_token=None, end_token=None) model = cfg['model'].configure(RNNSeq2Seq, train_mode=train_mode, vocabulary=encoding.vocabulary, style_vocabulary=style_vocabulary) trainer = cfg['trainer'].configure(BasicTrainer, dataset_manager=model.dataset_manager, training_ops=model.training_ops, logdir=logdir, write_summaries=train_mode) if train_mode: # Configure the dataset manager with the training and validation data. load_data_kwargs = dict(input_encoding=encoding, output_encoding=encoding, style_vocabulary=style_vocabulary) cfg['data_prep'].configure( prepare_train_and_val_data, dataset_manager=model.dataset_manager, train_generator=cfg['train_data'].configure(load_data, log=True, **load_data_kwargs), val_generator=cfg['val_data'].configure(load_data, **load_data_kwargs), output_types=(tf.int32, tf.int32, tf.int32, tf.int32), output_shapes=([None], [], [None], [None])) return model, trainer, encoding, style_vocabulary
class TranslationExperiment: def __init__(self, logdir, train_mode): set_random_seed(self._cfg.get('random_seed', None)) self.input_encoding = self._cfg['input_encoding'].configure() self.output_encoding = self._cfg['output_encoding'].configure() with open(self._cfg.get('style_list')) as f: style_list = [line.rstrip('\n') for line in f] self.style_vocabulary = Vocabulary(style_list, pad_token=None, start_token=None, end_token=None) self.input_shapes = ([self.input_encoding.num_rows, None], [], [None], [None]) self.input_types = (tf.float32, tf.int32, tf.int32, tf.int32) self.dataset_manager = DatasetManager( output_types=self.input_types, output_shapes=tuple([None, *shape] for shape in self.input_shapes)) self.model = self._cfg['model'].configure( CNNRNNSeq2Seq, dataset_manager=self.dataset_manager, train_mode=train_mode, vocabulary=self.output_encoding.vocabulary, style_vocabulary=self.style_vocabulary) self.trainer = self._cfg['trainer'].configure( BasicTrainer, dataset_manager=self.dataset_manager, training_ops=self.model.training_ops, logdir=logdir, write_summaries=train_mode) self._load_data_kwargs = dict(input_encoding=self.input_encoding, output_encoding=self.output_encoding, style_vocabulary=self.style_vocabulary) if train_mode: # Configure the dataset manager with the training and validation data. self._cfg['data_prep'].configure( prepare_train_and_val_data, dataset_manager=self.dataset_manager, train_generator=self._cfg['train_data'].configure( load_data, log=True, **self._load_data_kwargs), val_generator=self._cfg['val_data'].configure( load_data, **self._load_data_kwargs), output_types=self.input_types, output_shapes=self.input_shapes) def train(self, args): LOGGER.info('Starting training.') self.trainer.train() def run(self, args): self.trainer.load_variables(checkpoint_file=args.checkpoint) data = pickle.load(args.input_file) def generator(): style_id = self.style_vocabulary.to_id(args.target_style) for example in data: segment_id, notes = example yield self.input_encoding.encode(notes), style_id, [], [] dataset = make_simple_dataset(generator, output_types=self.input_types, output_shapes=self.input_shapes, batch_size=args.batch_size) output_ids = self.model.run(self.trainer.session, dataset, args.sample, args.softmax_temperature) outputs = [(segment_id, self.output_encoding.decode(seq)) for seq, (segment_id, _) in zip(output_ids, data)] pickle.dump(outputs, args.output_file)
class BeatRelativeEncoding: def __init__(self, units_per_beat=12, velocity_unit=4, use_velocity=True, default_velocity=127, use_all_off_event=False, use_drum_events=False, errors='remove', warn_on_errors=False): self._units_per_beat = units_per_beat self._velocity_unit = velocity_unit self._use_velocity = use_velocity self._use_all_off_event = use_all_off_event self._use_drum_events = use_drum_events self._errors = errors self._warn_on_errors = warn_on_errors wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i) for i in range(128)] + [('NoteOff', i) for i in range(128)] + ([('NoteOff', '*')] if use_all_off_event else []) + ([('DrumOn', i) for i in range(128)] + [('DrumOff', i) for i in range(128)] if use_drum_events else []) + [('SetTime', i) for i in range(units_per_beat)] + [('SetTimeNext', i) for i in range(units_per_beat)]) if use_velocity: max_velocity_units = (128 + velocity_unit - 1) // velocity_unit wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)]) self._default_velocity = default_velocity self.vocabulary = Vocabulary(wordlist) def encode(self, sequence, as_ids=True, add_start=False, add_end=False): sequence = note_sequence_utils.normalize_tempo(sequence) queue = _NoteEventQueue(sequence, quantization_step=1 / self._units_per_beat) events = [self.vocabulary.start_token] if add_start else [] last_beat = 0 last_t = 0 velocity_quantized = None for t, note, is_onset in queue: if t > last_t: beat = t // self._units_per_beat step_in_beat = t % self._units_per_beat while beat - last_beat > 1: # Skip to the beginning of the next beat events.append(('SetTimeNext', 0)) last_beat += 1 if beat == last_beat: events.append(('SetTime', step_in_beat)) else: # beat == last_beat + 1 events.append(('SetTimeNext', step_in_beat)) last_beat += 1 assert beat == last_beat last_t = t if is_onset: note_velocity = note.velocity if note_velocity > 127 or note_velocity < 1: warnings.warn(f'Invalid velocity value: {note_velocity}') note_velocity = self._default_velocity note_velocity_quantized = note_velocity // self._velocity_unit + 1 if velocity_quantized != note_velocity_quantized: velocity_quantized = note_velocity_quantized if self._use_velocity: events.append(('SetVelocity', velocity_quantized)) if note.is_drum and self._use_drum_events: events.append(('DrumOn', note.pitch)) else: events.append(('NoteOn', note.pitch)) else: if note.is_drum and self._use_drum_events: events.append(('DrumOff', note.pitch)) else: events.append(('NoteOff', note.pitch)) if self._use_all_off_event: events = _compress_note_offs(events) if add_end: events.append(self.vocabulary.end_token) if as_ids: return self.vocabulary.to_ids(events) return events def decode(self, tokens): sequence = music_pb2.NoteSequence() sequence.ticks_per_quarter = STANDARD_PPQ notes_on = defaultdict(list) error_count = 0 t = 0. current_beat = 0 velocity = self._default_velocity for token in tokens: if isinstance(token, (int, np.integer)): token = self.vocabulary.from_id(token) if token not in self.vocabulary: raise RuntimeError(f'Invalid token: {token}') if not isinstance(token, tuple): continue event, value = token if event in ['SetTime', 'SetTimeNext']: if event == 'SetTimeNext': current_beat += 1 new_t = current_beat + value / self._units_per_beat if new_t > t: t = new_t else: error_count += 1 continue if event == 'SetVelocity': velocity = (value - 1) * self._velocity_unit elif event in ['NoteOn', 'DrumOn']: note = sequence.notes.add() note.start_time = t note.pitch = value note.velocity = velocity note.is_drum = (event == 'DrumOn') notes_on[note.pitch].append(note) elif event in ['NoteOff', 'DrumOff']: if value == '*': assert self._use_all_off_event if not any(notes_on.values()): error_count += 1 for note_list in notes_on.values(): for note in note_list: note.end_time = t note_list.clear() else: try: note = notes_on[value].pop() note.end_time = t except IndexError: error_count += 1 sequence.total_time = t if error_count: self._log_errors('Encountered {} errors'.format(error_count)) # Handle hanging notes num_hanging = sum(len(lst) for lst in notes_on.values()) if any(notes_on.values()): if self._errors == 'remove': self._log_errors(f'Removing {num_hanging} hanging note(s)') notes_filtered = list(sequence.notes) for hanging_notes in notes_on.values(): notes_filtered = [n for n in notes_filtered if n not in hanging_notes] del sequence.notes[:] sequence.notes.extend(notes_filtered) else: # 'fix' self._log_errors(f'Ending {num_hanging} hanging note(s)') for hanging_notes in notes_on.values(): for note in hanging_notes: note.end_time = sequence.total_time return sequence def _log_errors(self, message): if self._warn_on_errors: warnings.warn(message, RuntimeWarning) else: _LOGGER.debug(message)
class PerformanceEncoding: """An encoding of note sequences based on Magenta's PerformanceRNN. This is actually very similar to how MIDI works. See https://magenta.tensorflow.org/performance-rnn. """ def __init__(self, time_unit=0.01, max_shift_units=100, velocity_unit=4, use_velocity=True, use_all_off_event=False, use_drum_events=False, use_magenta=False, errors='remove', warn_on_errors=False): self._time_unit = time_unit self._max_shift_units = max_shift_units self._velocity_unit = velocity_unit self._use_velocity = use_velocity self._use_all_off_event = use_all_off_event self._use_drum_events = use_drum_events self._use_magenta = use_magenta self._errors = errors self._warn_on_errors = warn_on_errors if use_drum_events: assert use_magenta max_velocity_units = (128 + velocity_unit - 1) // velocity_unit wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i) for i in range(128)] + [('NoteOff', i) for i in range(128)] + ([('NoteOff', '*')] if use_all_off_event else []) + ([('DrumOn', i) for i in range(128)] + [('DrumOff', i) for i in range(128)] if use_drum_events else []) + [('TimeShift', i + 1) for i in range(max_shift_units)]) if use_velocity: wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)]) self._default_velocity = 0 else: self._default_velocity = 127 self.vocabulary = Vocabulary(wordlist) def encode(self, notes, as_ids=True, add_start=False, add_end=False): is_drum = False if isinstance(notes, music_pb2.NoteSequence): is_drum = (len(notes.notes) > 0 and notes.notes[0].is_drum) notes = [ pretty_midi.Note(start=n.start_time, end=n.end_time, pitch=n.pitch, velocity=n.velocity) for n in notes.notes ] queue = _NoteEventQueue(notes, quantization_step=self._time_unit) events = [self.vocabulary.start_token] if add_start else [] last_t = 0 velocity = self._default_velocity for t, note, is_onset in queue: while last_t < t: shift_amount = min(t - last_t, self._max_shift_units) last_t += shift_amount events.append(('TimeShift', shift_amount)) if is_onset: note_velocity = note.velocity if note_velocity > 127 or note_velocity < 1: warnings.warn(f'Invalid velocity value: {note_velocity}') note_velocity = self._default_velocity if velocity != note_velocity: velocity = note_velocity if self._use_velocity: events.append(('SetVelocity', velocity // self._velocity_unit + 1)) if is_drum and self._use_drum_events: events.append(('DrumOn', note.pitch)) else: events.append(('NoteOn', note.pitch)) else: if is_drum and self._use_drum_events: events.append(('DrumOff', note.pitch)) else: events.append(('NoteOff', note.pitch)) if self._use_all_off_event: events = _compress_note_offs(events) if add_end: events.append(self.vocabulary.end_token) if as_ids: return self.vocabulary.to_ids(events) return events def decode(self, tokens): notes = [] notes_on = defaultdict(list) error_count = 0 is_drum = False t = 0 velocity = self._default_velocity for token in tokens: if isinstance(token, (int, np.integer)): token = self.vocabulary.from_id(token) if token not in self.vocabulary: raise RuntimeError(f'Invalid token: {token}') if not isinstance(token, tuple): continue event, value = token if event == 'TimeShift': t += value * self._time_unit elif event == 'SetVelocity': velocity = (value - 1) * self._velocity_unit elif event in ['NoteOn', 'DrumOn']: note = pretty_midi.Note(start=t, end=None, pitch=value, velocity=velocity) notes.append(note) notes_on[value].append(note) is_drum |= (event == 'DrumOn') elif event in ['NoteOff', 'DrumOff']: if value == '*': assert self._use_all_off_event if not any(notes_on.values()): error_count += 1 for note_list in notes_on.values(): for note in note_list: note.end = t note_list.clear() else: try: note = notes_on[value].pop() note.end = t except IndexError: error_count += 1 if error_count: self._log_errors('Encountered {} errors'.format(error_count)) if any(notes_on.values()): if self._errors == 'remove': self._log_errors('Removing {} hanging note(s)'.format( sum(len(l) for l in notes_on.values()))) for notes_on_list in notes_on.values(): for note in notes_on_list: notes.remove(note) else: # 'ignore' self._log_errors('Ignoring {} hanging note(s)'.format( sum(len(l) for l in notes_on.values()))) if self._use_magenta: sequence = music_pb2.NoteSequence() sequence.ticks_per_quarter = STANDARD_PPQ for note0 in notes: note = sequence.notes.add() note.start_time = note0.start note.end_time = note0.end note.pitch = note0.pitch note.velocity = note0.velocity note.is_drum = is_drum return sequence return notes def _log_errors(self, message): if self._warn_on_errors: warnings.warn(message, RuntimeWarning) else: logger.debug(message)