import jax.numpy as jnp from jax import jit from jaxdsp.param import Param NAME = "Clip" PARAMS = [Param("min", -1.0, -1.0, 1.0), Param("max", 1.0, -1.0, 1.0)] PRESETS = {} def init_state(): return {} @jit def tick(carry, x): params, _ = carry return carry, jnp.clip(x, params["min"], params["max"]) @jit def tick_buffer(carry, X): return tick(carry, X)
import jax.numpy as jnp from jax import jit, lax from jaxdsp.param import Param NAME = "FIR Filter" # TODO how to handle array params in UI? PARAMS = [Param("B", jnp.concatenate([jnp.array([1.0]), jnp.zeros(4)]))] PRESETS = {} def init_state(length=4): return {"inputs": jnp.zeros(length)} @jit def tick(carry, x): params, state = carry B = params["B"] state["inputs"] = jnp.concatenate([jnp.array([x]), state["inputs"][0:-1]]) y = B @ state["inputs"] return carry, y @jit def tick_buffer(carry, X): params, _ = carry B = params["B"] return carry, jnp.convolve(X, B)[:-(B.size - 1)] # Impossibly, the following seems to perform about the exact same or even faster for large N?! O_O # return lax.scan(tick, carry, X)[1]
# http://www.earlevel.com/main/2012/11/26/biquad-c-source-code/ import jax.numpy as jnp from jax import jit, lax from jaxdsp.param import Param NAME = "BiQuad Lowpass Filter" PARAMS = [Param("resonance", 0.7), Param("cutoff", 0.49, 0.0, 0.49)] PRESETS = {} def init_state(): return { "a0": 1.0, "a1": 0.0, "a2": 0.0, "b1": 0.0, "b2": 0.0, "z1": 0.0, "z2": 0.0, } @jit def tick(carry, x): _, state = carry out = x * state["a0"] + state["z1"] state["z1"] = x * state["a1"] + state["z2"] - state["b1"] * out state["z2"] = x * state["a2"] - state["b2"] * out
import jax.numpy as jnp from jax import jit, lax from jax.ops import index_update from jaxdsp.param import Param NAME = "Allpass Filter" PARAMS = [Param("feedback", 0.0)] PRESETS = {} def init_state(buffer_size=20): return { "buffer": jnp.zeros(buffer_size), "buffer_index": 0, "filter_store": 0.0, } @jit def tick(carry, x): params, state = carry buffer_out = state["buffer"][state["buffer_index"]] state["buffer"] = index_update(state["buffer"], state["buffer_index"], x + buffer_out * params["feedback"]) state["buffer_index"] += 1 state["buffer_index"] %= state["buffer"].size out = -x + buffer_out return carry, out
class RmsProp: NAME = "RMSProp" PARAMS = [step_size_param, Param("gamma", 0.9)] FUNCTION = optimizers.rmsprop
import jax.numpy as jnp from jax import jit, lax from jax.ops import index_update from jaxdsp.param import Param NAME = "Lowpass Feedback Comb Filter" PARAMS = [Param("feedback", 0.0), Param("damp", 0.0)] PRESETS = {} def init_state(buffer_size=20): return { "buffer": jnp.zeros(buffer_size), "buffer_index": 0, "filter_store": 0.0, } @jit def tick(carry, x): params, state = carry out = state["buffer"][state["buffer_index"]] state["filter_store"] = ( out * (1 - params["damp"]) + state["filter_store"] * params["damp"] ) state["buffer"] = index_update( state["buffer"], state["buffer_index"],
class Adamax: NAME = "Adamax" PARAMS = [step_size_param, Param("b1", 0.9), Param("b2", 0.999)] FUNCTION = optimizers.adamax
class Nesterov: NAME = "Nesterov" PARAMS = [step_size_param, Param("mass", 0.5)] FUNCTION = optimizers.nesterov
class AdaGrad: NAME = "AdaGrad" PARAMS = [step_size_param, Param("momentum", 0.9)] FUNCTION = optimizers.adagrad
import jax.numpy as jnp from jax.experimental import optimizers from jax.experimental.optimizers import optimizer from jax.tree_util import tree_map from jaxdsp.param import Param step_size_param = Param("step_size", 0.05, 1e-9, 0.2) # Clip all params to [0.0, 1.0] (params are all normalized to unit scale before passing to gradient fn). # TODO should also add a loss component to strongly encourage params into this range @optimizer def param_clipping_optimizer(init, update, get_params): # Note that these are the params being optimized, NOT the optimizer params :) def get_clipped_params(state): params = get_params(state) return tree_map(lambda param: jnp.clip(param, 0.0, 1.0), params) return init, update, get_clipped_params class Optimizer: def __init__(self, definition, param_values=None): self.definition = definition self.set_param_values(param_values) def set_param_values(self, param_values=None): self.param_values = { param.name: (param_values or {}).get(param.name) or param.default_value for param in self.definition.PARAMS }
import jax.numpy as jnp from jax import jit, lax from jaxdsp.param import Param from jaxdsp.processors import allpass_filter as allpass from jaxdsp.processors import lowpass_feedback_comb_filter as comb NAME = "Freeverb" PARAMS = [ Param("wet", 0.0), Param("dry", 1.0), Param("width", 0.0), Param("damp", 0.0), Param("room_size", 0.0, 0.0, 1.2), ] PRESETS = { "flat_space": { "wet": 0.3, "dry": 0.6, "width": 0.5, "damp": 0.3, "room_size": 1.055, }, "expanding_space": { "wet": 0.33, "dry": 0.0, "width": 0.5, "damp": 0.1, "room_size": 1.078, }, }
import jax.numpy as jnp from jax import jit, lax from jax.ops import index, index_update from jaxdsp.param import Param from jaxdsp.processors.constants import DEFAULT_SAMPLE_RATE # Stay inside the range of the the lowest frequency resolved by the fft in the spectral optimizers MAX_DELAY_LENGTH_MS = 40.0 MAX_DELAY_LENGTH_SAMPLES = math.ceil( DEFAULT_SAMPLE_RATE * (MAX_DELAY_LENGTH_MS / 1000.0) ) NAME = "Delay Line" PARAMS = [ Param("wet", 1.0), Param("delay_ms", MAX_DELAY_LENGTH_MS / 2, 0.0, MAX_DELAY_LENGTH_MS), ] PRESETS = {} def init_state(): return { "delay_line": jnp.zeros(MAX_DELAY_LENGTH_SAMPLES), "read_sample": 0.0, "write_sample": 0.0, "sample_rate": DEFAULT_SAMPLE_RATE, } @jit
import jax.numpy as jnp from jax import jit, lax from jaxdsp.param import Param NAME = "IIR Filter" PARAMS = [ Param("B", jnp.concatenate([jnp.array([1.0]), jnp.zeros(4)])), Param("A", jnp.concatenate([jnp.array([1.0]), jnp.zeros(4)])), ] PRESETS = {} def init_state(length=5): return { "inputs": jnp.zeros(length), "outputs": jnp.zeros(length - 1), } @jit def tick(carry, x): params, state = carry B = params["B"] A = params["A"] state["inputs"] = jnp.concatenate([jnp.array([x]), state["inputs"][0:-1]]) y = B @ state["inputs"] if state["outputs"].size > 0: y -= A[1:] @ state["outputs"]