コード例 #1
0
ファイル: test_module.py プロジェクト: Doddox-UV/torchsynth
    def test_noise(self):
        # Here we create there noise modules with different batch sizes.
        # We each noise module a number of times to obtain an equal number
        # of noise signals. All these noise samples should equal each other.
        # i.e., noise should be returned deterministically regardless of the
        # batch size.
        synthconfig32 = SynthConfig(32)
        synthconfig64 = SynthConfig(64)
        synthconfig128 = SynthConfig(128)

        noise32 = synthmodule.Noise(synthconfig32, seed=0)
        noise64 = synthmodule.Noise(synthconfig64, seed=0)
        noise128 = synthmodule.Noise(synthconfig128, seed=0)
        # A different seed should give a different result
        noise128_diff = synthmodule.Noise(synthconfig128, seed=42)

        out1 = torch.vstack((noise32(), noise32(), noise32(), noise32()))
        out2 = torch.vstack((noise64(), noise64()))
        out3 = noise128()
        out4 = noise128_diff()

        assert torch.all(out1 == out2)
        assert torch.all(out2 == out3)
        assert torch.all(out3 != out4)

        with pytest.raises(ValueError):
            # If the batch size is not a multiple of BASE_REPRODUCIBLE_BATCH_SIZE,
            # it should raise an error in reproducible mode
            synthconfigwrong = SynthConfig(BASE_REPRODUCIBLE_BATCH_SIZE + 1)
            synthmodule.Noise(synthconfigwrong, seed=0)
コード例 #2
0
    def run_batch_size_test(self, device):
        # Runs test for reproducibility across batch sizes on a device
        voice256 = Voice(SynthConfig(batch_size=256)).to(device)
        out256 = voice256(0)

        voice128 = Voice(SynthConfig(batch_size=128)).to(device)
        out128 = torch.vstack([voice128(0), voice128(1)])

        voice64 = Voice(SynthConfig(batch_size=64)).to(device)
        out64 = torch.vstack([voice64(0), voice64(1), voice64(2), voice64(3)])

        voice32 = Voice(SynthConfig(batch_size=32)).to(device)
        out32 = torch.vstack([
            voice32(0),
            voice32(1),
            voice32(2),
            voice32(3),
            voice32(4),
            voice32(5),
            voice32(6),
            voice32(7),
        ])

        # TODO there are some unexpected, very small numerical
        # errors between some, but not all, batch sizes
        # See https://github.com/torchsynth/torchsynth/issues/326
        assert torch.all(torch.isclose(out256, out128))
        assert torch.all(torch.isclose(out256, out64))
        assert torch.all(torch.isclose(out256, out32))
コード例 #3
0
ファイル: synth.py プロジェクト: Doddox-UV/torchsynth
 def __init__(self,
              synthconfig: Optional[SynthConfig] = None,
              *args,
              **kwargs):
     super().__init__(*args, **kwargs)
     if synthconfig is not None:
         self.synthconfig = synthconfig
     else:
         # Use the default
         self.synthconfig = SynthConfig()
コード例 #4
0
def test_synth_config():
    synthconfig = SynthConfig(DEFAULT_BATCH_SIZE)
    assert synthconfig.batch_size == DEFAULT_BATCH_SIZE

    # Test passing in specific values
    synthconfig = SynthConfig(
        batch_size=65,
        sample_rate=16000,
        buffer_size_seconds=0.5,
        control_rate=1000,
        reproducible=False,
    )
    assert synthconfig.control_rate == 1000
    assert synthconfig.sample_rate == 16000
    assert synthconfig.buffer_size_seconds == 0.5
    assert synthconfig.buffer_size == 8000
    assert synthconfig.control_buffer_size == 500
