Example #1
0
def measure_cd(x, sr, start=-0.25, end=0.25, bins=2000, wavlen=1550e-9):
    '''
    References:
        Zhou, H., Li, B., Tang, et. al, 2016. Fractional fourier transformation-based
        blind chromatic dispersion estimation for coherent optical communications.
        Journal of Lightwave Technology, 34(10), pp.2371-2380.
    '''
    c = 299792458.
    p = jnp.linspace(start, end, bins)
    N = x.shape[0]
    K = p.shape[0]

    L = jnp.zeros(K, dtype=jnp.float32)

    def f(_, pi):
        return None, jnp.sum(jnp.abs(xop.frft(jnp.abs(xop.frft(x, pi))**2, -1))**2)

    # Use `scan` instead of `vmap` here to avoid potential large memory allocation.
    # Despite the speed of `scan` scales surprisingly well to large bins,
    # the speed has a lowerbound e.g 600ms at bins=1, possiblely related to the blind
    # migration of `frft` from Github :) (could frft be jitted in theory?).
    # TODO review `frft`
    _, L = xop.scan(f, None, p)

    B2z = jnp.tan(jnp.pi/2 - (p - 1) / 2 * jnp.pi)/(sr * 2 * jnp.pi / N * sr)
    Dz_set  = -B2z / wavlen**2 * 2 * jnp.pi * c # the swept set of CD metrics
    Dz_hat = Dz_set[jnp.argmin(L)] # estimated accumulated CD

    return Dz_hat, L, Dz_set
Example #2
0
def dbp_timedomain(y, h, c, mode='SAME', homosteps=True, scansteps=True, conv=xop.fftconvolve):

    y = device_put(y)
    h = device_put(h)
    c = device_put(c)

    dims = y.shape[-1]

    optpowscale = jnp.sqrt(dims)
    y /= optpowscale

    md = 'SAME' if homosteps else mode

    D = jit(vmap(lambda y, h: conv(y, h, mode=md), in_axes=1, out_axes=1))
    N = jit(lambda y, c: y * jnp.exp(1j * (abs(y)**2 @ c)))

    T = h.shape[1] - 1
    K = h.shape[0]

    if homosteps and scansteps: # homogeneous steps is faster on first jitted run
    # scan not working on 'SAME' mode due to carry shape change
        y = xop.scan(lambda x, p: (N(D(x, p[0]), p[1]), 0.), y, (h, c))[0]
    else:
        steps = c.shape[0]
        for i in range(steps):
            y = D(y, h[i])
            y = N(y, c[i])

    if homosteps and mode.lower() == 'valid':
       y = y[K * T // 2: -K * T // 2]

    return y * optpowscale
Example #3
0
def _power_local(y, frame_size, frame_step, sps):
    yf = xop.frame(y, frame_size, frame_step, True)

    N = y.shape[0]
    frames = yf.shape[0]

    _, power = xop.scan(lambda c, y: (c, jnp.mean(jnp.abs(y)**2, axis=0)),
                        None, yf)

    xp = jnp.arange(frames) * frame_step + frame_size // 2
    x = jnp.arange(N * sps) / sps

    interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp),
                  in_axes=(None, None, -1),
                  out_axes=-1)

    power_ip = interp(x, xp, power)

    return power_ip
Example #4
0
def iterate(update: UpdateFn,
            step0: Step,
            state: AFState,
            signal: Signal,
            truth=None,
            truth_ndim=2,
            device=None):
    steps = step0 + jnp.arange(signal.shape[0])
    # pad dummy truth
    truth = jnp.zeros(
        (0, *signal.shape[1 - truth_ndim:]),
        dtype=signal.dtype) if truth is None else truth[:signal.shape[0]]
    padw_data_axes = ((0, 0), ) * (truth_ndim - 1)
    truth = jnp.pad(truth,
                    ((0, signal.shape[0] - truth.shape[0]), *padw_data_axes))
    xs = (steps, signal, truth)
    return steps[-1], xop.scan(lambda c, xs: update(xs[0], c, xs[1:]),
                               state,
                               xs,
                               jit_device=device)
Example #5
0
def _foe_local(y, frame_size, frame_step, sps):

    Y = xop.frame(y, frame_size, frame_step, True)

    N = y.shape[0]
    frames = Y.shape[0]

    def foe(carray, y):
        fo_hat, _ = foe_mpowfftmax(y)
        return carray, fo_hat

    _, fo_hat = xop.scan(foe, None, Y)

    xp = jnp.arange(frames) * frame_step + frame_size // 2
    x = jnp.arange(N * sps) / sps
    fo_hat /= sps

    interp = vmap(lambda x, xp, fp: jnp.interp(x, xp, fp),
                  in_axes=(None, None, -1),
                  out_axes=-1)

    fo_hat_ip = interp(x, xp, fo_hat)

    return fo_hat_ip
Example #6
0
def corr_local(y, x, frame_size=2000, L=None, device=gpus[0]):
    y = device_put(y, device)
    x = device_put(x, device)

    if L is None:
        L = len(np.unique(x))

    Y = xop.frame(y, frame_size, frame_size, True)
    X = xop.frame(x, frame_size, frame_size, True)

    lag = jnp.arange(-(frame_size - 1) // 2, (frame_size + 1) // 2)

    corr_v = vmap(lambda a, b: xop.correlate_fft(a, b),
                  in_axes=-1,
                  out_axes=-1)

    def f(_, z):
        y, x = z
        c = jnp.abs(corr_v(y, x))
        return _, lag[jnp.argmax(c, axis=0)]

    _, ret = xop.scan(f, None, (Y, X))

    return ret
Example #7
0
def iterate(update, state, signal, truth=None, train=None, device=cpus[0]):
    xs = make_train_argin(signal, truth, train)
    return xop.scan(update, state, xs, jit_device=device)