def init_with_plda(self, x, speaker_ids, init_params, domain_ids=None): # Compute a PLDA model with the input data, approximating the # model with just the usual initialization without the EM iterations weights = compute_class_weights(speaker_ids, domain_ids, init_params.get('balance_by_domain')) # Debugging of weight usage in PLDA: # Repeat the data from the first 5000 speakers twice either explicitely or through the weights # These two models should be identical (and they are!) #sela = speaker_ids<5000 #selb = speaker_ids>=5000 #x2 = np.concatenate([x[sela],x[sela],x[selb]]) #speaker_ids2 = np.concatenate((speaker_ids[sela], speaker_ids[sela]+np.max(speaker_ids)+1, speaker_ids[selb])) #weights2 = np.ones(len(np.unique(speaker_ids2))) #BCov2, WCov2, mu2 = compute_2cov_plda_model(x2, speaker_ids2, weights2, 10) #weights3 = weights.copy() #weights3[0:5000] *= 2 #BCov3, WCov3, mu3 = compute_2cov_plda_model(x, speaker_ids, weights3, 10) #assert np.allclose(BCov2,BCov3) #assert np.allclose(WCov2,WCov3) #assert np.allclose(mu2, mu3) # Bi and Wi are the between and within covariance matrices and mu is the global (weighted) mean Bi, Wi, mu = compute_2cov_plda_model(x, speaker_ids, weights, init_params.get('plda_em_its', 0)) # Equations (14) and (16) in Cumani's paper # Note that the paper has an error in the formula for k (a 1/2 missing before k_tilde) # that is fixed in the equations below # To compute L_tild and G_tilde we use the following equality: # inv( inv(C) + n*inv(D) ) == C @ inv(D + n*C) @ D == C @ solve(D + n*C, D) B = utils.CholInv(Bi) W = utils.CholInv(Wi) Bmu = B @ mu.T L_tilde = Bi @ np.linalg.solve(Wi + 2 * Bi, Wi) G_tilde = Bi @ np.linalg.solve(Wi + Bi, Wi) WtGL = W @ (L_tilde - G_tilde) logdet_L_tilde = np.linalg.slogdet(L_tilde)[1] logdet_G_tilde = np.linalg.slogdet(G_tilde)[1] logdet_B = B.logdet() k_tilde = -2.0 * logdet_G_tilde + logdet_L_tilde - logdet_B + mu @ Bmu k = 0.5 * k_tilde + 0.5 * Bmu.T @ (L_tilde - 2 * G_tilde) @ Bmu L = 0.5 * (W @ (W @ L_tilde)).T G = 0.5 * (W @ WtGL).T C = (WtGL @ Bmu) state_dict = {'L': L, 'G': G, 'C': C, 'k': k.squeeze()} utils.replace_state_dict(self, state_dict)
def estep(stats, V, W): VtW = (W @ V).T VtWV = VtW @ V VtWf = VtW @ stats.F.T y_hat = np.zeros_like(stats.F).T R = np.zeros_like(V) llk = 0.0 for n in np.unique(stats.N): idxs = np.where(stats.N == n)[0] L = n * VtWV + np.eye(V.shape[0]) Linv = utils.CholInv(L) y_hat[:, idxs] = Linv @ VtWf[:, idxs] # The expression below is a robust way for solving # yy_sum = len(idxs)*Linv + np.dot(y_hat[:, idxs], y_hat[:, idxs].T) # while avoiding doing the inverse of L which can create numerical issues n_spkrs = np.sum(stats.weights[idxs]) yy_sum = np.linalg.solve( L, n_spkrs * np.eye(V.shape[0]) + L @ y_hat[:, idxs] @ (y_hat[:, idxs].T * stats.weights[idxs])) R += n * yy_sum llk += 0.5 * n_spkrs * Linv.logdet() T = y_hat @ (stats.F * stats.weights) llk += 0.5 * np.trace(T @ VtW.T) + 0.5 * np.sum( stats.N * stats.weights) * W.logdet() - 0.5 * np.trace(W @ stats.S) return R, T, llk / np.sum(stats.N * stats.weights)
def compute_2cov_plda_model(x, class_ids, class_weights, em_its=0): """ Follows the "EM for SPLDA" document from Niko Brummer: https://sites.google.com/site/nikobrummer/EMforSPLDA.pdf """ BCov, WCov, GCov, mu, muc, stats = compute_lda_model( x, class_ids, class_weights) W = utils.CholInv(WCov) v, e = linalg.eig(BCov) V = e * np.sqrt(np.real(v)) def mstep(R, T, S, n_samples): V = (utils.CholInv(R) @ T).T Winv = 1 / n_samples * (S - V @ T) W = utils.CholInv(Winv) return V, W, Winv def estep(stats, V, W): VtW = (W @ V).T VtWV = VtW @ V VtWf = VtW @ stats.F.T y_hat = np.zeros_like(stats.F).T R = np.zeros_like(V) llk = 0.0 for n in np.unique(stats.N): idxs = np.where(stats.N == n)[0] L = n * VtWV + np.eye(V.shape[0]) Linv = utils.CholInv(L) y_hat[:, idxs] = Linv @ VtWf[:, idxs] # The expression below is a robust way for solving # yy_sum = len(idxs)*Linv + np.dot(y_hat[:, idxs], y_hat[:, idxs].T) # while avoiding doing the inverse of L which can create numerical issues n_spkrs = np.sum(stats.weights[idxs]) yy_sum = np.linalg.solve( L, n_spkrs * np.eye(V.shape[0]) + L @ y_hat[:, idxs] @ (y_hat[:, idxs].T * stats.weights[idxs])) R += n * yy_sum llk += 0.5 * n_spkrs * Linv.logdet() T = y_hat @ (stats.F * stats.weights) llk += 0.5 * np.trace(T @ VtW.T) + 0.5 * np.sum( stats.N * stats.weights) * W.logdet() - 0.5 * np.trace(W @ stats.S) return R, T, llk / np.sum(stats.N * stats.weights) prev_llk = 0.0 for it in range(em_its): R, T, llk = estep(stats, V, W) V, W, WCov = mstep(R, T, stats.S, np.sum(stats.N * stats.weights)) print("EM it %d LLK = %.5f (+%.5f)" % (it, llk, llk - prev_llk)) prev_llk = llk BCov = V @ V.T return BCov, WCov, np.atleast_2d(mu)
def mstep(R, T, S, n_samples): V = (utils.CholInv(R) @ T).T Winv = 1/n_samples * (S-V@T) W = utils.CholInv(Winv) return V, W, Winv