예제 #1
0
def rls_cma(y_f, w_init, beta=0.999, delta=1e-4, const=comm.const("16QAM", norm=True), device=cpus[0]):
    '''
    References:
    [1] Faruk, M.S. and Savory, S.J., 2017. Digital signal processing for coherent
    transceivers employing multilevel formats. Journal of Lightwave Technology, 35(5), pp.1125-1141.
    '''

    R2 = jnp.mean(abs(const)**4) / jnp.mean(abs(const)**2)

    N = y_f.shape[0]
    taps = w_init.shape[-1]
    dims = w_init.shape[0]
    # w_init: DxDxT -> DTxD
    w_init = jnp.reshape(w_init.conj(), (dims, dims * taps)).T
    cI = jnp.eye(dims * taps, dtype=y_f.dtype)
    P_init = delta * jnp.tile(cI[...,None], (1, 1, dims))
    y_f = jnp.reshape(y_f, (N, -1), order='F')

    params = (w_init, P_init, R2, beta)
    inputs = (y_f,)

    params = device_put(params, device)
    inputs = device_put(inputs, device)

    _, ret = scan(step_rls_cma, params, inputs)

    l, w = ret

    # w: NxDTxD -> NxDxDT
    w = jnp.moveaxis(w, 0, -1).T
    # w: NxDxDT -> NxDxDxT
    w = jnp.reshape(w, (N, dims, dims, taps)).conj()

    return l, w
예제 #2
0
def framekfcpr(signal, truth=None, n=100, w0=None, modformat='16QAM', const=None, backend='cpu'):
    y = jnp.asarray(signal)
    x = jnp.asarray(truth) if truth is not None else truth
    if const is None:
        const=comm.const(modformat, norm=True)
    const = jnp.asarray(const)
    return jit(_framekfcpr, backend=backend, static_argnums=2)(y, x, n, w0, const)
예제 #3
0
def ekfcpr(signal, truth=None, modformat='16QAM', const=None, backend='cpu'):
    y = jnp.asarray(signal)
    x = jnp.asarray(truth) if truth is not None else truth
    if const is None:
        const=comm.const(modformat, norm=True)
    const = jnp.asarray(const)
    return jit(_ekfcpr, backend=backend)(y, x, const)
예제 #4
0
def rde(lr: Union[float, Schedule] = 2**-15,
        train: Union[bool, Schedule] = False,
        Rs: Array = jnp.unique(jnp.abs(comm.const("16QAM", norm=True))),
        const: Optional[Array] = None) -> AdaptiveFilter:
    """Radius Directed adaptive Equalizer

    Args:
      lr: learning rate. scalar or Schedule
      train: schedule training mode, which can be a bool for global control within one call
        or an array of bool to swich training on iteration basis
      Rs: the radii of the target constellation
      const: Optional; constellation used to infer R2 when R2 is None

    Returns:
      an ``AdaptiveFilter`` object

    References:
      - [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and
        carrier phase recovery in a 16-QAM optical coherent system. Journal
        of lightwave technology, 27(15), pp.3042-3049.
    """
    lr = cxopt.make_schedule(lr)
    train = cxopt.make_schedule(train)

    if const is not None:
        Rs = jnp.array(jnp.unique(jnp.abs(const)))

    def init(dims=2, w0=None, taps=32, dtype=np.complex64):
        if w0 is None:
            w0 = np.zeros((dims, dims, taps), dtype=dtype)
            ctap = (taps + 1) // 2 - 1
            w0[np.arange(dims), np.arange(dims), ctap] = 1.
        return w0

    def loss_fn(w, u, x, i):
        v = r2c(mimo(w, u)[None, :])
        R2 = jnp.where(
            train(i),
            jnp.abs(x)**2,
            Rs[jnp.argmin(jnp.abs(Rs[:, None] * v / jnp.abs(v) - v),
                          axis=0)]**2)
        l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0, :])**2))
        return l

    def update(i, w, inp):
        u, x = inp
        l, g = jax.value_and_grad(loss_fn)(w, u, x, i)
        out = (w, l)
        w = w - lr(i) * g.conj()
        return w, out

    def apply(ws, yf):
        return jax.vmap(mimo)(ws, yf)

    return AdaptiveFilter(init, update, apply)
