コード例 #1
0
    def __init__(self, units_per_beat=12, velocity_unit=4, use_velocity=True, default_velocity=127,
                 use_all_off_event=False, use_drum_events=False, errors='remove',
                 warn_on_errors=False):

        self._units_per_beat = units_per_beat
        self._velocity_unit = velocity_unit
        self._use_velocity = use_velocity
        self._use_all_off_event = use_all_off_event
        self._use_drum_events = use_drum_events
        self._errors = errors
        self._warn_on_errors = warn_on_errors

        wordlist = (['<pad>', '<s>', '</s>'] +
                    [('NoteOn', i) for i in range(128)] +
                    [('NoteOff', i) for i in range(128)] +
                    ([('NoteOff', '*')] if use_all_off_event else []) +
                    ([('DrumOn', i) for i in range(128)] +
                     [('DrumOff', i) for i in range(128)]
                     if use_drum_events else []) +
                    [('SetTime', i) for i in range(units_per_beat)] +
                    [('SetTimeNext', i) for i in range(units_per_beat)])

        if use_velocity:
            max_velocity_units = (128 + velocity_unit - 1) // velocity_unit
            wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)])
        self._default_velocity = default_velocity

        self.vocabulary = Vocabulary(wordlist)
コード例 #2
0
    def __init__(self, logdir, train_mode):
        set_random_seed(self._cfg.get('random_seed', None))

        self.input_encoding = self._cfg['input_encoding'].configure()
        self.output_encoding = self._cfg['output_encoding'].configure()
        #with open(self._cfg.get('style_list')) as f:
        with open("./seq2seq/data/parallel/styles") as f:
            style_list = [line.rstrip('\n') for line in f]
        self.style_vocabulary = Vocabulary(style_list,
                                           pad_token=None,
                                           start_token=None,
                                           end_token=None)

        self.input_shapes = ([self.input_encoding.num_rows,
                              None], [], [None], [None])
        self.input_types = (tf.float32, tf.int32, tf.int32, tf.int32)
        self.dataset_manager = DatasetManager(
            output_types=self.input_types,
            output_shapes=tuple([None, *shape] for shape in self.input_shapes))

        self.model = self._cfg['model'].configure(
            CNNRNNSeq2Seq,
            dataset_manager=self.dataset_manager,
            train_mode=train_mode,
            vocabulary=self.output_encoding.vocabulary,
            style_vocabulary=self.style_vocabulary)
        self.trainer = self._cfg['trainer'].configure(
            BasicTrainer,
            dataset_manager=self.dataset_manager,
            training_ops=self.model.training_ops,
            logdir=logdir,
            write_summaries=train_mode)

        self._load_data_kwargs = dict(input_encoding=self.input_encoding,
                                      output_encoding=self.output_encoding,
                                      style_vocabulary=self.style_vocabulary)

        if train_mode:
            # Configure the dataset manager with the training and validation data.
            self._cfg['data_prep'].configure(
                prepare_train_and_val_data,
                dataset_manager=self.dataset_manager,
                train_generator=self._cfg['train_data'].configure(
                    load_data, log=True, **self._load_data_kwargs),
                val_generator=self._cfg['val_data'].configure(
                    load_data, **self._load_data_kwargs),
                output_types=self.input_types,
                output_shapes=self.input_shapes)
コード例 #3
0
    def __init__(self,
                 time_unit=0.01,
                 max_shift_units=100,
                 velocity_unit=4,
                 use_velocity=True,
                 use_all_off_event=False,
                 use_drum_events=False,
                 use_magenta=False,
                 errors='remove',
                 warn_on_errors=False):

        self._time_unit = time_unit
        self._max_shift_units = max_shift_units
        self._velocity_unit = velocity_unit
        self._use_velocity = use_velocity
        self._use_all_off_event = use_all_off_event
        self._use_drum_events = use_drum_events
        self._use_magenta = use_magenta
        self._errors = errors
        self._warn_on_errors = warn_on_errors

        if use_drum_events:
            assert use_magenta

        max_velocity_units = (128 + velocity_unit - 1) // velocity_unit

        wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i)
                                                for i in range(128)] +
                    [('NoteOff', i) for i in range(128)] +
                    ([('NoteOff', '*')] if use_all_off_event else []) +
                    ([('DrumOn', i) for i in range(128)] +
                     [('DrumOff', i)
                      for i in range(128)] if use_drum_events else []) +
                    [('TimeShift', i + 1) for i in range(max_shift_units)])

        if use_velocity:
            wordlist.extend([('SetVelocity', i + 1)
                             for i in range(max_velocity_units)])
            self._default_velocity = 0
        else:
            self._default_velocity = 127

        self.vocabulary = Vocabulary(wordlist)
