Ejemplo n.º 1
0
import jax.numpy as jnp
from jax import jit

from jaxdsp.param import Param

NAME = "Clip"
PARAMS = [Param("min", -1.0, -1.0, 1.0), Param("max", 1.0, -1.0, 1.0)]
PRESETS = {}


def init_state():
    return {}


@jit
def tick(carry, x):
    params, _ = carry
    return carry, jnp.clip(x, params["min"], params["max"])


@jit
def tick_buffer(carry, X):
    return tick(carry, X)
Ejemplo n.º 2
0
import jax.numpy as jnp
from jax import jit, lax

from jaxdsp.param import Param

NAME = "FIR Filter"
# TODO how to handle array params in UI?
PARAMS = [Param("B", jnp.concatenate([jnp.array([1.0]), jnp.zeros(4)]))]
PRESETS = {}


def init_state(length=4):
    return {"inputs": jnp.zeros(length)}


@jit
def tick(carry, x):
    params, state = carry
    B = params["B"]
    state["inputs"] = jnp.concatenate([jnp.array([x]), state["inputs"][0:-1]])
    y = B @ state["inputs"]
    return carry, y


@jit
def tick_buffer(carry, X):
    params, _ = carry
    B = params["B"]
    return carry, jnp.convolve(X, B)[:-(B.size - 1)]
    # Impossibly, the following seems to perform about the exact same or even faster for large N?! O_O
    # return lax.scan(tick, carry, X)[1]
Ejemplo n.º 3
0
# http://www.earlevel.com/main/2012/11/26/biquad-c-source-code/

import jax.numpy as jnp
from jax import jit, lax

from jaxdsp.param import Param

NAME = "BiQuad Lowpass Filter"
PARAMS = [Param("resonance", 0.7), Param("cutoff", 0.49, 0.0, 0.49)]
PRESETS = {}


def init_state():
    return {
        "a0": 1.0,
        "a1": 0.0,
        "a2": 0.0,
        "b1": 0.0,
        "b2": 0.0,
        "z1": 0.0,
        "z2": 0.0,
    }


@jit
def tick(carry, x):
    _, state = carry
    out = x * state["a0"] + state["z1"]
    state["z1"] = x * state["a1"] + state["z2"] - state["b1"] * out
    state["z2"] = x * state["a2"] - state["b2"] * out
Ejemplo n.º 4
0
import jax.numpy as jnp
from jax import jit, lax
from jax.ops import index_update

from jaxdsp.param import Param

NAME = "Allpass Filter"
PARAMS = [Param("feedback", 0.0)]
PRESETS = {}


def init_state(buffer_size=20):
    return {
        "buffer": jnp.zeros(buffer_size),
        "buffer_index": 0,
        "filter_store": 0.0,
    }


@jit
def tick(carry, x):
    params, state = carry
    buffer_out = state["buffer"][state["buffer_index"]]
    state["buffer"] = index_update(state["buffer"], state["buffer_index"],
                                   x + buffer_out * params["feedback"])
    state["buffer_index"] += 1
    state["buffer_index"] %= state["buffer"].size
    out = -x + buffer_out
    return carry, out

Ejemplo n.º 5
0
class RmsProp:
    NAME = "RMSProp"
    PARAMS = [step_size_param, Param("gamma", 0.9)]
    FUNCTION = optimizers.rmsprop
Ejemplo n.º 6
0
import jax.numpy as jnp
from jax import jit, lax
from jax.ops import index_update

from jaxdsp.param import Param

NAME = "Lowpass Feedback Comb Filter"
PARAMS = [Param("feedback", 0.0), Param("damp", 0.0)]
PRESETS = {}


def init_state(buffer_size=20):
    return {
        "buffer": jnp.zeros(buffer_size),
        "buffer_index": 0,
        "filter_store": 0.0,
    }


