Beispiel #1
0
  def __call__(self, *inputs, **kwargs):
    """Wrap the layer's __call__() with dictionary inputs and outputs.

    IMPORTANT: If no input_keys are provided to the constructor, they are
    inferred from the argument names in call(). If no output_keys are provided
    to the constructor, they are inferred from return annotation of call()
    (a list of strings).

    Example:
    ```
    def call(self, f0_hz, loudness) -> ['amps', 'frequencies']:
      ...
      return amps, frequencies
    ```
    Will infer `self.input_keys = ['f0_hz', 'loudness']` and
    `self.output_keys = ['amps', 'frequencies']`. If input_keys, or output_keys
    are provided to the constructor they will override these inferred values.

    Example Usage:
    The the example above works with both tensor inputs `layer(f0_hz, loudness)`
    or a dictionary of tensors `layer({'f0_hz':..., 'loudness':...})`, and in
    both cases will return a dictionary of tensors
    `{'amps':..., 'frequencies':...}`.

    Args:
      *inputs: Arguments passed on to call(). If any arguments are dicts, they
        will be merged and self.input_keys will be read out of them and passed
        to call() while other args will be ignored.
      **kwargs: Keyword arguments passed on to call().

    Returns:
      outputs: A dictionary of layer outputs from call(). If the layer call()
        returns a dictionary it will be returned directly, otherwise the output
        tensors will be wrapped in a dictionary {output_key: output_tensor}.
    """
    # Merge all dictionaries provided in inputs.
    input_dict = {}
    for v in inputs:
      if isinstance(v, dict):
        input_dict.update(v)

    # If any dicts provided, lookup input tensors from those dicts.
    # Otherwise, just use inputs list as input tensors.
    if input_dict:
      inputs = [core.nested_lookup(key, input_dict) for key in self.input_keys]

    # Run input tensors through the model.
    outputs = super().__call__(*inputs, **kwargs)

    # Return dict if call() returns it.
    if isinstance(outputs, dict):
      return outputs
    # Otherwise make a dict from output_keys.
    else:
      outputs = core.make_iterable(outputs)
      if len(self.output_keys) != len(outputs):
        raise ValueError(f'Output keys ({self.output_keys}) must have the same'
                         f'length as outputs ({outputs})')
      return dict(zip(self.output_keys, outputs))
Beispiel #2
0
    def __init__(self,
                 rnn_channels=512,
                 rnn_type="gru",
                 ch=512,
                 layers_per_stack=3,
                 input_keys=["f0_scaled", "osc_scaled"],
                 output_splits=(("amps", 1), ("harmonic_distribution", 40)),
                 name="multi_input_rnn_fc_decoder"):
        super().__init__(output_splits=output_splits, name=name)
        self.input_keys = input_keys
        stack = lambda: nn.fc_stack(ch, layers_per_stack)

        # Layers.
        self.stacks = []
        for _ in range(self.n_in):
            self.stacks.append(stack())
        rnn_channels = make_iterable(rnn_channels)
        self.n_rnn = len(rnn_channels)
        self.rnn = [nn.rnn(rnn_channels[0], rnn_type)]
        for i in range(self.n_rnn-1):
            self.rnn.append(nn.rnn(rnn_channels[i+1], rnn_type))
        self.out_stack = stack()
        self.dense_out = nn.dense(self.n_out)
Beispiel #3
0
    def __call__(self, *inputs, **kwargs):
        """Wrap the layer's __call__() with dictionary inputs and outputs.

    IMPORTANT: If no input_keys are provided to the constructor, they are
    inferred from the argument names in call(). If no output_keys are provided
    to the constructor, they are inferred from return annotation of call()
    (a list of strings).

    Example:
    ========
    ```
    def call(self, f0_hz, loudness, power=None) -> ['amps', 'frequencies']:
      ...
      return amps, frequencies
    ```
    Will infer `self.input_keys = ['f0_hz', 'loudness']` and
    `self.output_keys = ['amps', 'frequencies']`. If input_keys, or output_keys
    are provided to the constructor they will override these inferred values.
    It will also infer `self.default_input_keys = ['power']`, which it will try
    to look up the inputs, but use the default values and not throw an error if
    the key is not in the input dictionary.

    Example Usage:
    ==============
    The the example above works with both tensor inputs `layer(f0_hz, loudness)`
    or `layer(f0_hz, loudness, power)` or a dictionary of tensors
    `layer({'f0_hz':..., 'loudness':...})`, or
    `layer({'f0_hz':..., 'loudness':..., 'power':...})` and in both cases will
    return a dictionary of tensors `{'amps':..., 'frequencies':...}`.

    Args:
      *inputs: Arguments passed on to call(). If any arguments are dicts, they
        will be merged and self.input_keys will be read out of them and passed
        to call() while other args will be ignored.
      **kwargs: Keyword arguments passed on to call().

    Returns:
      outputs: A dictionary of layer outputs from call(). If the layer call()
        returns a dictionary it will be returned directly, otherwise the output
        tensors will be wrapped in a dictionary {output_key: output_tensor}.
    """
        # Construct a list of input tensors equal in length and order to the `call`
        # input signature.
        # -- Start first with any tensor arguments.
        # -- Then lookup tensors from input dictionaries.
        # -- Use default values if not found.

        # Start by merging all dictionaries of tensors from the input.
        input_dict = {}
        for v in inputs:
            if isinstance(v, dict):
                input_dict.update(v)

        # And then strip all dictionaries from the input.
        inputs = [v for v in inputs if not isinstance(v, dict)]

        # Add any tensors from kwargs.
        for key in self.all_input_keys:
            if key in kwargs:
                input_dict[key] = kwargs[key]

        # And strip from kwargs.
        kwargs = {
            k: v
            for k, v in kwargs.items() if k not in self.all_input_keys
        }

        # Look up further inputs from the dictionaries.
        for key in self.input_keys:
            try:
                # If key is present use the input_dict value.
                inputs.append(core.nested_lookup(key, input_dict))
            except KeyError:
                # Skip if not present.
                pass

        # Add default arguments.
        for key, value in zip(self.default_input_keys,
                              self.default_input_values):
            try:
                # If key is present, use the input_dict value.
                inputs.append(core.nested_lookup(key, input_dict))
            except KeyError:
                # Otherwise use the default value if not supplied as non-dict input.
                if len(inputs) < self.n_inputs:
                    inputs.append(value)

        # Run input tensors through the model.
        if len(inputs) != self.n_inputs:
            raise TypeError(
                f'{len(inputs)} input tensors extracted from inputs'
                '(including default args) but the layer expects '
                f'{self.n_inputs} tensors.\n'
                f'Input keys: {self.input_keys}\n'
                f'Default keys: {self.default_input_keys}\n'
                f'Default values: {self.default_input_values}\n'
                f'Input dictionaries: {input_dict}\n'
                f'Input Tensors (Args, Dicts, and Defaults): {inputs}\n')
        outputs = super().__call__(*inputs, **kwargs)

        # Return dict if call() returns it.
        if isinstance(outputs, dict):
            return outputs
        # Otherwise make a dict from output_keys.
        else:
            outputs = core.make_iterable(outputs)
            if len(self.output_keys) != len(outputs):
                raise ValueError(
                    f'Output keys ({self.output_keys}) must have the same'
                    f'length as outputs ({outputs})')
            return dict(zip(self.output_keys, outputs))
Beispiel #4
0
 def get_return_annotations(self, method):
     """Get list of strings of return annotations of method."""
     spec = inspect.getfullargspec(getattr(self, method))
     return core.make_iterable(spec.annotations['return'])