예제 #5
0
def lms_cpane(signal, w_init, data=None, train=None, lr=1e-4, beta=0.7, const=comm.const("16QAM", norm=True), device=cpus[0]):
    const = comm.const("16QAM", norm=True)

    if train is None:
        train = np.full((signal.shape[0],), False)
        data = np.full((signal.shape[0],), 0, dtype=const.dtype)

    dims = signal.shape[-1]

    params_lms = (w_init, lr)
    params_cpane = tuple(map(lambda x: np.tile(x, dims), [1e-5 * (1.+1j), 1e-2 * (1.+1j), 0j, 1j, 0j, beta])) + (const,)
    params = (params_lms, params_cpane)
    inputs = (signal, data, train)

    params = device_put(params, device)
    inputs = device_put(inputs, device)

    _, ret = scan(step_lms_cpane, params, inputs)

    return ret
예제 #6
0
def cma_2sec(y_f, h_init, w_init, mu1=1e-4, mu2=1e-1, const=comm.const("16QAM", norm=True), device=cpus[0]):

    R2 = jnp.mean(abs(const)**4) / jnp.mean(abs(const)**2)

    params = (h_init, w_init, mu1, mu2, R2)
    inputs = (y_f,)

    params = device_put(params, device)
    inputs = device_put(inputs, device)

    _, ret = scan(step_cma_2sec, params, inputs)
    return ret
예제 #7
0
def ddlms(mu_w=1/2**10, mu_f=1/2**6, mu_s=0., grad_max=(50., 50.), eps=1e-8,
          const=comm.const("16QAM", norm=True)):
    '''
    Impl. follows Fig. 6 in [1]
    References:
    [1] Mori, Y., Zhang, C. and Kikuchi, K., 2012. Novel configuration of
        finite-impulse-response filters tolerant to carrier-phase fluctuations
        in digital coherent optical receivers for higher-order quadrature
        amplitude modulation signals. Optics express, 20(24), pp.26236-26251.
    '''
    def init(state0):
        return state0

    def update(state, inp):
        w, f, s, = state
        u, x, train = inp

        v = mimo(w, u)

        z = v * f * s

        d = jnp.where(train, x, const[jnp.argmin(jnp.abs(const[:,None] - z[None,:]), axis=0)])

        psi_hat = jnp.abs(f)/f * jnp.abs(s)/s
        e_p = d * psi_hat - v
        e_f = d - f * v
        e_s = d - s * f * v
        gs = -1. / (jnp.abs(f * v)**2 + eps) * e_s * (f * v).conj()
        gf = -1. / (jnp.abs(v)**2 + eps) * e_f * v.conj()

        # clip the grads of f and s which are less regulated than w,
        # it may stablize this algo. in some corner cases?
        gw = -e_p[:, None, None] * u.conj().T[None, ...]
        gf = jnp.where(jnp.abs(gf) > grad_max[0], gf / jnp.abs(gf) * grad_max[0], gf)
        gs = jnp.where(jnp.abs(gs) > grad_max[1], gs / jnp.abs(gs) * grad_max[1], gs)

        out = (w, f, s, d)

        # update
        w = w - mu_w * gw
        f = f - mu_f * gf
        s = s - mu_s * gs

        state = (w, f, s)

        return state, out

    def static_map(ps, yf):
        ws, fs, ss = ps
        return jax.vmap(mimo)(ws, yf) * fs * ss

    return AdaptiveFilter(init, update, static_map)
