Beispiel #1
0
def get_config(batch_size, data_path):
    return configs.Config(
        model=MusicVAE(lstm_models.BidirectionalLstmEncoder(),
                       lstm_models.CategoricalLstmDecoder()),
        hparams=merge_hparams(
            lstm_models.get_default_hparams(),
            HParams(
                batch_size=512,
                max_seq_len=32,  # 2 bars w/ 16 steps per bar
                z_size=512,
                enc_rnn_size=[2048],
                dec_rnn_size=[2048, 2048, 2048],
                free_bits=0,
                max_beta=0.5,
                beta_rate=0.99999,
                sampling_schedule='inverse_sigmoid',
                sampling_rate=1000,
            )),
        note_sequence_augmenter=data.NoteSequenceAugmenter(
            transpose_range=(-5, 5)),
        data_converter=data.OneHotMelodyConverter(
            valid_programs=data.MEL_PROGRAMS,
            skip_polyphony=False,
            max_bars=100,  # Truncate long melodies before slicing.
            slice_bars=2,
            steps_per_quarter=4),
        train_examples_path=data_path,
        eval_examples_path=data_path,
    )
Beispiel #2
0
    def testTfAugment(self):
        augmenter = data.NoteSequenceAugmenter(transpose_range=(-3, -3),
                                               stretch_range=(2.0, 2.0))
        augmented_sequence = augmenter.augment(self.sequence)
        with self.test_session() as sess:
            sequence_str = tf.placeholder(tf.string)
            augmented_sequence_str_ = augmenter.tf_augment(sequence_str)
            augmented_sequence_str = sess.run(
                [augmented_sequence_str_],
                feed_dict={sequence_str: self.sequence.SerializeToString()})
        augmented_sequence = music_pb2.NoteSequence.FromString(
            augmented_sequence_str[0])

        expected_sequence = music_pb2.NoteSequence()
        expected_sequence.tempos.add(qpm=30)
        testing_lib.add_track_to_sequence(expected_sequence, 0,
                                          [(29, 100, 4, 8), (30, 100, 12, 22),
                                           (31, 100, 22, 26),
                                           (32, 100, 34, 36)])
        testing_lib.add_track_to_sequence(expected_sequence,
                                          1, [(57, 80, 8, 8.2),
                                              (58, 80, 24, 24.2)],
                                          is_drum=True)
        testing_lib.add_chords_to_sequence(expected_sequence, [('N.C.', 0),
                                                               ('A', 16),
                                                               ('Gbm', 32)])

        self.assertEqual(expected_sequence, augmented_sequence)
Beispiel #3
0
  def testAugmentStretch(self):
    augmenter = data.NoteSequenceAugmenter(stretch_range=(0.5, 0.5))
    augmented_sequence = augmenter.augment(self.sequence)

    expected_sequence = music_pb2.NoteSequence()
    expected_sequence.tempos.add(qpm=120)
    testing_lib.add_track_to_sequence(
        expected_sequence, 0,
        [(32, 100, 1, 2), (33, 100, 3, 5.5), (34, 100, 5.5, 6.5),
         (35, 100, 8.5, 9)])
    testing_lib.add_track_to_sequence(
        expected_sequence, 1, [(57, 80, 2, 2.05), (58, 80, 6, 6.05)],
        is_drum=True)
    testing_lib.add_chords_to_sequence(
        expected_sequence, [('N.C.', 0), ('C', 4), ('Am', 8)])

    self.assertEqual(expected_sequence, augmented_sequence)
Beispiel #4
0
  def testAugmentTranspose(self):
    augmenter = data.NoteSequenceAugmenter(transpose_range=(2, 2))
    augmented_sequence = augmenter.augment(self.sequence)

    expected_sequence = music_pb2.NoteSequence()
    expected_sequence.tempos.add(qpm=60)
    testing_lib.add_track_to_sequence(
        expected_sequence, 0,
        [(34, 100, 2, 4), (35, 100, 6, 11), (36, 100, 11, 13),
         (37, 100, 17, 18)])
    testing_lib.add_track_to_sequence(
        expected_sequence, 1, [(57, 80, 4, 4.1), (58, 80, 12, 12.1)],
        is_drum=True)
    testing_lib.add_chords_to_sequence(
        expected_sequence, [('N.C.', 0), ('D', 8), ('Bm', 16)])

    self.assertEqual(expected_sequence, augmented_sequence)
Beispiel #5
0
                   lstm_models.CategoricalLstmDecoder()),
    hparams=merge_hparams(
        lstm_models.get_default_hparams(),
        HParams(
            batch_size=512,
            max_seq_len=32,  # 2 bars w/ 16 steps per bar
            z_size=256,
            enc_rnn_size=[512],
            dec_rnn_size=[256, 256],
            free_bits=0,
            max_beta=0.2,
            beta_rate=0.99999,
            sampling_schedule='inverse_sigmoid',
            sampling_rate=1000,
        )),
    note_sequence_augmenter=data.NoteSequenceAugmenter(transpose_range=(-5,
                                                                        5)),
    data_converter=data.OneHotMelodyConverter(
        valid_programs=data.MEL_PROGRAMS,
        skip_polyphony=False,
        max_bars=100,  # Truncate long melodies before slicing.
        slice_bars=2,
        steps_per_quarter=4),
    train_examples_path=None,
    eval_examples_path=None,
)

CONFIG_MAP['cat-mel_2bar_big'] = Config(
    model=MusicVAE(lstm_models.BidirectionalLstmEncoder(),
                   lstm_models.CategoricalLstmDecoder()),
    hparams=merge_hparams(
        lstm_models.get_default_hparams(),