def _init(cfg, logdir, train_mode, **kwargs):
    set_random_seed(cfg.get('random_seed', None))

    encoding = cfg['encoding'].configure()
    with open(cfg.get('style_list')) as f:
        style_list = [line.rstrip('\n') for line in f]
    style_vocabulary = Vocabulary(style_list,
                                  pad_token=None,
                                  start_token=None,
                                  end_token=None)

    model = cfg['model'].configure(RNNSeq2Seq,
                                   train_mode=train_mode,
                                   vocabulary=encoding.vocabulary,
                                   style_vocabulary=style_vocabulary)
    trainer = cfg['trainer'].configure(BasicTrainer,
                                       dataset_manager=model.dataset_manager,
                                       training_ops=model.training_ops,
                                       logdir=logdir,
                                       write_summaries=train_mode)

    if train_mode:
        # Configure the dataset manager with the training and validation data.
        load_data_kwargs = dict(input_encoding=encoding,
                                output_encoding=encoding,
                                style_vocabulary=style_vocabulary)
        cfg['data_prep'].configure(
            prepare_train_and_val_data,
            dataset_manager=model.dataset_manager,
            train_generator=cfg['train_data'].configure(load_data,
                                                        log=True,
                                                        **load_data_kwargs),
            val_generator=cfg['val_data'].configure(load_data,
                                                    **load_data_kwargs),
            output_types=(tf.int32, tf.int32, tf.int32, tf.int32),
            output_shapes=([None], [], [None], [None]))

    return model, trainer, encoding, style_vocabulary
class TranslationExperiment:
    def __init__(self, logdir, train_mode):
        set_random_seed(self._cfg.get('random_seed', None))

        self.input_encoding = self._cfg['input_encoding'].configure()
        self.output_encoding = self._cfg['output_encoding'].configure()
        with open(self._cfg.get('style_list')) as f:
            style_list = [line.rstrip('\n') for line in f]
        self.style_vocabulary = Vocabulary(style_list,
                                           pad_token=None,
                                           start_token=None,
                                           end_token=None)

        self.input_shapes = ([self.input_encoding.num_rows,
                              None], [], [None], [None])
        self.input_types = (tf.float32, tf.int32, tf.int32, tf.int32)
        self.dataset_manager = DatasetManager(
            output_types=self.input_types,
            output_shapes=tuple([None, *shape] for shape in self.input_shapes))

        self.model = self._cfg['model'].configure(
            CNNRNNSeq2Seq,
            dataset_manager=self.dataset_manager,
            train_mode=train_mode,
            vocabulary=self.output_encoding.vocabulary,
            style_vocabulary=self.style_vocabulary)
        self.trainer = self._cfg['trainer'].configure(
            BasicTrainer,
            dataset_manager=self.dataset_manager,
            training_ops=self.model.training_ops,
            logdir=logdir,
            write_summaries=train_mode)

        self._load_data_kwargs = dict(input_encoding=self.input_encoding,
                                      output_encoding=self.output_encoding,
                                      style_vocabulary=self.style_vocabulary)

        if train_mode:
            # Configure the dataset manager with the training and validation data.
            self._cfg['data_prep'].configure(
                prepare_train_and_val_data,
                dataset_manager=self.dataset_manager,
                train_generator=self._cfg['train_data'].configure(
                    load_data, log=True, **self._load_data_kwargs),
                val_generator=self._cfg['val_data'].configure(
                    load_data, **self._load_data_kwargs),
                output_types=self.input_types,
                output_shapes=self.input_shapes)

    def train(self, args):
        LOGGER.info('Starting training.')
        self.trainer.train()

    def run(self, args):
        self.trainer.load_variables(checkpoint_file=args.checkpoint)
        data = pickle.load(args.input_file)

        def generator():
            style_id = self.style_vocabulary.to_id(args.target_style)
            for example in data:
                segment_id, notes = example
                yield self.input_encoding.encode(notes), style_id, [], []

        dataset = make_simple_dataset(generator,
                                      output_types=self.input_types,
                                      output_shapes=self.input_shapes,
                                      batch_size=args.batch_size)

        output_ids = self.model.run(self.trainer.session, dataset, args.sample,
                                    args.softmax_temperature)
        outputs = [(segment_id, self.output_encoding.decode(seq))
                   for seq, (segment_id, _) in zip(output_ids, data)]

        pickle.dump(outputs, args.output_file)