예제 #8
0
def mu_cma(y_f, w_init, lr=1e-4, alpha=0.9999, const=comm.const("16QAM", norm=True), device=cpus[0]):
    d = jnp.mean(abs(const)**4) / jnp.mean(abs(const)**2)

    ntap = w_init.shape[-1]
    nch = y_f.shape[-1]
    z  = jnp.zeros((ntap, nch), dtype=y_f.dtype)
    c  = jnp.zeros((nch, nch, ntap), dtype=y_f.dtype)

    y_f = device_put(y_f, device)
    w_init = device_put(w_init, device)
    lr = device_put(lr, device)
    alpha = device_put(alpha, device)
    d = device_put(d, device)

    _, w = scan(step_mu_cma, (w_init, d, c, z, alpha, lr), y_f)
    return w
예제 #9
0
def cpane_ekf(beta=0.8,
              Q=1e-5 * (1.+1j),
              R=1e-2 * (1.+1j),
              const=comm.const("16QAM", norm=True)):
    '''
    References:
    [1] Pakala, L. and Schmauss, B., 2016. Extended Kalman filtering for joint mitigation
    of phase and amplitude noise in coherent QAM systems. Optics express, 24(6), pp.6391-6401.
    '''
    const = jnp.array(const)

    def init(p0=0j):
        state0 = (p0, 1j, 0j)
        return state0

    def update(state, inp):

        Psi_c, P_c, Psi_a = state
        y, x, train = inp

        Psi_p = Psi_c
        P_p = P_c + Q

        Psi_a = beta * Psi_a + (1 - beta) * Psi_c

        d = jnp.where(train,
                      x,
                      const[jnp.argmin(jnp.abs(const - y * jnp.exp(-1j * Psi_a)))])

        H = 1j * d * jnp.exp(1j * Psi_p)
        K = P_p * H.conj() / (H * P_p * H.conj() + R)
        v = y - d * jnp.exp(1j * Psi_p)

        out = (Psi_c, d)

        Psi_c = Psi_p + K * v
        P_c = (1. - K * H) * P_p

        state = (Psi_c, P_c, Psi_a)

        return state, out

    def static_map(Psi, ys):
        return ys * jnp.exp(-1j * Psi)

    return AdaptiveFilter(init, update, static_map)
예제 #10
0
def modulusmimo(signal, sps=2, taps=32, lr=2**-14, cma_samples=20000, modformat='16QAM', const=None, backend='cpu'):
    '''
    Adaptive MIMO equalizer for M-QAM signal
    '''
    y = jnp.asarray(signal)

    if y.shape[0] < cma_samples:
        raise ValueError('cam_samples must > given samples length')

    if const is None:
        const = comm.const(modformat, norm=True)

    R2 = np.array(np.mean(abs(const)**4) / np.mean(abs(const)**2))
    Rs = np.array(np.unique(np.abs(const)))

    return jit(_modulusmimo,
               static_argnums=(3, 4, 5),
               backend=backend)(y, R2, Rs, sps, taps, cma_samples, lr)
예제 #11
0
def rde(lr=1e-4, Rs=jnp.unique(jnp.abs(comm.const("16QAM", norm=True)))):
    '''
    References:
    [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and
        carrier phase recovery in a 16-QAM optical coherent system. Journal
        of lightwave technology, 27(15), pp.3042-3049.
    '''
    def init(w0=None, taps=19, dims=2, unitarize=False):
        if w0 is None:
            w0 = np.zeros((2, 2, taps), dtype=np.complex64)
            ctap = (taps + 1) // 2 - 1
            w0[np.arange(dims), np.arange(dims), ctap] = 1.
        elif unitarize:
            try:
                w0 = unitarize_mimo_weights(w0)
            except:
                pass
        return w0

    def update(w, inp):
        u, Rx, train = inp

        def loss_fn(w, u):
            v = mimo(w, u)[None,:]
            R2 = jnp.where(train,
                           Rx**2,
                           Rs[jnp.argmin(
                               jnp.abs(Rs[:,None] * v / jnp.abs(v) - v),
                               axis=0)]**2)
            l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0,:])**2))
            return l

        l, g = jax.value_and_grad(loss_fn)(w, u)
        out = (l, w)
        w = w - lr * g.conj()
        return w, out

    def static_map(ws, yf):
        return jax.vmap(mimo)(ws, yf)

    return AdaptiveFilter(init, update, static_map)
