def rls_cma(y_f, w_init, beta=0.999, delta=1e-4, const=comm.const("16QAM", norm=True), device=cpus[0]): ''' References: [1] Faruk, M.S. and Savory, S.J., 2017. Digital signal processing for coherent transceivers employing multilevel formats. Journal of Lightwave Technology, 35(5), pp.1125-1141. ''' R2 = jnp.mean(abs(const)**4) / jnp.mean(abs(const)**2) N = y_f.shape[0] taps = w_init.shape[-1] dims = w_init.shape[0] # w_init: DxDxT -> DTxD w_init = jnp.reshape(w_init.conj(), (dims, dims * taps)).T cI = jnp.eye(dims * taps, dtype=y_f.dtype) P_init = delta * jnp.tile(cI[...,None], (1, 1, dims)) y_f = jnp.reshape(y_f, (N, -1), order='F') params = (w_init, P_init, R2, beta) inputs = (y_f,) params = device_put(params, device) inputs = device_put(inputs, device) _, ret = scan(step_rls_cma, params, inputs) l, w = ret # w: NxDTxD -> NxDxDT w = jnp.moveaxis(w, 0, -1).T # w: NxDxDT -> NxDxDxT w = jnp.reshape(w, (N, dims, dims, taps)).conj() return l, w
def framekfcpr(signal, truth=None, n=100, w0=None, modformat='16QAM', const=None, backend='cpu'): y = jnp.asarray(signal) x = jnp.asarray(truth) if truth is not None else truth if const is None: const=comm.const(modformat, norm=True) const = jnp.asarray(const) return jit(_framekfcpr, backend=backend, static_argnums=2)(y, x, n, w0, const)
def ekfcpr(signal, truth=None, modformat='16QAM', const=None, backend='cpu'): y = jnp.asarray(signal) x = jnp.asarray(truth) if truth is not None else truth if const is None: const=comm.const(modformat, norm=True) const = jnp.asarray(const) return jit(_ekfcpr, backend=backend)(y, x, const)
def rde(lr: Union[float, Schedule] = 2**-15, train: Union[bool, Schedule] = False, Rs: Array = jnp.unique(jnp.abs(comm.const("16QAM", norm=True))), const: Optional[Array] = None) -> AdaptiveFilter: """Radius Directed adaptive Equalizer Args: lr: learning rate. scalar or Schedule train: schedule training mode, which can be a bool for global control within one call or an array of bool to swich training on iteration basis Rs: the radii of the target constellation const: Optional; constellation used to infer R2 when R2 is None Returns: an ``AdaptiveFilter`` object References: - [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and carrier phase recovery in a 16-QAM optical coherent system. Journal of lightwave technology, 27(15), pp.3042-3049. """ lr = cxopt.make_schedule(lr) train = cxopt.make_schedule(train) if const is not None: Rs = jnp.array(jnp.unique(jnp.abs(const))) def init(dims=2, w0=None, taps=32, dtype=np.complex64): if w0 is None: w0 = np.zeros((dims, dims, taps), dtype=dtype) ctap = (taps + 1) // 2 - 1 w0[np.arange(dims), np.arange(dims), ctap] = 1. return w0 def loss_fn(w, u, x, i): v = r2c(mimo(w, u)[None, :]) R2 = jnp.where( train(i), jnp.abs(x)**2, Rs[jnp.argmin(jnp.abs(Rs[:, None] * v / jnp.abs(v) - v), axis=0)]**2) l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0, :])**2)) return l def update(i, w, inp): u, x = inp l, g = jax.value_and_grad(loss_fn)(w, u, x, i) out = (w, l) w = w - lr(i) * g.conj() return w, out def apply(ws, yf): return jax.vmap(mimo)(ws, yf) return AdaptiveFilter(init, update, apply)
def lms_cpane(signal, w_init, data=None, train=None, lr=1e-4, beta=0.7, const=comm.const("16QAM", norm=True), device=cpus[0]): const = comm.const("16QAM", norm=True) if train is None: train = np.full((signal.shape[0],), False) data = np.full((signal.shape[0],), 0, dtype=const.dtype) dims = signal.shape[-1] params_lms = (w_init, lr) params_cpane = tuple(map(lambda x: np.tile(x, dims), [1e-5 * (1.+1j), 1e-2 * (1.+1j), 0j, 1j, 0j, beta])) + (const,) params = (params_lms, params_cpane) inputs = (signal, data, train) params = device_put(params, device) inputs = device_put(inputs, device) _, ret = scan(step_lms_cpane, params, inputs) return ret
def cma_2sec(y_f, h_init, w_init, mu1=1e-4, mu2=1e-1, const=comm.const("16QAM", norm=True), device=cpus[0]): R2 = jnp.mean(abs(const)**4) / jnp.mean(abs(const)**2) params = (h_init, w_init, mu1, mu2, R2) inputs = (y_f,) params = device_put(params, device) inputs = device_put(inputs, device) _, ret = scan(step_cma_2sec, params, inputs) return ret
def ddlms(mu_w=1/2**10, mu_f=1/2**6, mu_s=0., grad_max=(50., 50.), eps=1e-8, const=comm.const("16QAM", norm=True)): ''' Impl. follows Fig. 6 in [1] References: [1] Mori, Y., Zhang, C. and Kikuchi, K., 2012. Novel configuration of finite-impulse-response filters tolerant to carrier-phase fluctuations in digital coherent optical receivers for higher-order quadrature amplitude modulation signals. Optics express, 20(24), pp.26236-26251. ''' def init(state0): return state0 def update(state, inp): w, f, s, = state u, x, train = inp v = mimo(w, u) z = v * f * s d = jnp.where(train, x, const[jnp.argmin(jnp.abs(const[:,None] - z[None,:]), axis=0)]) psi_hat = jnp.abs(f)/f * jnp.abs(s)/s e_p = d * psi_hat - v e_f = d - f * v e_s = d - s * f * v gs = -1. / (jnp.abs(f * v)**2 + eps) * e_s * (f * v).conj() gf = -1. / (jnp.abs(v)**2 + eps) * e_f * v.conj() # clip the grads of f and s which are less regulated than w, # it may stablize this algo. in some corner cases? gw = -e_p[:, None, None] * u.conj().T[None, ...] gf = jnp.where(jnp.abs(gf) > grad_max[0], gf / jnp.abs(gf) * grad_max[0], gf) gs = jnp.where(jnp.abs(gs) > grad_max[1], gs / jnp.abs(gs) * grad_max[1], gs) out = (w, f, s, d) # update w = w - mu_w * gw f = f - mu_f * gf s = s - mu_s * gs state = (w, f, s) return state, out def static_map(ps, yf): ws, fs, ss = ps return jax.vmap(mimo)(ws, yf) * fs * ss return AdaptiveFilter(init, update, static_map)
def mu_cma(y_f, w_init, lr=1e-4, alpha=0.9999, const=comm.const("16QAM", norm=True), device=cpus[0]): d = jnp.mean(abs(const)**4) / jnp.mean(abs(const)**2) ntap = w_init.shape[-1] nch = y_f.shape[-1] z = jnp.zeros((ntap, nch), dtype=y_f.dtype) c = jnp.zeros((nch, nch, ntap), dtype=y_f.dtype) y_f = device_put(y_f, device) w_init = device_put(w_init, device) lr = device_put(lr, device) alpha = device_put(alpha, device) d = device_put(d, device) _, w = scan(step_mu_cma, (w_init, d, c, z, alpha, lr), y_f) return w
def cpane_ekf(beta=0.8, Q=1e-5 * (1.+1j), R=1e-2 * (1.+1j), const=comm.const("16QAM", norm=True)): ''' References: [1] Pakala, L. and Schmauss, B., 2016. Extended Kalman filtering for joint mitigation of phase and amplitude noise in coherent QAM systems. Optics express, 24(6), pp.6391-6401. ''' const = jnp.array(const) def init(p0=0j): state0 = (p0, 1j, 0j) return state0 def update(state, inp): Psi_c, P_c, Psi_a = state y, x, train = inp Psi_p = Psi_c P_p = P_c + Q Psi_a = beta * Psi_a + (1 - beta) * Psi_c d = jnp.where(train, x, const[jnp.argmin(jnp.abs(const - y * jnp.exp(-1j * Psi_a)))]) H = 1j * d * jnp.exp(1j * Psi_p) K = P_p * H.conj() / (H * P_p * H.conj() + R) v = y - d * jnp.exp(1j * Psi_p) out = (Psi_c, d) Psi_c = Psi_p + K * v P_c = (1. - K * H) * P_p state = (Psi_c, P_c, Psi_a) return state, out def static_map(Psi, ys): return ys * jnp.exp(-1j * Psi) return AdaptiveFilter(init, update, static_map)
def modulusmimo(signal, sps=2, taps=32, lr=2**-14, cma_samples=20000, modformat='16QAM', const=None, backend='cpu'): ''' Adaptive MIMO equalizer for M-QAM signal ''' y = jnp.asarray(signal) if y.shape[0] < cma_samples: raise ValueError('cam_samples must > given samples length') if const is None: const = comm.const(modformat, norm=True) R2 = np.array(np.mean(abs(const)**4) / np.mean(abs(const)**2)) Rs = np.array(np.unique(np.abs(const))) return jit(_modulusmimo, static_argnums=(3, 4, 5), backend=backend)(y, R2, Rs, sps, taps, cma_samples, lr)
def rde(lr=1e-4, Rs=jnp.unique(jnp.abs(comm.const("16QAM", norm=True)))): ''' References: [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and carrier phase recovery in a 16-QAM optical coherent system. Journal of lightwave technology, 27(15), pp.3042-3049. ''' def init(w0=None, taps=19, dims=2, unitarize=False): if w0 is None: w0 = np.zeros((2, 2, taps), dtype=np.complex64) ctap = (taps + 1) // 2 - 1 w0[np.arange(dims), np.arange(dims), ctap] = 1. elif unitarize: try: w0 = unitarize_mimo_weights(w0) except: pass return w0 def update(w, inp): u, Rx, train = inp def loss_fn(w, u): v = mimo(w, u)[None,:] R2 = jnp.where(train, Rx**2, Rs[jnp.argmin( jnp.abs(Rs[:,None] * v / jnp.abs(v) - v), axis=0)]**2) l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0,:])**2)) return l l, g = jax.value_and_grad(loss_fn)(w, u) out = (l, w) w = w - lr * g.conj() return w, out def static_map(ws, yf): return jax.vmap(mimo)(ws, yf) return AdaptiveFilter(init, update, static_map)
def cpane_ekf(signal, data=None, train=None, beta=0.8, device=cpus[0]): ''' References: [1] Pakala, L. and Schmauss, B., 2016. Extended Kalman filtering for joint mitigation of phase and amplitude noise in coherent QAM systems. Optics express, 24(6), pp.6391-6401. ''' const = comm.const("16QAM", norm=True) if train is None: train = np.full((signal.shape[0],), False) data = np.full((signal.shape[0],), 0, dtype=const.dtype) params = (1e-5 * (1.+1j), 1e-2 * (1.+1j), 0j, 1j, 0j, beta, const) inputs = (signal, data, train) params = device_put(params, device) inputs = device_put(inputs, device) psi_hat = _cpane_ekf(params, inputs) return psi_hat
def lms( lr: Union[float, Schedule] = 1e-4, const: Optional[Array] = comm.const('16QAM', norm=True) ) -> AdaptiveFilter: """LMS MIMO adaptive filter. Args: lr: Optional; learning rate const: Optional; constellation used to infer R2 when R2 is None Returns: an ``AdaptiveFilter`` object """ lr = cxopt.make_schedule(lr) if const is not None: const = jnp.asarray(const) def init(w0=None, taps=19, dims=2, dtype=np.complex64): if w0 is None: w0 = np.zeros((dims, dims, taps), dtype=dtype) ctap = (taps + 1) // 2 - 1 w0[np.arange(dims), np.arange(dims), ctap] = 1. return w0.astype(dtype) def loss_fn(w, u): v = r2c(mimo(w, u)[None, :])[0, :] d = decision(const, v) loss = jnp.sum(jnp.abs(d - v)**2) return loss def update(i, w, u): l, g = jax.value_and_grad(loss_fn)(w, u) out = (w, l) w = w - lr(i) * g.conj() return w, out def apply(ws, yf): return jax.vmap(mimo)(ws, yf) return AdaptiveFilter(init, update, apply)
def rde(signal, w_init, data=None, train=None, lr=1e-4, const=comm.const("16QAM", norm=True), device=cpus[0]): ''' References: [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and carrier phase recovery in a 16-QAM optical coherent system. Journal of lightwave technology, 27(15), pp.3042-3049. ''' if train is None: train = np.full((signal.shape[0],), False) data = np.full((signal.shape[0], signal.shape[-1]), 0, dtype=signal.dtype) else: if train.shape[0] != signal.shape[0] or data.shape[0] != signal.shape[0]: raise ValueError('invalid shape') Rs = np.unique(np.abs(const)) params = (w_init, jnp.array([0j, 0j]), lr, Rs) inputs = (signal, data, train) params = device_put(params, device) inputs = device_put(inputs, device) _, ret = scan(step_rde, params, inputs) return ret
def dd_lms(signal, w_init, f_init=None, s_init=None, data=None, train=None, lr_w=1/2**10, lr_f=1/2**6, lr_s=0., grad_max=(50., 50.), const=comm.const("16QAM", norm=True), device=cpus[0]): ''' Impl. follows Fig. 6 in [1] References: [1] Mori, Y., Zhang, C. and Kikuchi, K., 2012. Novel configuration of finite-impulse-response filters tolerant to carrier-phase fluctuations in digital coherent optical receivers for higher-order quadrature amplitude modulation signals. Optics express, 20(24), pp.26236-26251. ''' if train is None: if data is None: data = np.full((signal.shape[0], signal.shape[-1]), 0, dtype=signal.dtype) train = np.full((signal.shape[0],), False) else: train = np.concatenate([np.full((data.shape[0],), True), np.full((signal.shape[0] - data.shape[0],), False)]) else: if train.shape[0] != signal.shape[0] or data.shape[0] != signal.shape[0]: raise ValueError('invalid shape') dims = signal.shape[-1] if f_init is None: f_init = np.full((dims,), 1+0j, dtype=signal.dtype) # dummy initial value if s_init is None: s_init = np.full((dims,), 1+0j, dtype=signal.dtype) # dummy initial value params = (w_init, f_init, s_init, lr_w, lr_f, lr_s, grad_max, 1e-8, const) inputs = (signal, data, train) params = device_put(params, device) inputs = device_put(inputs, device) ret = _dd_lms(params, inputs) return ret
def cpr_ekf(signal, init_states=None, device=cpus[0]): Q = 3.0e-5 * np.eye(1) R = 2.0e-2 * np.eye(2) const = comm.const("16QAM", norm=True) P_corr = np.array([[1.]]) phi_corr = np.array([[0.]]) init_states = (Q, R, phi_corr, P_corr) init_states = device_put(init_states, device) const = device_put(const, device) @jit def step(states, r): Q, R, phi_corr, P_corr = states phi_pred = phi_corr P_pred = P_corr + Q phi_pred_C = jnp.exp(1j * phi_pred[0,0]) s_hat = const[jnp.argmin(jnp.abs(const - r * phi_pred_C.conj()))] r_hat_pred = s_hat * phi_pred_C H_C = 1j * r_hat_pred H = jnp.array([[H_C.real], [H_C.imag]]) I = jnp.array([[(r - r_hat_pred).real], [(r - r_hat_pred).imag]]) S = H @ P_pred @ H.T + R K = P_pred @ H.T @ jnp.linalg.inv(S) phi_corr = phi_pred + K @ I P_corr = P_pred - K @ H @ P_pred return (Q, R, phi_corr, P_corr), phi_pred[0,0] _, ret = scan(step, init_states, signal) return ret
def rdemimo(signal, truth, lr=1/2**13, sps=2, taps=31, backend='cpu', const=comm.const("16QAM", norm=True)): Rs = np.array(np.unique(np.abs(const))) return jit(_rdemimo, static_argnums=(4, 5), backend=backend)(signal, truth, lr, Rs, sps, taps)
def cpane_ekf( train: Union[bool, Schedule] = False, alpha: float = 0.99, beta: float = 0.6, Q: complex = 1e-4 + 0j, R: complex = 1e-2 + 0j, akf: bool = True, const: Array = comm.const("16QAM", norm=True) ) -> AdaptiveFilter: """Carrier Phase and Amplitude Noise Estimator symbol-by-symbol fine carrier phsae recovery using extended Kalman filter Args: train: scheduler for training mode alpha: smoothening factor beta: smoothening factor Q: covariance matrix of observer noises R: covariance matrix of system noises akf: adaptive controlling of Q and R, a.k.a, AKF const: reference constellation Returns: a ``AdaptiveFilter`` object References: - [1] Pakala, L. and Schmauss, B., 2016. Extended Kalman filtering for joint mitigation of phase and amplitude noise in coherent QAM systems. Optics express, 24(6), pp.6391-6401. - [2] Akhlaghi, Shahrokh, Ning Zhou, and Zhenyu Huang. "Adaptive adjustment of noise covariance in Kalman filter for dynamic state estimation." 2017 IEEE power & energy society general meeting. IEEE, 2017. """ const = jnp.asarray(const) train = cxopt.make_schedule(train) def init(p0=0j): state0 = (p0, 1j, 0j, Q, R) return state0 def update(i, state, inp): Psi_c, P_c, Psi_a, Q, R = state y, x = inp Psi_p = Psi_c P_p = P_c + Q # exponential moving average Psi_a = beta * Psi_a + (1 - beta) * Psi_c d = jnp.where( train(i), x, const[jnp.argmin(jnp.abs(const - y * jnp.exp(-1j * Psi_a)))]) H = 1j * d * jnp.exp(1j * Psi_p) K = P_p * H.conj() / (H * P_p * H.conj() + R) v = y - d * jnp.exp(1j * Psi_p) out = (Psi_c, (Q, R)) Psi_c = Psi_p + K * v P_c = (1. - K * H) * P_p e = y - d * jnp.exp(1j * Psi_c) Q = alpha * Q + (1 - alpha) * K * v * v.conj() * K.conj() if akf else Q R = alpha * R + (1 - alpha) * (e * e.conj() + H * P_p * H.conj()) if akf else R state = (Psi_c, P_c, Psi_a, Q, R) return state, out def apply(Psi, ys): return ys * jnp.exp(-1j * Psi) return AdaptiveFilter(init, update, apply)
def frame_cpr_kf( Q: Array = jnp.array([[0, 0], [0, 1e-9]]), # 1e-8 is better if akf is False R: Array = jnp.array([[1e-2, 0], [0, 1e-3]]), const: Array = comm.const("16QAM", norm=True), train: Union[bool, Schedule] = False, akf: Schedule = cxopt.piecewise_constant([10, 500], [False, True, False]), alpha: float = 0.999) -> AdaptiveFilter: """Block based estimator of carrier frequency offset frame-by-frame coarse carrier phsae recovery using Kalman filter, can tolerate 0.1 * baudrate frequency offset[1]. Args: Q: covariance matrix of observer noises R: covariance matrix of system noises const: reference constellation used in decison stage train: scheduler of training mode akf: scheduler of AKF alpha: smoothening factor used in AKF Returns: A ``AdaptiveFilter`` object Caution: needs proper initialization of FO[1] References: - [1] Inoue, Takashi, and Shu Namiki. "Carrier recovery for M-QAM signals based on a block estimation process with Kalman filter." Optics express 22.13 (2014): 15376-15387. - [2] Akhlaghi, Shahrokh, Ning Zhou, and Zhenyu Huang. "Adaptive adjustment of noise covariance in Kalman filter for dynamic state estimation." 2017 IEEE power & energy society general meeting. IEEE, 2017. """ const = jnp.asarray(const) train = cxopt.make_schedule(train) akf = cxopt.make_schedule(akf) def init(w0=0): z0 = jnp.array([[0], [w0]], dtype=jnp.float32) P0 = jnp.zeros((2, 2), dtype=jnp.float32) state0 = (z0, P0, Q) return state0 def update(i, state, inp): z_c, P_c, Q = state y, x = inp N = y.shape[0] # frame size A = jnp.array([[1, N], [0, 1]]) I = jnp.eye(2) n = (jnp.arange(N) - (N - 1) / 2) z_p = A @ z_c P_p = A @ P_c @ A.T + Q phi_p = z_p[0, 0] + n * z_p[1, 0] # linear approx. s_p = y * jnp.exp(-1j * phi_p) d = jnp.where( train(i), x, const[jnp.argmin(jnp.abs(const[None, :] - s_p[:, None]), axis=-1)]) scd_p = s_p * d.conj() sumscd_p = jnp.sum(scd_p) e = jnp.array([[jnp.arctan(sumscd_p.imag / sumscd_p.real)], [(jnp.sum(n * scd_p)).imag / (jnp.sum(n * n * scd_p)).real]]) G = P_p @ jnp.linalg.pinv((P_p + R)) z_c = z_p + G @ e P_c = (I - G) @ P_p Q = jnp.where(akf(i), alpha * Q + (1 - alpha) * (G @ e @ e.T @ G), Q) out = (z_p[1, 0], phi_p) state = (z_c, P_c, Q) return state, out def apply(phis, ys): return jax.vmap(lambda y, phi: y * jnp.exp(-1j * phi))(ys, phis) return AdaptiveFilter(init, update, apply)
def ddlms( lr_w: Union[float, Schedule] = 1 / 2**6, lr_f: Union[float, Schedule] = 1 / 2**7, lr_s: Union[float, Schedule] = 0., lr_b: Union[float, Schedule] = 1 / 2**11, train: Union[bool, Schedule] = False, grad_max: Tuple[float, float] = (30., 30.), eps: float = 1e-8, beta: float = 0., const: Array = comm.const("16QAM", norm=True) ) -> AdaptiveFilter: """Decision-Directed Least Mean Square adaptive equalizer Args: lr_w: learning rate of MIMO(butterfly part)'s weights lr_f: learning rate of stage-I phase tracker lr_s: learning rate of stage-II phase tracker lr_b: learning rate of bias term train: controlling flag of training mode, which can be a bool for global control within one call or an array of bool to swich training on iteration basis grad_max: clipling threshold of the gradients of phase trackers eps: perturbative term to stablize normalized LMS beta: smoothening factor of phase trackers const: Optional; constellation used to infer R2 when R2 is None Returns: an ``AdaptiveFilter`` object Notes: - add bias term to handle varying DC component References: - [1] Mori, Y., Zhang, C. and Kikuchi, K., 2012. Novel configuration of finite-impulse-response filters tolerant to carrier-phase fluctuations in digital coherent optical receivers for higher-order quadrature amplitude modulation signals. Optics express, 20(24), pp.26236-26251. """ const = jnp.asarray(const) lr_w = cxopt.make_schedule(lr_w) lr_f = cxopt.make_schedule(lr_f) lr_s = cxopt.make_schedule(lr_s) lr_b = cxopt.make_schedule(lr_b) train = cxopt.make_schedule(train) def init(taps=31, dims=2, dtype=jnp.complex64, mimoinit='zeros'): w0 = mimoinitializer(taps, dims, dtype, mimoinit) f0 = jnp.full((dims, ), 1., dtype=dtype) s0 = jnp.full((dims, ), 1., dtype=dtype) b0 = jnp.full((dims, ), 0., dtype=dtype) fshat0 = jnp.full((dims, ), 1., dtype=dtype) return (w0, f0, s0, b0, fshat0) def update(i, state, inp): w, f, s, b, fshat = state u, x = inp v = mimo(w, u) # v = r2c(mimo(w, u)[None, :])[0, :] k = v * f c = k * s z = c + b q = v * fshat + b d = jnp.where(train(i), x, decision(const, q)) l = jnp.sum(jnp.abs(z - d)**2) psi_hat = jnp.abs(f) / f * jnp.abs(s) / s e_w = (d - b) * psi_hat - v e_f = d - b - k e_s = d - b - c e_b = d - z gw = -1. / ((jnp.abs(u)**2).sum() + eps) * e_w[:, None, None] * u.conj().T[None, ...] # gw = -e_w[:, None, None] * u.conj().T[None, ...] gf = -1. / (jnp.abs(v)**2 + eps) * e_f * v.conj() gs = -1. / (jnp.abs(k)**2 + eps) * e_s * k.conj() gb = -e_b # bound the grads of f and s which are less regulated than w, # it may stablize this algo. by experience gf = jnp.where( jnp.abs(gf) > grad_max[0], gf / jnp.abs(gf) * grad_max[0], gf) gs = jnp.where( jnp.abs(gs) > grad_max[1], gs / jnp.abs(gs) * grad_max[1], gs) out = ((w, f, s, b), (l, d)) # update w = w - lr_w(i) * gw f = f - lr_f(i) * gf s = s - lr_s(i) * gs b = b - lr_b(i) * gb fshat = beta * fshat + (1 - beta) * (f * s) state = (w, f, s, b, fshat) return state, out def apply(ps, yf): ws, fs, ss, bs = ps return jax.vmap(mimo)(ws, yf) * fs * ss + bs return AdaptiveFilter(init, update, apply)
def cpr_foe_ekf(signal, init_states=None, device=cpus[0]): ''' References: [1] Jain, A., Krishnamurthy, P.K., Landais, P. and Anandarajah, P.M., 2017. EKF for joint mitigation of phase noise, frequency offset and nonlinearity in 400 Gb/s PM-16-QAM and 200 Gb/s PM-QPSK systems. IEEE Photonics Journal, 9(1), pp.1-10. [2] Lin, W.T. and Chang, D.C., 2006, May. The extended Kalman filtering algorithm for carrier synchronization and the implementation. In 2006 IEEE International Symposium on Circuits and Systems (pp. 4-pp). IEEE. [3] Akhlaghi, Shahrokh, Ning Zhou, and Zhenyu Huang. "Adaptive adjustment of noise covariance in Kalman filter for dynamic state estimation." 2017 IEEE power & energy society general meeting. IEEE, 2017. ''' if init_states is None: init_states = ( jnp.array([[1e-2, 0], [0, 1e-5]]), 1e-1 * jnp.eye(2), jnp.array([[0.], [0.]]), 1. * jnp.eye(2) ) A = jnp.array([[1, 1], [0, 1]]) const = comm.const("16QAM", norm=True) signal = device_put(signal, device) init_states = device_put(init_states, device) A = device_put(A, device) const = device_put(const, device) @jit def step(states, r): Q, R, x_c, P_c = states x_p = A @ x_c P_p = A @ P_c @ A.T + Q p_p = jnp.exp(1j * x_p[0,0]) s_hat_p = const[jnp.argmin(jnp.abs(const - r * p_p.conj()))] r_hat_p = s_hat_p * p_p d = r - r_hat_p H = jnp.array([[-r_hat_p.imag, 0], [ r_hat_p.real, 0]]) I = jnp.array([[d.real], [d.imag]]) S = H @ P_p @ H.T + R K = P_p @ H.T @ jnp.linalg.inv(S) x_c = x_p + K @ I P_c = P_p - K @ H @ P_p # adapt Q and R beta = .99 # p_c = jnp.exp(1j * x_c[0,0]) # e = r - s_hat_p * p_c # e_R = jnp.array([[e.real], # [e.imag]]) # R = beta * R + (1 - beta) * (e_R @ e_R.T + H @ P_p @ H.T) Q = beta * Q + (1. - beta) * K @ I @ I.T @ K.T return (Q, R, x_c, P_c), x_p[:,0] _, ret = scan(step, init_states, signal) return ret.T