コード例 #6
0
class BeatRelativeEncoding:

    def __init__(self, units_per_beat=12, velocity_unit=4, use_velocity=True, default_velocity=127,
                 use_all_off_event=False, use_drum_events=False, errors='remove',
                 warn_on_errors=False):

        self._units_per_beat = units_per_beat
        self._velocity_unit = velocity_unit
        self._use_velocity = use_velocity
        self._use_all_off_event = use_all_off_event
        self._use_drum_events = use_drum_events
        self._errors = errors
        self._warn_on_errors = warn_on_errors

        wordlist = (['<pad>', '<s>', '</s>'] +
                    [('NoteOn', i) for i in range(128)] +
                    [('NoteOff', i) for i in range(128)] +
                    ([('NoteOff', '*')] if use_all_off_event else []) +
                    ([('DrumOn', i) for i in range(128)] +
                     [('DrumOff', i) for i in range(128)]
                     if use_drum_events else []) +
                    [('SetTime', i) for i in range(units_per_beat)] +
                    [('SetTimeNext', i) for i in range(units_per_beat)])

        if use_velocity:
            max_velocity_units = (128 + velocity_unit - 1) // velocity_unit
            wordlist.extend([('SetVelocity', i + 1) for i in range(max_velocity_units)])
        self._default_velocity = default_velocity

        self.vocabulary = Vocabulary(wordlist)

    def encode(self, sequence, as_ids=True, add_start=False, add_end=False):
        sequence = note_sequence_utils.normalize_tempo(sequence)

        queue = _NoteEventQueue(sequence, quantization_step=1 / self._units_per_beat)
        events = [self.vocabulary.start_token] if add_start else []

        last_beat = 0
        last_t = 0
        velocity_quantized = None
        for t, note, is_onset in queue:
            if t > last_t:
                beat = t // self._units_per_beat
                step_in_beat = t % self._units_per_beat

                while beat - last_beat > 1:
                    # Skip to the beginning of the next beat
                    events.append(('SetTimeNext', 0))
                    last_beat += 1

                if beat == last_beat:
                    events.append(('SetTime', step_in_beat))
                else:  # beat == last_beat + 1
                    events.append(('SetTimeNext', step_in_beat))
                    last_beat += 1
                assert beat == last_beat

                last_t = t

            if is_onset:
                note_velocity = note.velocity
                if note_velocity > 127 or note_velocity < 1:
                    warnings.warn(f'Invalid velocity value: {note_velocity}')
                    note_velocity = self._default_velocity
                note_velocity_quantized = note_velocity // self._velocity_unit + 1
                if velocity_quantized != note_velocity_quantized:
                    velocity_quantized = note_velocity_quantized
                    if self._use_velocity:
                        events.append(('SetVelocity', velocity_quantized))

                if note.is_drum and self._use_drum_events:
                    events.append(('DrumOn', note.pitch))
                else:
                    events.append(('NoteOn', note.pitch))
            else:
                if note.is_drum and self._use_drum_events:
                    events.append(('DrumOff', note.pitch))
                else:
                    events.append(('NoteOff', note.pitch))

        if self._use_all_off_event:
            events = _compress_note_offs(events)

        if add_end:
            events.append(self.vocabulary.end_token)

        if as_ids:
            return self.vocabulary.to_ids(events)
        return events

    def decode(self, tokens):
        sequence = music_pb2.NoteSequence()
        sequence.ticks_per_quarter = STANDARD_PPQ

        notes_on = defaultdict(list)
        error_count = 0

        t = 0.
        current_beat = 0
        velocity = self._default_velocity
        for token in tokens:
            if isinstance(token, (int, np.integer)):
                token = self.vocabulary.from_id(token)
            if token not in self.vocabulary:
                raise RuntimeError(f'Invalid token: {token}')
            if not isinstance(token, tuple):
                continue
            event, value = token

            if event in ['SetTime', 'SetTimeNext']:
                if event == 'SetTimeNext':
                    current_beat += 1
                new_t = current_beat + value / self._units_per_beat
                if new_t > t:
                    t = new_t
                else:
                    error_count += 1
                continue

            if event == 'SetVelocity':
                velocity = (value - 1) * self._velocity_unit
            elif event in ['NoteOn', 'DrumOn']:
                note = sequence.notes.add()
                note.start_time = t
                note.pitch = value
                note.velocity = velocity
                note.is_drum = (event == 'DrumOn')
                notes_on[note.pitch].append(note)
            elif event in ['NoteOff', 'DrumOff']:
                if value == '*':
                    assert self._use_all_off_event

                    if not any(notes_on.values()):
                        error_count += 1

                    for note_list in notes_on.values():
                        for note in note_list:
                            note.end_time = t
                        note_list.clear()
                else:
                    try:
                        note = notes_on[value].pop()
                        note.end_time = t
                    except IndexError:
                        error_count += 1
        sequence.total_time = t

        if error_count:
            self._log_errors('Encountered {} errors'.format(error_count))

        # Handle hanging notes
        num_hanging = sum(len(lst) for lst in notes_on.values())
        if any(notes_on.values()):
            if self._errors == 'remove':
                self._log_errors(f'Removing {num_hanging} hanging note(s)')
                notes_filtered = list(sequence.notes)
                for hanging_notes in notes_on.values():
                    notes_filtered = [n for n in notes_filtered if n not in hanging_notes]
                del sequence.notes[:]
                sequence.notes.extend(notes_filtered)
            else:  # 'fix'
                self._log_errors(f'Ending {num_hanging} hanging note(s)')
                for hanging_notes in notes_on.values():
                    for note in hanging_notes:
                        note.end_time = sequence.total_time

        return sequence

    def _log_errors(self, message):
        if self._warn_on_errors:
            warnings.warn(message, RuntimeWarning)
        else:
            _LOGGER.debug(message)