예제 #12
0
def cpane_ekf(signal, data=None, train=None, beta=0.8, device=cpus[0]):
    '''
    References:
    [1] Pakala, L. and Schmauss, B., 2016. Extended Kalman filtering for joint mitigation
    of phase and amplitude noise in coherent QAM systems. Optics express, 24(6), pp.6391-6401.
    '''
    const = comm.const("16QAM", norm=True)

    if train is None:
        train = np.full((signal.shape[0],), False)
        data = np.full((signal.shape[0],), 0, dtype=const.dtype)

    params = (1e-5 * (1.+1j), 1e-2 * (1.+1j), 0j, 1j, 0j, beta, const)
    inputs = (signal, data, train)

    params = device_put(params, device)
    inputs = device_put(inputs, device)

    psi_hat = _cpane_ekf(params, inputs)

    return psi_hat
예제 #13
0
def lms(
    lr: Union[float, Schedule] = 1e-4,
    const: Optional[Array] = comm.const('16QAM', norm=True)
) -> AdaptiveFilter:
    """LMS MIMO adaptive filter.

    Args:
      lr: Optional; learning rate
      const: Optional; constellation used to infer R2 when R2 is None

    Returns:
      an ``AdaptiveFilter`` object
    """
    lr = cxopt.make_schedule(lr)
    if const is not None:
        const = jnp.asarray(const)

    def init(w0=None, taps=19, dims=2, dtype=np.complex64):
        if w0 is None:
            w0 = np.zeros((dims, dims, taps), dtype=dtype)
            ctap = (taps + 1) // 2 - 1
            w0[np.arange(dims), np.arange(dims), ctap] = 1.
        return w0.astype(dtype)

    def loss_fn(w, u):
        v = r2c(mimo(w, u)[None, :])[0, :]
        d = decision(const, v)
        loss = jnp.sum(jnp.abs(d - v)**2)
        return loss

    def update(i, w, u):
        l, g = jax.value_and_grad(loss_fn)(w, u)
        out = (w, l)
        w = w - lr(i) * g.conj()
        return w, out

    def apply(ws, yf):
        return jax.vmap(mimo)(ws, yf)

    return AdaptiveFilter(init, update, apply)
예제 #14
0
def rde(signal, w_init, data=None, train=None, lr=1e-4, const=comm.const("16QAM", norm=True), device=cpus[0]):
    '''
    References:
    [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and carrier phase recovery
        in a 16-QAM optical coherent system. Journal of lightwave technology, 27(15), pp.3042-3049.
    '''
    if train is None:
        train = np.full((signal.shape[0],), False)
        data = np.full((signal.shape[0], signal.shape[-1]), 0, dtype=signal.dtype)
    else:
        if train.shape[0] != signal.shape[0] or data.shape[0] != signal.shape[0]:
            raise ValueError('invalid shape')

    Rs = np.unique(np.abs(const))

    params = (w_init, jnp.array([0j, 0j]), lr, Rs)
    inputs = (signal, data, train)

    params = device_put(params, device)
    inputs = device_put(inputs, device)

    _, ret = scan(step_rde, params, inputs)
    return ret
