def _ekfcpr(y, x, const): dims = y.shape[-1] cpr_init, cpr_update, cpr_map = af.array(af.cpane_ekf, dims)(beta=0.6, const=const) cpr_state = cpr_init() _, (phi, _) = af.iterate(cpr_update, 0, cpr_state, y, x)[1] xhat = cpr_map(phi, y) return xhat, phi
def mimoaf(scope: Scope, signal, taps=32, rtap=None, dims=2, sps=2, train=False, mimofn=af.ddlms, mimokwargs={}, mimoinitargs={}): x, t = signal t = scope.variable('const', 't', conv1d_t, t, taps, rtap, 2, 'valid').value x = xop.frame(x, taps, sps) mimo_init, mimo_update, mimo_apply = mimofn(train=train, **mimokwargs) state = scope.variable( 'af_state', 'mimoaf', lambda *_: (0, mimo_init(dims=dims, taps=taps, **mimoinitargs)), ()) truth_var = scope.variable('aux_inputs', 'truth', lambda *_: None, ()) truth = truth_var.value if truth is not None: truth = truth[t.start:truth.shape[0] + t.stop] af_step, af_stats = state.value af_step, (af_stats, (af_weights, _)) = af.iterate(mimo_update, af_step, af_stats, x, truth) y = mimo_apply(af_weights, x) state.value = (af_step, af_stats) return Signal(y, t)
def _rdemimo(signal, truth, lr, Rs, sps, taps): y = jnp.asarray(signal) x = jnp.asarray(truth) rde_init, rde_update, rde_map = af.rde(lr=lr, Rs=Rs) yf = af.frame(y, taps, sps) s0 = rde_init(taps=taps, dtype=y.dtype) _, (ss, loss) = af.iterate(rde_update, 0, s0, yf, x)[1] xhat = rde_map(ss, yf) return xhat, loss
def _framekfcpr(y, x, n, w0, const): dims = y.shape[-1] cpr_init, cpr_update, cpr_map = af.array(af.frame_cpr_kf, dims)(alpha=0.98, const=const) yf = xop.frame(y, n, n) xf = xop.frame(x, n, n) if x is not None else x if w0 is None: w0 = xcomm.foe_mpowfftmax(y[:5000])[0].mean() cpr_state = cpr_init(w0=w0) _, (fo, phif) = af.iterate(cpr_update, 0, cpr_state, yf, xf, truth_ndim=3)[1] xhat = cpr_map(phif, yf).reshape((-1, dims)) phi = phif.reshape((-1, dims)) return xhat, (fo, phi)
def _lmsmimo(signal, truth, sps, taps, mu_w, mu_f, mu_s, mu_b, beta): y = jnp.asarray(signal) x = jnp.asarray(truth) lms_init, lms_update, lms_map = af.ddlms(lr_w=mu_w, lr_f=mu_f, lr_s=mu_s, lr_b=mu_b, beta=beta) yf = af.frame(y, taps, sps) s0 = lms_init(taps, mimoinit='centralspike') _, (ss, (loss, *_)) = af.iterate(lms_update, 0, s0, yf, x)[1] xhat = lms_map(ss, yf) return xhat, ss[-1], loss
def _modulusmimo(y, R2, Rs, sps, taps, cma_samples, lr): # prepare adaptive filters y = jnp.asarray(y) dims = y.shape[-1] cma_init, cma_update, _ = af.mucma(R2=R2, dims=dims) rde_init, rde_update, rde_map = af.rde(Rs=Rs, lr=lr) # framing signal to enable parallelization (a.k.a `jax.vmap`) yf = af.frame(y, taps, sps) # get initial weights s0 = cma_init(taps=taps, dtype=y.dtype) # initialize MIMO via MU-CMA to avoid singularity (w0, *_,), (ws1, loss1) = af.iterate(cma_update, 0, s0, yf[:cma_samples])[1] # switch to RDE _, (ws2, loss2) = af.iterate(rde_update, 0, w0, yf[cma_samples:])[1] loss = jnp.concatenate([loss1, loss2], axis=0) ws = jnp.concatenate([ws1, ws2], axis=0) x_hat = rde_map(ws, yf) return x_hat, ws, loss
def mimofoeaf(scope: Scope, signal, framesize=100, w0=0, train=False, preslicer=lambda x: x, foekwargs={}, mimofn=af.rde, mimokwargs={}, mimoinitargs={}): sps = 2 dims = 2 tx = signal.t # MIMO slisig = preslicer(signal) auxsig = scope.child(mimoaf, mimofn=mimofn, train=train, mimokwargs=mimokwargs, mimoinitargs=mimoinitargs, name='MIMO4FOE')(slisig) y, ty = auxsig # assume y is continuous in time yf = xop.frame(y, framesize, framesize) foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs) state = scope.variable('af_state', 'framefoeaf', lambda *_: (0., 0, foe_init(w0)), ()) phi, af_step, af_stats = state.value af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats, yf) wp = wf.reshape((-1, dims)).mean(axis=-1) w = jnp.interp( jnp.arange(y.shape[0] * sps) / sps, jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps psi = phi + jnp.cumsum(w) state.value = (psi[-1], af_step, af_stats) # apply FOE to original input signal via linear extrapolation psi_ext = jnp.concatenate([ w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi, w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1] ]) signal = signal * jnp.exp(-1j * psi_ext)[:, None] return signal