@jit
def tick(carry, x):
    params, state = carry

    out = state["buffer"][state["buffer_index"]]
    state["filter_store"] = (
        out * (1 - params["damp"]) + state["filter_store"] * params["damp"]
    )

    state["buffer"] = index_update(
        state["buffer"],
        state["buffer_index"],
Ejemplo n.º 7
0
class Adamax:
    NAME = "Adamax"
    PARAMS = [step_size_param, Param("b1", 0.9), Param("b2", 0.999)]
    FUNCTION = optimizers.adamax
Ejemplo n.º 8
0
class Nesterov:
    NAME = "Nesterov"
    PARAMS = [step_size_param, Param("mass", 0.5)]
    FUNCTION = optimizers.nesterov
Ejemplo n.º 9
0
class AdaGrad:
    NAME = "AdaGrad"
    PARAMS = [step_size_param, Param("momentum", 0.9)]
    FUNCTION = optimizers.adagrad
Ejemplo n.º 10
0
import jax.numpy as jnp
from jax.experimental import optimizers
from jax.experimental.optimizers import optimizer
from jax.tree_util import tree_map

from jaxdsp.param import Param

step_size_param = Param("step_size", 0.05, 1e-9, 0.2)

# Clip all params to [0.0, 1.0] (params are all normalized to unit scale before passing to gradient fn).
# TODO should also add a loss component to strongly encourage params into this range
@optimizer
def param_clipping_optimizer(init, update, get_params):
    # Note that these are the params being optimized, NOT the optimizer params :)
    def get_clipped_params(state):
        params = get_params(state)
        return tree_map(lambda param: jnp.clip(param, 0.0, 1.0), params)

    return init, update, get_clipped_params


class Optimizer:
    def __init__(self, definition, param_values=None):
        self.definition = definition
        self.set_param_values(param_values)

    def set_param_values(self, param_values=None):
        self.param_values = {
            param.name: (param_values or {}).get(param.name) or param.default_value
            for param in self.definition.PARAMS
        }
Ejemplo n.º 11
0
import jax.numpy as jnp
from jax import jit, lax

from jaxdsp.param import Param
from jaxdsp.processors import allpass_filter as allpass
from jaxdsp.processors import lowpass_feedback_comb_filter as comb

NAME = "Freeverb"
PARAMS = [
    Param("wet", 0.0),
    Param("dry", 1.0),
    Param("width", 0.0),
    Param("damp", 0.0),
    Param("room_size", 0.0, 0.0, 1.2),
]
PRESETS = {
    "flat_space": {
        "wet": 0.3,
        "dry": 0.6,
        "width": 0.5,
        "damp": 0.3,
        "room_size": 1.055,
    },
    "expanding_space": {
        "wet": 0.33,
        "dry": 0.0,
        "width": 0.5,
        "damp": 0.1,
        "room_size": 1.078,
    },
}
Ejemplo n.º 12
0
import jax.numpy as jnp
from jax import jit, lax
from jax.ops import index, index_update

from jaxdsp.param import Param
from jaxdsp.processors.constants import DEFAULT_SAMPLE_RATE

# Stay inside the range of the the lowest frequency resolved by the fft in the spectral optimizers
MAX_DELAY_LENGTH_MS = 40.0
MAX_DELAY_LENGTH_SAMPLES = math.ceil(
    DEFAULT_SAMPLE_RATE * (MAX_DELAY_LENGTH_MS / 1000.0)
)

NAME = "Delay Line"
PARAMS = [
    Param("wet", 1.0),
    Param("delay_ms", MAX_DELAY_LENGTH_MS / 2, 0.0, MAX_DELAY_LENGTH_MS),
]
PRESETS = {}


def init_state():
    return {
        "delay_line": jnp.zeros(MAX_DELAY_LENGTH_SAMPLES),
        "read_sample": 0.0,
        "write_sample": 0.0,
        "sample_rate": DEFAULT_SAMPLE_RATE,
    }


@jit
Ejemplo n.º 13
0
import jax.numpy as jnp
from jax import jit, lax

from jaxdsp.param import Param

NAME = "IIR Filter"
PARAMS = [
    Param("B", jnp.concatenate([jnp.array([1.0]),
                                jnp.zeros(4)])),
    Param("A", jnp.concatenate([jnp.array([1.0]),
                                jnp.zeros(4)])),
]
PRESETS = {}


def init_state(length=5):
    return {
        "inputs": jnp.zeros(length),
        "outputs": jnp.zeros(length - 1),
    }


@jit
def tick(carry, x):
    params, state = carry
    B = params["B"]
    A = params["A"]
    state["inputs"] = jnp.concatenate([jnp.array([x]), state["inputs"][0:-1]])
    y = B @ state["inputs"]
    if state["outputs"].size > 0:
        y -= A[1:] @ state["outputs"]