예제 #15
0
def dd_lms(signal, w_init, f_init=None, s_init=None, data=None, train=None, lr_w=1/2**10, lr_f=1/2**6, lr_s=0.,
           grad_max=(50., 50.), const=comm.const("16QAM", norm=True), device=cpus[0]):
    '''
    Impl. follows Fig. 6 in [1]
    References:
    [1] Mori, Y., Zhang, C. and Kikuchi, K., 2012. Novel configuration of finite-impulse-response
    filters tolerant to carrier-phase fluctuations in digital coherent optical receivers for
    higher-order quadrature amplitude modulation signals. Optics express, 20(24), pp.26236-26251.
    '''

    if train is None:
        if data is None:
            data = np.full((signal.shape[0], signal.shape[-1]), 0, dtype=signal.dtype)
            train = np.full((signal.shape[0],), False)
        else:
            train = np.concatenate([np.full((data.shape[0],), True),
                                    np.full((signal.shape[0] - data.shape[0],), False)])
    else:
        if train.shape[0] != signal.shape[0] or data.shape[0] != signal.shape[0]:
           raise ValueError('invalid shape')

    dims = signal.shape[-1]
    if f_init is None:
        f_init = np.full((dims,), 1+0j, dtype=signal.dtype) # dummy initial value
    if s_init is None:
        s_init = np.full((dims,), 1+0j, dtype=signal.dtype) # dummy initial value

    params = (w_init, f_init, s_init, lr_w, lr_f, lr_s, grad_max, 1e-8, const)
    inputs = (signal, data, train)

    params = device_put(params, device)
    inputs = device_put(inputs, device)

    ret = _dd_lms(params, inputs)

    return ret
예제 #16
0
def cpr_ekf(signal, init_states=None, device=cpus[0]):
    Q = 3.0e-5 * np.eye(1)
    R = 2.0e-2 * np.eye(2)
    const = comm.const("16QAM", norm=True)

    P_corr = np.array([[1.]])
    phi_corr = np.array([[0.]])

    init_states = (Q, R, phi_corr, P_corr)

    init_states = device_put(init_states, device)
    const = device_put(const, device)

    @jit
    def step(states, r):
        Q, R, phi_corr, P_corr = states
        phi_pred = phi_corr
        P_pred = P_corr + Q

        phi_pred_C = jnp.exp(1j * phi_pred[0,0])
        s_hat = const[jnp.argmin(jnp.abs(const - r * phi_pred_C.conj()))]
        r_hat_pred = s_hat * phi_pred_C

        H_C = 1j * r_hat_pred
        H = jnp.array([[H_C.real], [H_C.imag]])
        I = jnp.array([[(r - r_hat_pred).real], [(r - r_hat_pred).imag]])
        S = H @ P_pred @ H.T + R
        K = P_pred @ H.T @ jnp.linalg.inv(S)
        phi_corr = phi_pred + K @ I
        P_corr = P_pred - K @ H @ P_pred

        return (Q, R, phi_corr, P_corr), phi_pred[0,0]

    _, ret = scan(step, init_states, signal)

    return ret
예제 #17
0
def rdemimo(signal, truth, lr=1/2**13, sps=2, taps=31, backend='cpu',
            const=comm.const("16QAM", norm=True)):
    Rs = np.array(np.unique(np.abs(const)))
    return jit(_rdemimo,
               static_argnums=(4, 5),
               backend=backend)(signal, truth, lr, Rs, sps, taps)
