def _framekfcpr(y, x, n, w0, const): dims = y.shape[-1] cpr_init, cpr_update, cpr_map = af.array(af.frame_cpr_kf, dims)(alpha=0.98, const=const) yf = xop.frame(y, n, n) xf = xop.frame(x, n, n) if x is not None else x if w0 is None: w0 = xcomm.foe_mpowfftmax(y[:5000])[0].mean() cpr_state = cpr_init(w0=w0) _, (fo, phif) = af.iterate(cpr_update, 0, cpr_state, yf, xf, truth_ndim=3)[1] xhat = cpr_map(phif, yf).reshape((-1, dims)) phi = phif.reshape((-1, dims)) return xhat, (fo, phi)
def mimoaf(scope: Scope, signal, taps=32, rtap=None, dims=2, sps=2, train=False, mimofn=af.ddlms, mimokwargs={}, mimoinitargs={}): x, t = signal t = scope.variable('const', 't', conv1d_t, t, taps, rtap, 2, 'valid').value x = xop.frame(x, taps, sps) mimo_init, mimo_update, mimo_apply = mimofn(train=train, **mimokwargs) state = scope.variable( 'af_state', 'mimoaf', lambda *_: (0, mimo_init(dims=dims, taps=taps, **mimoinitargs)), ()) truth_var = scope.variable('aux_inputs', 'truth', lambda *_: None, ()) truth = truth_var.value if truth is not None: truth = truth[t.start:truth.shape[0] + t.stop] af_step, af_stats = state.value af_step, (af_stats, (af_weights, _)) = af.iterate(mimo_update, af_step, af_stats, x, truth) y = mimo_apply(af_weights, x) state.value = (af_step, af_stats) return Signal(y, t)
def foe_daxcorr(y, x, L=100): N = y.shape[0] if N < L: raise TypeError('signal length %d is less then xcorr length %d' % (N, L)) s = y * x.conj() # remove modulated data phase sf = xop.frame(s, L, L) sf2 = sf[:-1] sf1 = sf[1:] return jnp.mean(jnp.angle(sf1 * sf2.conj())) / L
def corr_local(y, x, frame_size=2000, L=None, device=gpus[0]): y = device_put(y, device) x = device_put(x, device) if L is None: L = len(np.unique(x)) Y = xop.frame(y, frame_size, frame_size, True) X = xop.frame(x, frame_size, frame_size, True) lag = jnp.arange(-(frame_size - 1) // 2, (frame_size + 1) // 2) corr_v = vmap(lambda a, b: xop.correlate_fft(a, b), in_axes=-1, out_axes=-1) def f(_, z): y, x = z c = jnp.abs(corr_v(y, x)) return _, lag[jnp.argmax(c, axis=0)] _, ret = xop.scan(f, None, (Y, X)) return ret
def mimofoeaf(scope: Scope, signal, framesize=100, w0=0, train=False, preslicer=lambda x: x, foekwargs={}, mimofn=af.rde, mimokwargs={}, mimoinitargs={}): sps = 2 dims = 2 tx = signal.t # MIMO slisig = preslicer(signal) auxsig = scope.child(mimoaf, mimofn=mimofn, train=train, mimokwargs=mimokwargs, mimoinitargs=mimoinitargs, name='MIMO4FOE')(slisig) y, ty = auxsig # assume y is continuous in time yf = xop.frame(y, framesize, framesize) foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs) state = scope.variable('af_state', 'framefoeaf', lambda *_: (0., 0, foe_init(w0)), ()) phi, af_step, af_stats = state.value af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats, yf) wp = wf.reshape((-1, dims)).mean(axis=-1) w = jnp.interp( jnp.arange(y.shape[0] * sps) / sps, jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps psi = phi + jnp.cumsum(w) state.value = (psi[-1], af_step, af_stats) # apply FOE to original input signal via linear extrapolation psi_ext = jnp.concatenate([ w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi, w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1] ]) signal = signal * jnp.exp(-1j * psi_ext)[:, None] return signal
def _power_local(y, frame_size, frame_step, sps): yf = xop.frame(y, frame_size, frame_step, True) N = y.shape[0] frames = yf.shape[0] _, power = xop.scan(lambda c, y: (c, jnp.mean(jnp.abs(y)**2, axis=0)), None, yf) xp = jnp.arange(frames) * frame_step + frame_size // 2 x = jnp.arange(N * sps) / sps interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp), in_axes=(None, None, -1), out_axes=-1) power_ip = interp(x, xp, power) return power_ip
def _foe_local(y, frame_size, frame_step, sps): Y = xop.frame(y, frame_size, frame_step, True) N = y.shape[0] frames = Y.shape[0] def foe(carray, y): fo_hat, _ = foe_mpowfftmax(y) return carray, fo_hat _, fo_hat = xop.scan(foe, None, Y) xp = jnp.arange(frames) * frame_step + frame_size // 2 x = jnp.arange(N * sps) / sps fo_hat /= sps interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp), in_axes=(None, None, -1), out_axes=-1) fo_hat_ip = interp(x, xp, fo_hat) return fo_hat_ip
def frame(y, taps, sps, rtap=None): y_pad = jnp.pad(y, mimozerodelaypads(taps=taps, sps=sps, rtap=rtap)) yf = jnp.array(xop.frame(y_pad, taps, sps)) return yf