Example #1
0
    def __init__(self,
                 max_bars=None,
                 slice_bars=None,
                 gap_bars=1.0,
                 pitch_classes=None,
                 add_end_token=False,
                 steps_per_quarter=4,
                 quarters_per_bar=4,
                 pad_to_total_time=False,
                 roll_input=False,
                 roll_output=False,
                 max_tensors_per_notesequence=5,
                 presplit_on_time_changes=True):
        self._pitch_classes = pitch_classes or REDUCED_DRUM_PITCH_CLASSES
        self._pitch_class_map = {
            p: i
            for i, pitches in enumerate(self._pitch_classes) for p in pitches
        }

        self._steps_per_quarter = steps_per_quarter
        self._steps_per_bar = steps_per_quarter * quarters_per_bar
        self._slice_steps = self._steps_per_bar * slice_bars if slice_bars else None
        self._pad_to_total_time = pad_to_total_time
        self._roll_input = roll_input
        self._roll_output = roll_output

        self._drums_extractor_fn = functools.partial(
            mm.extract_drum_tracks,
            min_bars=1,
            gap_bars=gap_bars or float('inf'),
            max_steps_truncate=self._steps_per_bar *
            max_bars if max_bars else None,
            pad_end=True)

        num_classes = len(self._pitch_classes)

        self._pr_encoder_decoder = mm.PianorollEncoderDecoder(
            input_size=num_classes + add_end_token)
        # Use pitch classes as `drum_type_pitches` since we have already done the
        # mapping.
        self._oh_encoder_decoder = mm.MultiDrumOneHotEncoding(
            drum_type_pitches=[(i, ) for i in range(num_classes)])

        output_depth = (num_classes if self._roll_output else
                        self._oh_encoder_decoder.num_classes) + add_end_token
        super(DrumsConverter, self).__init__(
            input_depth=(num_classes + 1 if self._roll_input else
                         self._oh_encoder_decoder.num_classes) + add_end_token,
            input_dtype=np.bool,
            output_depth=output_depth,
            output_dtype=np.bool,
            end_token=output_depth - 1 if add_end_token else None,
            presplit_on_time_changes=presplit_on_time_changes,
            max_tensors_per_notesequence=max_tensors_per_notesequence)
Example #2
0
    def __init__(self,
                 max_bars=None,
                 slice_bars=None,
                 gap_bars=1.0,
                 steps_per_quarter=4,
                 quarters_per_bar=4,
                 pad_to_total_time=False,
                 add_end_token=False,
                 max_tensors_per_notesequence=5,
                 binary_input=False,
                 presplit_on_time_changes=True):
        steps_per_bar = steps_per_quarter * quarters_per_bar
        max_steps_truncate = steps_per_bar * max_bars if max_bars else None
        self._binary_input = binary_input

        def drumtrack_fn():
            return mm.DrumTrack(steps_per_bar=steps_per_bar,
                                steps_per_quarter=steps_per_quarter)

        drums_extractor_fn = functools.partial(
            mm.extract_drum_tracks,
            min_bars=1,
            gap_bars=gap_bars or float('inf'),
            max_steps_truncate=max_steps_truncate,
            pad_end=True)
        super(OneHotDrumsConverter, self).__init__(
            drumtrack_fn,
            drums_extractor_fn,
            mm.MultiDrumOneHotEncoding(),
            add_end_token=add_end_token,
            slice_bars=slice_bars,
            pad_to_total_time=pad_to_total_time,
            steps_per_quarter=steps_per_quarter,
            quarters_per_bar=quarters_per_bar,
            max_tensors_per_notesequence=max_tensors_per_notesequence,
            presplit_on_time_changes=presplit_on_time_changes)

        if binary_input:
            # Binary input tensor is log_2 of the output depth after removing the end
            # token plus an additional dimension to represent the "empty" label.
            self._input_depth = (
                int(np.log2(self._output_depth - add_end_token)) + 1 +
                add_end_token)