예제 #18
0
def cpane_ekf(
    train: Union[bool, Schedule] = False,
    alpha: float = 0.99,
    beta: float = 0.6,
    Q: complex = 1e-4 + 0j,
    R: complex = 1e-2 + 0j,
    akf: bool = True,
    const: Array = comm.const("16QAM", norm=True)
) -> AdaptiveFilter:
    """Carrier Phase and Amplitude Noise Estimator
    symbol-by-symbol fine carrier phsae recovery using extended Kalman filter

    Args:
      train: scheduler for training mode
      alpha: smoothening factor
      beta: smoothening factor
      Q: covariance matrix of observer noises
      R: covariance matrix of system noises
      akf: adaptive controlling of Q and R, a.k.a, AKF
      const: reference constellation

    Returns:
      a ``AdaptiveFilter`` object

    References:
      - [1] Pakala, L. and Schmauss, B., 2016. Extended Kalman filtering for joint mitigation
        of phase and amplitude noise in coherent QAM systems. Optics express, 24(6), pp.6391-6401.
      - [2] Akhlaghi, Shahrokh, Ning Zhou, and Zhenyu Huang. "Adaptive adjustment of noise
        covariance in Kalman filter for dynamic state estimation." 2017 IEEE power & energy
        society general meeting. IEEE, 2017.
    """
    const = jnp.asarray(const)
    train = cxopt.make_schedule(train)

    def init(p0=0j):
        state0 = (p0, 1j, 0j, Q, R)
        return state0

    def update(i, state, inp):
        Psi_c, P_c, Psi_a, Q, R = state
        y, x = inp

        Psi_p = Psi_c
        P_p = P_c + Q
        # exponential moving average
        Psi_a = beta * Psi_a + (1 - beta) * Psi_c

        d = jnp.where(
            train(i), x,
            const[jnp.argmin(jnp.abs(const - y * jnp.exp(-1j * Psi_a)))])

        H = 1j * d * jnp.exp(1j * Psi_p)
        K = P_p * H.conj() / (H * P_p * H.conj() + R)
        v = y - d * jnp.exp(1j * Psi_p)

        out = (Psi_c, (Q, R))

        Psi_c = Psi_p + K * v
        P_c = (1. - K * H) * P_p
        e = y - d * jnp.exp(1j * Psi_c)
        Q = alpha * Q + (1 - alpha) * K * v * v.conj() * K.conj() if akf else Q
        R = alpha * R + (1 - alpha) * (e * e.conj() +
                                       H * P_p * H.conj()) if akf else R

        state = (Psi_c, P_c, Psi_a, Q, R)

        return state, out

    def apply(Psi, ys):
        return ys * jnp.exp(-1j * Psi)

    return AdaptiveFilter(init, update, apply)
예제 #19
0
def frame_cpr_kf(
        Q: Array = jnp.array([[0, 0],
                              [0, 1e-9]]),  # 1e-8 is better if akf is False
        R: Array = jnp.array([[1e-2, 0], [0, 1e-3]]),
        const: Array = comm.const("16QAM", norm=True),
        train: Union[bool, Schedule] = False,
        akf: Schedule = cxopt.piecewise_constant([10, 500],
                                                 [False, True, False]),
        alpha: float = 0.999) -> AdaptiveFilter:
    """Block based estimator of carrier frequency offset

    frame-by-frame coarse carrier phsae recovery using Kalman filter, can tolerate 0.1 * baudrate
    frequency offset[1].
    
    Args:
        Q: covariance matrix of observer noises
        R: covariance matrix of system noises
        const: reference constellation used in decison stage
        train: scheduler of training mode
        akf: scheduler of AKF
        alpha: smoothening factor used in AKF
    
    Returns:
        A ``AdaptiveFilter`` object
    
    Caution:
        needs proper initialization of FO[1]
    
    References:
        - [1] Inoue, Takashi, and Shu Namiki. "Carrier recovery for M-QAM signals based on
          a block estimation process with Kalman filter." Optics express 22.13 (2014): 15376-15387.
        - [2] Akhlaghi, Shahrokh, Ning Zhou, and Zhenyu Huang. "Adaptive adjustment of noise
          covariance in Kalman filter for dynamic state estimation." 2017 IEEE power & energy
          society general meeting. IEEE, 2017.
    """
    const = jnp.asarray(const)
    train = cxopt.make_schedule(train)
    akf = cxopt.make_schedule(akf)

    def init(w0=0):
        z0 = jnp.array([[0], [w0]], dtype=jnp.float32)
        P0 = jnp.zeros((2, 2), dtype=jnp.float32)
        state0 = (z0, P0, Q)
        return state0

    def update(i, state, inp):
        z_c, P_c, Q = state
        y, x = inp

        N = y.shape[0]  # frame size
        A = jnp.array([[1, N], [0, 1]])
        I = jnp.eye(2)
        n = (jnp.arange(N) - (N - 1) / 2)

        z_p = A @ z_c
        P_p = A @ P_c @ A.T + Q
        phi_p = z_p[0, 0] + n * z_p[1, 0]  # linear approx.
        s_p = y * jnp.exp(-1j * phi_p)
        d = jnp.where(
            train(i), x,
            const[jnp.argmin(jnp.abs(const[None, :] - s_p[:, None]), axis=-1)])
        scd_p = s_p * d.conj()
        sumscd_p = jnp.sum(scd_p)
        e = jnp.array([[jnp.arctan(sumscd_p.imag / sumscd_p.real)],
                       [(jnp.sum(n * scd_p)).imag /
                        (jnp.sum(n * n * scd_p)).real]])

        G = P_p @ jnp.linalg.pinv((P_p + R))
        z_c = z_p + G @ e
        P_c = (I - G) @ P_p

        Q = jnp.where(akf(i), alpha * Q + (1 - alpha) * (G @ e @ e.T @ G), Q)

        out = (z_p[1, 0], phi_p)
        state = (z_c, P_c, Q)

        return state, out

    def apply(phis, ys):
        return jax.vmap(lambda y, phi: y * jnp.exp(-1j * phi))(ys, phis)

    return AdaptiveFilter(init, update, apply)
