示例#1
0
    def testToNoteSequence(self):
        converter = data.TrioConverter(steps_per_quarter=1,
                                       slice_bars=2,
                                       max_tensors_per_notesequence=1)

        mel_oh = data.np_onehot(self.expected_sliced_labels[3][0], 90)
        bass_oh = data.np_onehot(self.expected_sliced_labels[3][1], 90)
        drums_oh = data.np_onehot(self.expected_sliced_labels[3][2], 512)
        output_tensors = np.concatenate([mel_oh, bass_oh, drums_oh], axis=-1)

        sequences = converter.to_notesequences([output_tensors])
        self.assertEqual(1, len(sequences))

        self.assertProtoEquals(
            """
        ticks_per_quarter: 220
        tempos < qpm: 120 >
        notes <
          instrument: 0 pitch: 52 start_time: 2.0 end_time: 4.0 program: 0
          velocity: 80
        >
        notes <
          instrument: 1 pitch: 50 start_time: 1.0 end_time: 2.5 program: 33
          velocity: 80
        >
        notes <
          instrument: 9 pitch: 36 start_time: 0.0 end_time: 0.5 velocity: 80
          is_drum: True
        >
        notes <
          instrument: 9 pitch: 38 start_time: 2.0 end_time: 2.5 velocity: 80
          is_drum: True
        >
        total_time: 4.0
        """, sequences[0])
示例#2
0
  def testSliced(self):
    converter = data.TrioConverter(
        steps_per_quarter=1, gap_bars=1, slice_bars=2,
        max_tensors_per_notesequence=None)
    in_tensors, out_tensors = converter.to_tensors(self.sequence)
    self.assertArraySetsEqual(in_tensors, out_tensors)
    actual_sliced_labels = [
        np.stack(np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1))
        for t in out_tensors]

    self.assertArraySetsEqual(self.expected_sliced_labels, actual_sliced_labels)
示例#3
0
  def testSlicedChordConditioned(self):
    converter = data.TrioConverter(
        steps_per_quarter=1, gap_bars=1, slice_bars=2,
        max_tensors_per_notesequence=None,
        chord_encoding=mm.MajorMinorChordOneHotEncoding())
    tensors = converter.to_tensors(self.sequence)
    self.assertArraySetsEqual(tensors.inputs, tensors.outputs)
    actual_sliced_labels = [
        np.stack(np.argmax(s, axis=-1) for s in np.split(t, [90, 180], axis=-1))
        for t in tensors.outputs]
    actual_sliced_chord_labels = [
        np.argmax(t, axis=-1) for t in tensors.controls]

    self.assertArraySetsEqual(self.expected_sliced_labels, actual_sliced_labels)
    self.assertArraySetsEqual(
        self.expected_sliced_chord_labels, actual_sliced_chord_labels)
示例#4
0
    def testToNoteSequenceChordConditioned(self):
        converter = data.TrioConverter(
            steps_per_quarter=1,
            slice_bars=2,
            max_tensors_per_notesequence=1,
            chord_encoding=mm.MajorMinorChordOneHotEncoding())

        mel_oh = data.np_onehot(self.expected_sliced_labels[3][0], 90)
        bass_oh = data.np_onehot(self.expected_sliced_labels[3][1], 90)
        drums_oh = data.np_onehot(self.expected_sliced_labels[3][2], 512)
        chords_oh = data.np_onehot(self.expected_sliced_chord_labels[3], 25)

        output_tensors = np.concatenate([mel_oh, bass_oh, drums_oh], axis=-1)

        sequences = converter.to_notesequences([output_tensors], [chords_oh])
        self.assertEqual(1, len(sequences))

        self.assertProtoEquals(
            """
        ticks_per_quarter: 220
        tempos < qpm: 120 >
        notes <
          instrument: 0 pitch: 52 start_time: 2.0 end_time: 4.0 program: 0
          velocity: 80
        >
        notes <
          instrument: 1 pitch: 50 start_time: 1.0 end_time: 2.5 program: 33
          velocity: 80
        >
        notes <
          instrument: 9 pitch: 36 start_time: 0.0 end_time: 0.5 velocity: 80
          is_drum: True
        >
        notes <
          instrument: 9 pitch: 38 start_time: 2.0 end_time: 2.5 velocity: 80
          is_drum: True
        >
        text_annotations <
          text: 'N.C.' annotation_type: CHORD_SYMBOL
        >
        text_annotations <
          time: 2.0 text: 'C' annotation_type: CHORD_SYMBOL
        >
        total_time: 4.0
        """, sequences[0])
示例#5
0
        )),
    note_sequence_augmenter=None,
    data_converter=data.DrumsConverter(
        max_bars=100,  # Truncate long drum sequences before slicing.
        pitch_classes=data.FULL_DRUM_PITCH_CLASSES,
        slice_bars=2,
        steps_per_quarter=4,
        roll_input=True,
        roll_output=True),
    train_examples_path=None,
    eval_examples_path=None,
)

# Trio Models
trio_16bar_converter = data.TrioConverter(steps_per_quarter=4,
                                          slice_bars=16,
                                          gap_bars=2)

CONFIG_MAP['flat-trio_16bar'] = Config(
    model=MusicVAE(
        lstm_models.BidirectionalLstmEncoder(),
        lstm_models.MultiOutCategoricalLstmDecoder(output_depths=[
            90,  # melody
            90,  # bass
            512,  # drums
        ])),
    hparams=merge_hparams(
        lstm_models.get_default_hparams(),
        HParams(
            batch_size=256,
            max_seq_len=256,
示例#6
0
            90,  # melody
            90,  # bass
            512,  # drums
        ])),
    hparams=merge_hparams(
        lstm_models.get_default_hparams(),
        HParams(
            batch_size=256,
            max_seq_len=256,
            z_size=512,
            enc_rnn_size=[2048, 2048],
            dec_rnn_size=[2048, 2048, 2048],
        )),
    note_sequence_augmenter=None,
    note_sequence_converter=data.TrioConverter(steps_per_quarter=4,
                                               slice_bars=16,
                                               gap_bars=2),
    train_examples_path=None,
    eval_examples_path=None,
)

config_map['hiercat-trio_16bar_big'] = Config(
    model=MusicVAE(
        lstm_models.BidirectionalLstmEncoder(),
        lstm_models.HierarchicalMultiOutLstmDecoder(
            core_decoders=[
                lstm_models.CategoricalLstmDecoder(),
                lstm_models.CategoricalLstmDecoder(),
                lstm_models.CategoricalLstmDecoder()
            ],
            output_depths=[