Пример #1
0
def _ekfcpr(y, x, const):
    dims = y.shape[-1]
    cpr_init, cpr_update, cpr_map = af.array(af.cpane_ekf, dims)(beta=0.6, const=const)
    cpr_state = cpr_init()
    _, (phi, _) = af.iterate(cpr_update, 0, cpr_state, y, x)[1]
    xhat = cpr_map(phi, y)
    return xhat, phi
Пример #2
0
def mimoaf(scope: Scope,
           signal,
           taps=32,
           rtap=None,
           dims=2,
           sps=2,
           train=False,
           mimofn=af.ddlms,
           mimokwargs={},
           mimoinitargs={}):

    x, t = signal
    t = scope.variable('const', 't', conv1d_t, t, taps, rtap, 2, 'valid').value
    x = xop.frame(x, taps, sps)
    mimo_init, mimo_update, mimo_apply = mimofn(train=train, **mimokwargs)
    state = scope.variable(
        'af_state', 'mimoaf', lambda *_:
        (0, mimo_init(dims=dims, taps=taps, **mimoinitargs)), ())
    truth_var = scope.variable('aux_inputs', 'truth', lambda *_: None, ())
    truth = truth_var.value
    if truth is not None:
        truth = truth[t.start:truth.shape[0] + t.stop]
    af_step, af_stats = state.value
    af_step, (af_stats, (af_weights, _)) = af.iterate(mimo_update, af_step,
                                                      af_stats, x, truth)
    y = mimo_apply(af_weights, x)
    state.value = (af_step, af_stats)
    return Signal(y, t)
Пример #3
0
def _rdemimo(signal, truth, lr, Rs, sps, taps):
    y = jnp.asarray(signal)
    x = jnp.asarray(truth)
    rde_init, rde_update, rde_map = af.rde(lr=lr, Rs=Rs)
    yf = af.frame(y, taps, sps)
    s0 = rde_init(taps=taps, dtype=y.dtype)
    _, (ss, loss) = af.iterate(rde_update, 0, s0, yf, x)[1]
    xhat = rde_map(ss, yf)
    return xhat, loss
Пример #4
0
def _framekfcpr(y, x, n, w0, const):
    dims = y.shape[-1]
    cpr_init, cpr_update, cpr_map = af.array(af.frame_cpr_kf, dims)(alpha=0.98, const=const)

    yf = xop.frame(y, n, n)
    xf = xop.frame(x, n, n) if x is not None else x
    if w0 is None:
        w0 = xcomm.foe_mpowfftmax(y[:5000])[0].mean()
    cpr_state = cpr_init(w0=w0)
    _, (fo, phif) = af.iterate(cpr_update, 0, cpr_state, yf, xf, truth_ndim=3)[1]
    xhat = cpr_map(phif, yf).reshape((-1, dims))
    phi = phif.reshape((-1, dims))
    return xhat, (fo, phi)
Пример #5
0
def _lmsmimo(signal, truth, sps, taps, mu_w, mu_f, mu_s, mu_b, beta):
    y = jnp.asarray(signal)
    x = jnp.asarray(truth)
    lms_init, lms_update, lms_map = af.ddlms(lr_w=mu_w,
                                             lr_f=mu_f,
                                             lr_s=mu_s,
                                             lr_b=mu_b,
                                             beta=beta)
    yf = af.frame(y, taps, sps)
    s0 = lms_init(taps, mimoinit='centralspike')
    _, (ss, (loss, *_)) = af.iterate(lms_update, 0, s0, yf, x)[1]
    xhat = lms_map(ss, yf)
    return xhat, ss[-1], loss
Пример #6
0
def _modulusmimo(y, R2, Rs, sps, taps, cma_samples, lr):
    # prepare adaptive filters

    y = jnp.asarray(y)

    dims = y.shape[-1]
    cma_init, cma_update, _ = af.mucma(R2=R2, dims=dims)
    rde_init, rde_update, rde_map = af.rde(Rs=Rs, lr=lr)

    # framing signal to enable parallelization (a.k.a `jax.vmap`)
    yf = af.frame(y, taps, sps)

    # get initial weights
    s0 = cma_init(taps=taps, dtype=y.dtype)
    # initialize MIMO via MU-CMA to avoid singularity
    (w0, *_,), (ws1, loss1) = af.iterate(cma_update, 0, s0, yf[:cma_samples])[1]
    # switch to RDE
    _, (ws2, loss2) = af.iterate(rde_update, 0, w0, yf[cma_samples:])[1]
    loss = jnp.concatenate([loss1, loss2], axis=0)
    ws = jnp.concatenate([ws1, ws2], axis=0)
    x_hat = rde_map(ws, yf)

    return x_hat, ws, loss
Пример #7
0
def mimofoeaf(scope: Scope,
              signal,
              framesize=100,
              w0=0,
              train=False,
              preslicer=lambda x: x,
              foekwargs={},
              mimofn=af.rde,
              mimokwargs={},
              mimoinitargs={}):

    sps = 2
    dims = 2
    tx = signal.t
    # MIMO
    slisig = preslicer(signal)
    auxsig = scope.child(mimoaf,
                         mimofn=mimofn,
                         train=train,
                         mimokwargs=mimokwargs,
                         mimoinitargs=mimoinitargs,
                         name='MIMO4FOE')(slisig)
    y, ty = auxsig  # assume y is continuous in time
    yf = xop.frame(y, framesize, framesize)

    foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs)
    state = scope.variable('af_state', 'framefoeaf', lambda *_:
                           (0., 0, foe_init(w0)), ())
    phi, af_step, af_stats = state.value

    af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats,
                                              yf)
    wp = wf.reshape((-1, dims)).mean(axis=-1)
    w = jnp.interp(
        jnp.arange(y.shape[0] * sps) / sps,
        jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps
    psi = phi + jnp.cumsum(w)
    state.value = (psi[-1], af_step, af_stats)

    # apply FOE to original input signal via linear extrapolation
    psi_ext = jnp.concatenate([
        w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi,
        w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1]
    ])

    signal = signal * jnp.exp(-1j * psi_ext)[:, None]
    return signal