def measure_cd(x, sr, start=-0.25, end=0.25, bins=2000, wavlen=1550e-9): ''' References: Zhou, H., Li, B., Tang, et. al, 2016. Fractional fourier transformation-based blind chromatic dispersion estimation for coherent optical communications. Journal of Lightwave Technology, 34(10), pp.2371-2380. ''' c = 299792458. p = jnp.linspace(start, end, bins) N = x.shape[0] K = p.shape[0] L = jnp.zeros(K, dtype=jnp.float32) def f(_, pi): return None, jnp.sum(jnp.abs(xop.frft(jnp.abs(xop.frft(x, pi))**2, -1))**2) # Use `scan` instead of `vmap` here to avoid potential large memory allocation. # Despite the speed of `scan` scales surprisingly well to large bins, # the speed has a lowerbound e.g 600ms at bins=1, possiblely related to the blind # migration of `frft` from Github :) (could frft be jitted in theory?). # TODO review `frft` _, L = xop.scan(f, None, p) B2z = jnp.tan(jnp.pi/2 - (p - 1) / 2 * jnp.pi)/(sr * 2 * jnp.pi / N * sr) Dz_set = -B2z / wavlen**2 * 2 * jnp.pi * c # the swept set of CD metrics Dz_hat = Dz_set[jnp.argmin(L)] # estimated accumulated CD return Dz_hat, L, Dz_set
def dbp_timedomain(y, h, c, mode='SAME', homosteps=True, scansteps=True, conv=xop.fftconvolve): y = device_put(y) h = device_put(h) c = device_put(c) dims = y.shape[-1] optpowscale = jnp.sqrt(dims) y /= optpowscale md = 'SAME' if homosteps else mode D = jit(vmap(lambda y, h: conv(y, h, mode=md), in_axes=1, out_axes=1)) N = jit(lambda y, c: y * jnp.exp(1j * (abs(y)**2 @ c))) T = h.shape[1] - 1 K = h.shape[0] if homosteps and scansteps: # homogeneous steps is faster on first jitted run # scan not working on 'SAME' mode due to carry shape change y = xop.scan(lambda x, p: (N(D(x, p[0]), p[1]), 0.), y, (h, c))[0] else: steps = c.shape[0] for i in range(steps): y = D(y, h[i]) y = N(y, c[i]) if homosteps and mode.lower() == 'valid': y = y[K * T // 2: -K * T // 2] return y * optpowscale
def _power_local(y, frame_size, frame_step, sps): yf = xop.frame(y, frame_size, frame_step, True) N = y.shape[0] frames = yf.shape[0] _, power = xop.scan(lambda c, y: (c, jnp.mean(jnp.abs(y)**2, axis=0)), None, yf) xp = jnp.arange(frames) * frame_step + frame_size // 2 x = jnp.arange(N * sps) / sps interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp), in_axes=(None, None, -1), out_axes=-1) power_ip = interp(x, xp, power) return power_ip
def iterate(update: UpdateFn, step0: Step, state: AFState, signal: Signal, truth=None, truth_ndim=2, device=None): steps = step0 + jnp.arange(signal.shape[0]) # pad dummy truth truth = jnp.zeros( (0, *signal.shape[1 - truth_ndim:]), dtype=signal.dtype) if truth is None else truth[:signal.shape[0]] padw_data_axes = ((0, 0), ) * (truth_ndim - 1) truth = jnp.pad(truth, ((0, signal.shape[0] - truth.shape[0]), *padw_data_axes)) xs = (steps, signal, truth) return steps[-1], xop.scan(lambda c, xs: update(xs[0], c, xs[1:]), state, xs, jit_device=device)
def _foe_local(y, frame_size, frame_step, sps): Y = xop.frame(y, frame_size, frame_step, True) N = y.shape[0] frames = Y.shape[0] def foe(carray, y): fo_hat, _ = foe_mpowfftmax(y) return carray, fo_hat _, fo_hat = xop.scan(foe, None, Y) xp = jnp.arange(frames) * frame_step + frame_size // 2 x = jnp.arange(N * sps) / sps fo_hat /= sps interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp), in_axes=(None, None, -1), out_axes=-1) fo_hat_ip = interp(x, xp, fo_hat) return fo_hat_ip
def corr_local(y, x, frame_size=2000, L=None, device=gpus[0]): y = device_put(y, device) x = device_put(x, device) if L is None: L = len(np.unique(x)) Y = xop.frame(y, frame_size, frame_size, True) X = xop.frame(x, frame_size, frame_size, True) lag = jnp.arange(-(frame_size - 1) // 2, (frame_size + 1) // 2) corr_v = vmap(lambda a, b: xop.correlate_fft(a, b), in_axes=-1, out_axes=-1) def f(_, z): y, x = z c = jnp.abs(corr_v(y, x)) return _, lag[jnp.argmax(c, axis=0)] _, ret = xop.scan(f, None, (Y, X)) return ret
def iterate(update, state, signal, truth=None, train=None, device=cpus[0]): xs = make_train_argin(signal, truth, train) return xop.scan(update, state, xs, jit_device=device)