コード例 #1
0
    def __init__(self,
                 max_bars=None,
                 slice_bars=None,
                 gap_bars=1.0,
                 pitch_classes=None,
                 add_end_token=False,
                 steps_per_quarter=4,
                 quarters_per_bar=4,
                 pad_to_total_time=False,
                 roll_input=False,
                 roll_output=False,
                 max_tensors_per_notesequence=5,
                 presplit_on_time_changes=True):
        self._pitch_classes = pitch_classes or REDUCED_DRUM_PITCH_CLASSES
        self._pitch_class_map = {
            p: i
            for i, pitches in enumerate(self._pitch_classes) for p in pitches
        }

        self._steps_per_quarter = steps_per_quarter
        self._steps_per_bar = steps_per_quarter * quarters_per_bar
        self._slice_steps = self._steps_per_bar * slice_bars if slice_bars else None
        self._pad_to_total_time = pad_to_total_time
        self._roll_input = roll_input
        self._roll_output = roll_output

        self._drums_extractor_fn = functools.partial(
            mm.extract_drum_tracks,
            min_bars=1,
            gap_bars=gap_bars or float('inf'),
            max_steps_truncate=self._steps_per_bar *
            max_bars if max_bars else None,
            pad_end=True)

        num_classes = len(self._pitch_classes)

        self._pr_encoder_decoder = mm.PianorollEncoderDecoder(
            input_size=num_classes + add_end_token)
        # Use pitch classes as `drum_type_pitches` since we have already done the
        # mapping.
        self._oh_encoder_decoder = mm.MultiDrumOneHotEncoding(
            drum_type_pitches=[(i, ) for i in range(num_classes)])

        output_depth = (num_classes if self._roll_output else
                        self._oh_encoder_decoder.num_classes) + add_end_token
        super(DrumsConverter, self).__init__(
            input_depth=(num_classes + 1 if self._roll_input else
                         self._oh_encoder_decoder.num_classes) + add_end_token,
            input_dtype=np.bool,
            output_depth=output_depth,
            output_dtype=np.bool,
            end_token=output_depth - 1 if add_end_token else None,
            presplit_on_time_changes=presplit_on_time_changes,
            max_tensors_per_notesequence=max_tensors_per_notesequence)
コード例 #2
0
 def setUp(self):
   super(PianorollPipelineTest, self).setUp()
   self.config = events_rnn_model.EventSequenceRnnConfig(
       None, mm.PianorollEncoderDecoder(88), contrib_training.HParams())
コード例 #3
0
      primer track).
    """
        return self._generate_events(num_steps=num_steps,
                                     primer_events=primer_sequence,
                                     temperature=None,
                                     beam_size=beam_size,
                                     branch_factor=branch_factor,
                                     steps_per_iteration=steps_per_iteration)


default_configs = {
    'rnn-nade':
    events_rnn_model.EventSequenceRnnConfig(
        magenta.protobuf.generator_pb2.GeneratorDetails(
            id='rnn-nade', description='RNN-NADE'),
        mm.PianorollEncoderDecoder(),
        tf.contrib.training.HParams(batch_size=64,
                                    rnn_layer_sizes=[128, 128, 128],
                                    nade_hidden_units=128,
                                    dropout_keep_prob=0.5,
                                    clip_norm=5,
                                    learning_rate=0.001)),
    'rnn-nade_attn':
    events_rnn_model.EventSequenceRnnConfig(
        magenta.protobuf.generator_pb2.GeneratorDetails(
            id='rnn-nade_attn', description='RNN-NADE with attention.'),
        mm.PianorollEncoderDecoder(),
        tf.contrib.training.HParams(batch_size=48,
                                    rnn_layer_sizes=[128, 128],
                                    attn_length=32,
                                    nade_hidden_units=128,
コード例 #4
0
 def setUp(self):
     self.config = events_rnn_model.EventSequenceRnnConfig(
         None, mm.PianorollEncoderDecoder(88),
         tf.contrib.training.HParams())
コード例 #5
0
 def setUp(self):
     self.config = events_rnn_model.EventSequenceRnnConfig(
         None, mm.PianorollEncoderDecoder(88), magenta.common.HParams())