コード例 #5
0
ファイル: profile.py プロジェクト: Doddox-UV/torchsynth
def main():

    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("module", help="module to profile", type=str)
    parser.add_argument(
        "--batch-size",
        "-b",
        help="Batch size to run profiling at",
        type=int,
        default=64,
    )
    parser.add_argument("--num_batches",
                        "-n",
                        help="Number of batches to run",
                        type=int,
                        default=64)
    parser.add_argument(
        "--profile",
        "-p",
        help="Whether to run cProfile",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--save",
        "-s",
        help="File to save profiler results. If this is left out then profiling "
        "results are printed. ",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--device",
        "-d",
        help=
        "Device to run. Default is None which will select cuda if available, "
        "otherwise will run on cpu",
        type=str,
        default=None,
    )

    args = parser.parse_args()
    if args.save is not None and not args.profile:
        raise SystemExit(
            "Profile (-p) flag must be set in order for profile results to be saved"
        )

    # Try to create the synth module that is being profiled
    synthconfig = SynthConfig(args.batch_size, reproducible=False)
    module = instantiate_module(args.module, synthconfig)

    run_lightning_module(module, args.batch_size, args.num_batches, args.save,
                         args.profile, args.device)
コード例 #6
0
ファイル: test_module.py プロジェクト: Doddox-UV/torchsynth
    def test_properties(self):
        synthconfig = SynthConfig(2, reproducible=False)
        adsr = synthmodule.ADSR(synthconfig)

        # Sample rate and buffer size should raise errors
        with pytest.raises(NotImplementedError):
            adsr.sample_rate

        with pytest.raises(NotImplementedError):
            adsr.buffer_size

        expected_buffer = synthconfig.buffer_size / synthconfig.sample_rate
        expected_buffer *= adsr.control_rate
        assert adsr.control_buffer_size == expected_buffer
コード例 #7
0
    def test_voice_reproducibility(self):
        synthconfig = SynthConfig()
        voice_1 = Voice(synthconfig)
        voice_2 = Voice(synthconfig)

        # Randomly initialized voices will have different results
        with pytest.raises(AssertionError):
            self.compare_voices(voice_1, voice_2)

        # Seeding only one voice will also have different results
        voice_1.randomize(1)
        with pytest.raises(AssertionError):
            self.compare_voices(voice_1, voice_2)

        # Now seeding the second voice the same should be the same
        voice_2.randomize(1)
        self.compare_voices(voice_1, voice_2)

        # Running voice twice in a row with the same parameters should
        # lead to the same results
        voice_1.randomize(234)
        self.compare_voices(voice_1, voice_1)