コード例 #7
0
class PerformanceEncoding:
    """An encoding of note sequences based on Magenta's PerformanceRNN.

    This is actually very similar to how MIDI works.
    See https://magenta.tensorflow.org/performance-rnn.
    """
    def __init__(self,
                 time_unit=0.01,
                 max_shift_units=100,
                 velocity_unit=4,
                 use_velocity=True,
                 use_all_off_event=False,
                 use_drum_events=False,
                 use_magenta=False,
                 errors='remove',
                 warn_on_errors=False):

        self._time_unit = time_unit
        self._max_shift_units = max_shift_units
        self._velocity_unit = velocity_unit
        self._use_velocity = use_velocity
        self._use_all_off_event = use_all_off_event
        self._use_drum_events = use_drum_events
        self._use_magenta = use_magenta
        self._errors = errors
        self._warn_on_errors = warn_on_errors

        if use_drum_events:
            assert use_magenta

        max_velocity_units = (128 + velocity_unit - 1) // velocity_unit

        wordlist = (['<pad>', '<s>', '</s>'] + [('NoteOn', i)
                                                for i in range(128)] +
                    [('NoteOff', i) for i in range(128)] +
                    ([('NoteOff', '*')] if use_all_off_event else []) +
                    ([('DrumOn', i) for i in range(128)] +
                     [('DrumOff', i)
                      for i in range(128)] if use_drum_events else []) +
                    [('TimeShift', i + 1) for i in range(max_shift_units)])

        if use_velocity:
            wordlist.extend([('SetVelocity', i + 1)
                             for i in range(max_velocity_units)])
            self._default_velocity = 0
        else:
            self._default_velocity = 127

        self.vocabulary = Vocabulary(wordlist)

    def encode(self, notes, as_ids=True, add_start=False, add_end=False):
        is_drum = False
        if isinstance(notes, music_pb2.NoteSequence):
            is_drum = (len(notes.notes) > 0 and notes.notes[0].is_drum)
            notes = [
                pretty_midi.Note(start=n.start_time,
                                 end=n.end_time,
                                 pitch=n.pitch,
                                 velocity=n.velocity) for n in notes.notes
            ]

        queue = _NoteEventQueue(notes, quantization_step=self._time_unit)
        events = [self.vocabulary.start_token] if add_start else []

        last_t = 0
        velocity = self._default_velocity
        for t, note, is_onset in queue:
            while last_t < t:
                shift_amount = min(t - last_t, self._max_shift_units)
                last_t += shift_amount
                events.append(('TimeShift', shift_amount))

            if is_onset:
                note_velocity = note.velocity
                if note_velocity > 127 or note_velocity < 1:
                    warnings.warn(f'Invalid velocity value: {note_velocity}')
                    note_velocity = self._default_velocity
                if velocity != note_velocity:
                    velocity = note_velocity
                    if self._use_velocity:
                        events.append(('SetVelocity',
                                       velocity // self._velocity_unit + 1))
                if is_drum and self._use_drum_events:
                    events.append(('DrumOn', note.pitch))
                else:
                    events.append(('NoteOn', note.pitch))
            else:
                if is_drum and self._use_drum_events:
                    events.append(('DrumOff', note.pitch))
                else:
                    events.append(('NoteOff', note.pitch))

        if self._use_all_off_event:
            events = _compress_note_offs(events)

        if add_end:
            events.append(self.vocabulary.end_token)

        if as_ids:
            return self.vocabulary.to_ids(events)
        return events

    def decode(self, tokens):
        notes = []
        notes_on = defaultdict(list)
        error_count = 0
        is_drum = False

        t = 0
        velocity = self._default_velocity
        for token in tokens:
            if isinstance(token, (int, np.integer)):
                token = self.vocabulary.from_id(token)
            if token not in self.vocabulary:
                raise RuntimeError(f'Invalid token: {token}')
            if not isinstance(token, tuple):
                continue
            event, value = token

            if event == 'TimeShift':
                t += value * self._time_unit
            elif event == 'SetVelocity':
                velocity = (value - 1) * self._velocity_unit
            elif event in ['NoteOn', 'DrumOn']:
                note = pretty_midi.Note(start=t,
                                        end=None,
                                        pitch=value,
                                        velocity=velocity)
                notes.append(note)
                notes_on[value].append(note)
                is_drum |= (event == 'DrumOn')
            elif event in ['NoteOff', 'DrumOff']:
                if value == '*':
                    assert self._use_all_off_event

                    if not any(notes_on.values()):
                        error_count += 1

                    for note_list in notes_on.values():
                        for note in note_list:
                            note.end = t
                        note_list.clear()
                else:
                    try:
                        note = notes_on[value].pop()
                        note.end = t
                    except IndexError:
                        error_count += 1

        if error_count:
            self._log_errors('Encountered {} errors'.format(error_count))

        if any(notes_on.values()):
            if self._errors == 'remove':
                self._log_errors('Removing {} hanging note(s)'.format(
                    sum(len(l) for l in notes_on.values())))
                for notes_on_list in notes_on.values():
                    for note in notes_on_list:
                        notes.remove(note)
            else:  # 'ignore'
                self._log_errors('Ignoring {} hanging note(s)'.format(
                    sum(len(l) for l in notes_on.values())))

        if self._use_magenta:
            sequence = music_pb2.NoteSequence()
            sequence.ticks_per_quarter = STANDARD_PPQ
            for note0 in notes:
                note = sequence.notes.add()
                note.start_time = note0.start
                note.end_time = note0.end
                note.pitch = note0.pitch
                note.velocity = note0.velocity
                note.is_drum = is_drum
            return sequence

        return notes

    def _log_errors(self, message):
        if self._warn_on_errors:
            warnings.warn(message, RuntimeWarning)
        else:
            logger.debug(message)