예제 #20
0
def ddlms(
    lr_w: Union[float, Schedule] = 1 / 2**6,
    lr_f: Union[float, Schedule] = 1 / 2**7,
    lr_s: Union[float, Schedule] = 0.,
    lr_b: Union[float, Schedule] = 1 / 2**11,
    train: Union[bool, Schedule] = False,
    grad_max: Tuple[float, float] = (30., 30.),
    eps: float = 1e-8,
    beta: float = 0.,
    const: Array = comm.const("16QAM", norm=True)
) -> AdaptiveFilter:
    """Decision-Directed Least Mean Square adaptive equalizer

    Args:
      lr_w: learning rate of MIMO(butterfly part)'s weights
      lr_f: learning rate of stage-I phase tracker
      lr_s: learning rate of stage-II phase tracker
      lr_b: learning rate of bias term
      train: controlling flag of training mode, which can be a bool for global control within one call
        or an array of bool to swich training on iteration basis
      grad_max: clipling threshold of the gradients of phase trackers
      eps: perturbative term to stablize normalized LMS
      beta: smoothening factor of phase trackers
      const: Optional; constellation used to infer R2 when R2 is None

    Returns:
      an ``AdaptiveFilter`` object

    Notes:
      - add bias term to handle varying DC component

    References:
      - [1] Mori, Y., Zhang, C. and Kikuchi, K., 2012. Novel configuration of
        finite-impulse-response filters tolerant to carrier-phase fluctuations
        in digital coherent optical receivers for higher-order quadrature
        amplitude modulation signals. Optics express, 20(24), pp.26236-26251.
    """
    const = jnp.asarray(const)
    lr_w = cxopt.make_schedule(lr_w)
    lr_f = cxopt.make_schedule(lr_f)
    lr_s = cxopt.make_schedule(lr_s)
    lr_b = cxopt.make_schedule(lr_b)
    train = cxopt.make_schedule(train)

    def init(taps=31, dims=2, dtype=jnp.complex64, mimoinit='zeros'):
        w0 = mimoinitializer(taps, dims, dtype, mimoinit)
        f0 = jnp.full((dims, ), 1., dtype=dtype)
        s0 = jnp.full((dims, ), 1., dtype=dtype)
        b0 = jnp.full((dims, ), 0., dtype=dtype)
        fshat0 = jnp.full((dims, ), 1., dtype=dtype)
        return (w0, f0, s0, b0, fshat0)

    def update(i, state, inp):
        w, f, s, b, fshat = state
        u, x = inp

        v = mimo(w, u)
        # v = r2c(mimo(w, u)[None, :])[0, :]
        k = v * f
        c = k * s
        z = c + b
        q = v * fshat + b
        d = jnp.where(train(i), x, decision(const, q))
        l = jnp.sum(jnp.abs(z - d)**2)

        psi_hat = jnp.abs(f) / f * jnp.abs(s) / s
        e_w = (d - b) * psi_hat - v
        e_f = d - b - k
        e_s = d - b - c
        e_b = d - z
        gw = -1. / ((jnp.abs(u)**2).sum() +
                    eps) * e_w[:, None, None] * u.conj().T[None, ...]
        # gw = -e_w[:, None, None] * u.conj().T[None, ...]
        gf = -1. / (jnp.abs(v)**2 + eps) * e_f * v.conj()
        gs = -1. / (jnp.abs(k)**2 + eps) * e_s * k.conj()
        gb = -e_b

        # bound the grads of f and s which are less regulated than w,
        # it may stablize this algo. by experience
        gf = jnp.where(
            jnp.abs(gf) > grad_max[0], gf / jnp.abs(gf) * grad_max[0], gf)
        gs = jnp.where(
            jnp.abs(gs) > grad_max[1], gs / jnp.abs(gs) * grad_max[1], gs)

        out = ((w, f, s, b), (l, d))

        # update
        w = w - lr_w(i) * gw
        f = f - lr_f(i) * gf
        s = s - lr_s(i) * gs
        b = b - lr_b(i) * gb
        fshat = beta * fshat + (1 - beta) * (f * s)

        state = (w, f, s, b, fshat)

        return state, out

    def apply(ps, yf):
        ws, fs, ss, bs = ps
        return jax.vmap(mimo)(ws, yf) * fs * ss + bs

    return AdaptiveFilter(init, update, apply)
