示例#1
0
def _framekfcpr(y, x, n, w0, const):
    dims = y.shape[-1]
    cpr_init, cpr_update, cpr_map = af.array(af.frame_cpr_kf, dims)(alpha=0.98, const=const)

    yf = xop.frame(y, n, n)
    xf = xop.frame(x, n, n) if x is not None else x
    if w0 is None:
        w0 = xcomm.foe_mpowfftmax(y[:5000])[0].mean()
    cpr_state = cpr_init(w0=w0)
    _, (fo, phif) = af.iterate(cpr_update, 0, cpr_state, yf, xf, truth_ndim=3)[1]
    xhat = cpr_map(phif, yf).reshape((-1, dims))
    phi = phif.reshape((-1, dims))
    return xhat, (fo, phi)
示例#2
0
def mimoaf(scope: Scope,
           signal,
           taps=32,
           rtap=None,
           dims=2,
           sps=2,
           train=False,
           mimofn=af.ddlms,
           mimokwargs={},
           mimoinitargs={}):

    x, t = signal
    t = scope.variable('const', 't', conv1d_t, t, taps, rtap, 2, 'valid').value
    x = xop.frame(x, taps, sps)
    mimo_init, mimo_update, mimo_apply = mimofn(train=train, **mimokwargs)
    state = scope.variable(
        'af_state', 'mimoaf', lambda *_:
        (0, mimo_init(dims=dims, taps=taps, **mimoinitargs)), ())
    truth_var = scope.variable('aux_inputs', 'truth', lambda *_: None, ())
    truth = truth_var.value
    if truth is not None:
        truth = truth[t.start:truth.shape[0] + t.stop]
    af_step, af_stats = state.value
    af_step, (af_stats, (af_weights, _)) = af.iterate(mimo_update, af_step,
                                                      af_stats, x, truth)
    y = mimo_apply(af_weights, x)
    state.value = (af_step, af_stats)
    return Signal(y, t)
示例#3
0
def foe_daxcorr(y, x, L=100):
    N = y.shape[0]
    if N < L:
        raise TypeError('signal length %d is less then xcorr length %d' % (N, L))
    s = y * x.conj() # remove modulated data phase
    sf = xop.frame(s, L, L)
    sf2 = sf[:-1]
    sf1 = sf[1:]
    return jnp.mean(jnp.angle(sf1 * sf2.conj())) / L
示例#4
0
def corr_local(y, x, frame_size=2000, L=None, device=gpus[0]):
    y = device_put(y, device)
    x = device_put(x, device)

    if L is None:
        L = len(np.unique(x))

    Y = xop.frame(y, frame_size, frame_size, True)
    X = xop.frame(x, frame_size, frame_size, True)

    lag = jnp.arange(-(frame_size - 1) // 2, (frame_size + 1) // 2)

    corr_v = vmap(lambda a, b: xop.correlate_fft(a, b),
                  in_axes=-1,
                  out_axes=-1)

    def f(_, z):
        y, x = z
        c = jnp.abs(corr_v(y, x))
        return _, lag[jnp.argmax(c, axis=0)]

    _, ret = xop.scan(f, None, (Y, X))

    return ret
示例#5
0
def mimofoeaf(scope: Scope,
              signal,
              framesize=100,
              w0=0,
              train=False,
              preslicer=lambda x: x,
              foekwargs={},
              mimofn=af.rde,
              mimokwargs={},
              mimoinitargs={}):

    sps = 2
    dims = 2
    tx = signal.t
    # MIMO
    slisig = preslicer(signal)
    auxsig = scope.child(mimoaf,
                         mimofn=mimofn,
                         train=train,
                         mimokwargs=mimokwargs,
                         mimoinitargs=mimoinitargs,
                         name='MIMO4FOE')(slisig)
    y, ty = auxsig  # assume y is continuous in time
    yf = xop.frame(y, framesize, framesize)

    foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs)
    state = scope.variable('af_state', 'framefoeaf', lambda *_:
                           (0., 0, foe_init(w0)), ())
    phi, af_step, af_stats = state.value

    af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats,
                                              yf)
    wp = wf.reshape((-1, dims)).mean(axis=-1)
    w = jnp.interp(
        jnp.arange(y.shape[0] * sps) / sps,
        jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps
    psi = phi + jnp.cumsum(w)
    state.value = (psi[-1], af_step, af_stats)

    # apply FOE to original input signal via linear extrapolation
    psi_ext = jnp.concatenate([
        w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi,
        w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1]
    ])

    signal = signal * jnp.exp(-1j * psi_ext)[:, None]
    return signal
示例#6
0
def _power_local(y, frame_size, frame_step, sps):
    yf = xop.frame(y, frame_size, frame_step, True)

    N = y.shape[0]
    frames = yf.shape[0]

    _, power = xop.scan(lambda c, y: (c, jnp.mean(jnp.abs(y)**2, axis=0)),
                        None, yf)

    xp = jnp.arange(frames) * frame_step + frame_size // 2
    x = jnp.arange(N * sps) / sps

    interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp),
                  in_axes=(None, None, -1),
                  out_axes=-1)

    power_ip = interp(x, xp, power)

    return power_ip
示例#7
0
def _foe_local(y, frame_size, frame_step, sps):

    Y = xop.frame(y, frame_size, frame_step, True)

    N = y.shape[0]
    frames = Y.shape[0]

    def foe(carray, y):
        fo_hat, _ = foe_mpowfftmax(y)
        return carray, fo_hat

    _, fo_hat = xop.scan(foe, None, Y)

    xp = jnp.arange(frames) * frame_step + frame_size // 2
    x = jnp.arange(N * sps) / sps
    fo_hat /= sps

    interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp),
                  in_axes=(None, None, -1),
                  out_axes=-1)

    fo_hat_ip = interp(x, xp, fo_hat)

    return fo_hat_ip
示例#8
0
def frame(y, taps, sps, rtap=None):
    y_pad = jnp.pad(y, mimozerodelaypads(taps=taps, sps=sps, rtap=rtap))
    yf = jnp.array(xop.frame(y_pad, taps, sps))
    return yf