コード例 #8
0
ファイル: synth.py プロジェクト: Doddox-UV/torchsynth
class AbstractSynth(LightningModule):
    """
    Base class for synthesizers that combine one or more SynthModules
    to create a full synth architecture.

    Args:
        sample_rate (int): sample rate to run this synth at
        buffer_size (int): number of samples expected at output of child modules
    """
    def __init__(self,
                 synthconfig: Optional[SynthConfig] = None,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        if synthconfig is not None:
            self.synthconfig = synthconfig
        else:
            # Use the default
            self.synthconfig = SynthConfig()

    @property
    def batch_size(self) -> T:
        assert self.synthconfig.batch_size.ndim == 0
        return self.synthconfig.batch_size

    @property
    def sample_rate(self) -> T:
        assert self.synthconfig.sample_rate.ndim == 0
        return self.synthconfig.sample_rate

    @property
    def buffer_size(self) -> T:
        assert self.synthconfig.buffer_size.ndim == 0
        return self.synthconfig.buffer_size

    @property
    def buffer_size_seconds(self) -> T:
        assert self.synthconfig.buffer_size_seconds.ndim == 0
        return self.synthconfig.buffer_size_seconds

    def add_synth_modules(self, modules: List[Tuple[str, SynthModule,
                                                    Optional[Dict[str,
                                                                  Any]]]]):
        """
        Add a set of named children TorchSynthModules to this synth. Registers them
        with the torch nn.Module so that all parameters are recognized.

        Args:
            modules List[Tuple[str, SynthModule, Optional[Dict[str, Any]]]]: A list of
                SynthModule classes with their names and any parameters to pass to
                their constructor.
        """

        for module_tuple in modules:
            if len(module_tuple) == 3:
                name, module, params = module_tuple
            else:
                name, module = module_tuple
                params = {}

            if not issubclass(module, SynthModule):
                raise TypeError(f"{module} is not a SynthModule")

            self.add_module(
                name, module(self.synthconfig, device=self.device, **params))

    def get_parameters(
        self,
        include_frozen: bool = False
    ) -> OrderedDictType[Tuple[str, str], ModuleParameter]:
        """
        Returns a dictionary of ModuleParameters for this synth keyed
        on a tuple of the SynthModule name and the parameter name
        """
        parameters = []

        # Each parameter in this synth will have a unique combination of module name
        # and parameter name -- create a dictionary keyed on that.
        for module_name, module in sorted(self.named_modules()):
            # Make sure this is a SynthModule, b/c we are using ParameterDict
            # and ParameterDict is a module, we get those returned as well
            # TODO: see https://github.com/torchsynth/torchsynth/issues/213
            if isinstance(module, SynthModule):
                for parameter in module.parameters():
                    if include_frozen or not ModuleParameter.is_parameter_frozen(
                            parameter):
                        parameters.append(
                            ((module_name, parameter.parameter_name),
                             parameter))

        return OrderedDict(parameters)

    def set_parameters(self,
                       params: Dict[Tuple, T],
                       freeze: Optional[bool] = False):
        """
        Set parameters for synth by passing in a dictionary of modules and parameters.
        Can optionally freeze a parameter at this value to prevent further updates.
        """
        for (module_name, param_name), value in params.items():
            module = getattr(self, module_name)
            module.set_parameter(param_name, value.to(self.device))
            # Freeze this parameter at this value now if freeze is True
            if freeze:
                module.get_parameter(param_name).frozen = True

    def set_frozen_parameters(self, params: Dict[Tuple, float]):
        """
        Sets specific parameters within this Synth. All params within the batch
        will be set to the same value and frozen to prevent further updates.
        """
        params = {
            key: tensor([value] * self.batch_size, device=self.device)
            for key, value in params.items()
        }
        self.set_parameters(params, freeze=True)

    def freeze_parameters(self, params: List[Tuple]):
        """
        Freeze a set of parameters by passing in a tuple of the module and param name
        """
        for module_name, param_name in params:
            module = getattr(self, module_name)
            module.get_parameter(param_name).frozen = True

    def unfreeze_all_parameters(self):
        """
        Unfreeze all parameters in this synth
        """
        for param in self.parameters():
            if isinstance(param, ModuleParameter):
                param.frozen = False

    def _forward(self, *args: Any,
                 **kwargs: Any) -> Signal:  # pragma: no cover
        """
        Each AbstractSynth should override this.
        """
        raise NotImplementedError("Derived classes must override this method")

    def forward(self,
                batch_idx: Optional[int] = None,
                *args: Any,
                **kwargs: Any) -> Signal:  # pragma: no cover
        """
        Each AbstractSynth should override this.

        Args:
            batch_idx (Optional[int])   - If provided, we set the parameters of this
                                    synth for reproducibility, in a deterministic
                                    random way. If None (default), we just use
                                    the current module parameter settings.
        """
        if self.synthconfig.reproducible and batch_idx is None:
            raise ValueError("Reproducible mode is on, you must "
                             "pass a batch index when calling this synth")
        if self.synthconfig.no_grad:
            with torch.no_grad():
                if batch_idx is not None:
                    self.randomize(seed=batch_idx)
                return self._forward(*args, **kwargs)
        else:
            if batch_idx is not None:
                self.randomize(seed=batch_idx)
            return self._forward(*args, **kwargs)

    def test_step(self, batch, batch_idx):
        """
        This is boilerplate for lightning -- this is required by lightning Trainer
        when calling test, which we use to forward Synths on multi-gpu platforms
        """
        return 0.0

    @property
    def hyperparameters(self) -> OrderedDictType[Tuple[str, str, str], Any]:
        """
        Returns a dictionary of curve and symmetry hyperparameter values keyed
        on a tuple of the module name, parameter name, and hyperparameter name
        """
        hparams = []
        for (module_name,
             parameter_name), parameter in self.get_parameters().items():
            hparams.append((
                (module_name, parameter_name, "curve"),
                parameter.parameter_range.curve,
            ))
            hparams.append((
                (module_name, parameter_name, "symmetric"),
                parameter.parameter_range.symmetric,
            ))

        return OrderedDict(hparams)

    def set_hyperparameter(self, hyperparameter: Tuple[str, str, str],
                           value: Any):
        """
        Set a hyperparameter. Pass in the module name, parameter name, and
        hyperparameter to set, and the value to set it to.
        """
        module = getattr(self, hyperparameter[0])
        parameter = module.get_parameter(hyperparameter[1])
        assert not ModuleParameter.is_parameter_frozen(parameter)
        setattr(parameter.parameter_range, hyperparameter[2], value)

    def save_hyperparameters(self, filename: str, indent=True) -> None:
        """
        Save hyperparameters to a JSON file
        """
        # Render all hyperparameters as JSON
        hp = [{
            "name": key,
            "value": val
        } for key, val in self.hyperparameters.items()]
        with open(os.path.abspath(filename), "w") as fp:
            json.dump(hp, fp, indent=indent)

    def load_hyperparameters(self, nebula: str) -> None:
        """
        Load hyperparameters from a JSON file

        Args:
            nebula: nebula to load. This can either be the name of a nebula that is
                included in torchsynth, or the filename of a nebula json file to load.

        TODO add nebula list in docs
        See https://github.com/torchsynth/torchsynth/issues/324
        """

        # Try to load nebulae from package resources, otherwise, try
        # to load from a filename
        try:
            synth = type(self).__name__.lower()
            nebulae_str = f"nebulae/{synth}/{nebula}.json"
            data = pkg_resources.resource_string(__name__, nebulae_str)
            hyperparameters = json.loads(data)
        except FileNotFoundError:
            with open(os.path.abspath(nebula), "r") as fp:
                hyperparameters = json.load(fp)

        # Update all hyperparameters in this synth
        for hp in hyperparameters:
            self.set_hyperparameter(hp["name"], hp["value"])

    def randomize(self, seed: Optional[int] = None):
        """
        Randomize all parameters
        """
        parameters = [param for _, param in sorted(self.named_parameters())]
        if seed is not None:
            # Generate batch_size x parameter number of random values
            # Reseed the random number generator for every item in the batch
            cpu_rng = torch.Generator(device="cpu")
            new_values = []
            for i in range(self.batch_size):
                cpu_rng.manual_seed(seed * self.batch_size.numpy().item() + i)
                new_values.append(
                    torch.rand((len(parameters), ),
                               device="cpu",
                               generator=cpu_rng))

            # Move to device if necessary
            new_values = torch.stack(new_values, dim=1)
            if self.device.type != "cpu":
                new_values = new_values.pin_memory().to(self.device,
                                                        non_blocking=True)

            # Set parameter data
            for i, parameter in enumerate(parameters):
                if not ModuleParameter.is_parameter_frozen(parameter):
                    parameter.data = new_values[i]
        else:
            assert not self.synthconfig.reproducible
            for parameter in parameters:
                if not ModuleParameter.is_parameter_frozen(parameter):
                    parameter.data.uniform_(0, 1)

        # Add seed to all modules
        for module in self._modules:
            self._modules[module].seed = seed

    def on_post_move_to_device(self) -> None:
        """
        LightningModule trigger after this Synth has been moved to a different device.
        Use this to update children SynthModules device settings
        """
        self.synthconfig.to(self.device)
        for module in self.modules():
            if isinstance(module, SynthModule):
                # TODO look into performance of calling to instead
                module.update_device(self.device)
コード例 #9
0
 def test_voice_nonreproducibility(self):
     with pytest.raises(ValueError):
         SynthConfig(batch_size=BASE_REPRODUCIBLE_BATCH_SIZE + 1)
コード例 #10
0
def test_synth_config_debug():
    synthconfig = SynthConfig()
    assert synthconfig.debug
コード例 #11
0
def stft_plot(signal, sample_rate=44100):
    if isnotebook():  # pragma: no cover
        X = librosa.stft(signal)
        Xdb = librosa.amplitude_to_db(abs(X))
        plt.figure(figsize=(5, 5))
        librosa.display.specshow(Xdb,
                                 sr=sample_rate,
                                 x_axis="time",
                                 y_axis="log")
        plt.show()


# ## Globals
# We'll generate 2 sounds at once, 4 seconds each
synthconfig = SynthConfig(batch_size=2,
                          reproducible=False,
                          sample_rate=44100,
                          buffer_size_seconds=4.0)

# For a few examples, we'll only generate one sound
synthconfig1 = SynthConfig(batch_size=1,
                           reproducible=False,
                           sample_rate=44100,
                           buffer_size_seconds=4.0)

# And a short one sound
synthconfig1short = SynthConfig(batch_size=1,
                                reproducible=False,
                                sample_rate=44100,
                                buffer_size_seconds=0.1)

# ## The Envelope