예제 #21
0
def cpr_foe_ekf(signal, init_states=None, device=cpus[0]):
    '''
    References:
    [1] Jain, A., Krishnamurthy, P.K., Landais, P. and Anandarajah, P.M., 2017.
        EKF for joint mitigation of phase noise, frequency offset and nonlinearity
        in 400 Gb/s PM-16-QAM and 200 Gb/s PM-QPSK systems. IEEE Photonics Journal,
        9(1), pp.1-10.
    [2] Lin, W.T. and Chang, D.C., 2006, May. The extended Kalman filtering algorithm
        for carrier synchronization and the implementation. In 2006 IEEE International
        Symposium on Circuits and Systems (pp. 4-pp). IEEE.
    [3] Akhlaghi, Shahrokh, Ning Zhou, and Zhenyu Huang. "Adaptive adjustment of noise
        covariance in Kalman filter for dynamic state estimation." 2017 IEEE power & energy
        society general meeting. IEEE, 2017.
    '''

    if init_states is None:
        init_states = (
          jnp.array([[1e-2,  0],
                     [0,  1e-5]]),
          1e-1 * jnp.eye(2),
          jnp.array([[0.],
                     [0.]]),
          1. * jnp.eye(2)
        )

    A = jnp.array([[1, 1],
                   [0, 1]])

    const = comm.const("16QAM", norm=True)

    signal = device_put(signal, device)
    init_states = device_put(init_states, device)
    A = device_put(A, device)
    const = device_put(const, device)

    @jit
    def step(states, r):
        Q, R, x_c, P_c = states

        x_p = A @ x_c
        P_p = A @ P_c @ A.T + Q

        p_p = jnp.exp(1j * x_p[0,0])
        s_hat_p = const[jnp.argmin(jnp.abs(const - r * p_p.conj()))]
        r_hat_p = s_hat_p * p_p

        d = r - r_hat_p
        H = jnp.array([[-r_hat_p.imag, 0],
                       [ r_hat_p.real, 0]])
        I = jnp.array([[d.real],
                       [d.imag]])
        S = H @ P_p @ H.T + R
        K = P_p @ H.T @ jnp.linalg.inv(S)
        x_c = x_p + K @ I
        P_c = P_p - K @ H @ P_p

        # adapt Q and R
        beta = .99
        # p_c  = jnp.exp(1j * x_c[0,0])
        # e = r - s_hat_p * p_c
        # e_R = jnp.array([[e.real],
        #                  [e.imag]])
        # R = beta * R + (1 - beta) * (e_R @ e_R.T + H @ P_p @ H.T)
        Q = beta * Q + (1. - beta) * K @ I @ I.T @ K.T

        return (Q, R, x_c, P_c), x_p[:,0]

    _, ret = scan(step, init_states, signal)

    return ret.T