def transform_org2talmon(source, target_train, target_test): covs_source = source['covs'] covs_target_train = target_train['covs'] covs_target_test = target_test['covs'] M_source = mean_riemann(covs_source) M_target_train = mean_riemann(covs_target_train) M = geodesic_riemann(M_source, M_target_train, alpha=0.5) Gamma = ParallelTransport(reference_old=M_source, reference_new=M) covs_source_transp = Gamma.fit_transform(covs_source) Gamma = ParallelTransport(reference_old=M_target_train, reference_new=M) covs_target_train_transp = Gamma.fit_transform(covs_target_train) covs_target_test_transp = Gamma.transform(covs_target_test) source_talmon = {} source_talmon['labels'] = source['labels'] source_talmon['covs'] = covs_source_transp target_talmon_train = {} target_talmon_train['labels'] = target_train['labels'] target_talmon_train['covs'] = covs_target_train_transp target_talmon_test = {} target_talmon_test['labels'] = target_test['labels'] target_talmon_test['covs'] = covs_target_test_transp return source_talmon, target_talmon_train, target_talmon_test
def transform_org2talmon_p300(source, target_train, target_test): covs_source = source['covs'] covs_target_train = target_train['covs'] covs_target_test = target_test['covs'] weights = np.ones(len(source['labels'])) weights[source['labels'] == 2] = 5 M_source = mean_riemann(covs_source, sample_weight=weights) weights = np.ones(len(target_train['labels'])) weights[target_train['labels'] == 2] = 5 M_target_train = mean_riemann(covs_target_train, sample_weight=weights) M = geodesic_riemann(M_source, M_target_train, alpha=0.5) Gamma = ParallelTransport(reference_old=M_source, reference_new=M) covs_source_transp = Gamma.fit_transform(covs_source) Gamma = ParallelTransport(reference_old=M_target_train, reference_new=M) covs_target_train_transp = Gamma.fit_transform(covs_target_train) covs_target_test_transp = Gamma.transform(covs_target_test) source_talmon = {} source_talmon['labels'] = source['labels'] source_talmon['covs'] = covs_source_transp target_talmon_train = {} target_talmon_train['labels'] = target_train['labels'] target_talmon_train['covs'] = covs_target_train_transp target_talmon_test = {} target_talmon_test['labels'] = target_test['labels'] target_talmon_test['covs'] = covs_target_test_transp return source_talmon, target_talmon_train, target_talmon_test
def transform_org2opt(source, target_train, target_test): target_opt_train = {} target_opt_test = {} target_opt_train['labels'] = target_train['labels'] target_opt_test['labels'] = target_test['labels'] # get cost matrix Cs = source['covs'] ys = source['labels'] Ct_train = target_train['covs'] Ct_test = target_test['covs'] M = np.zeros((len(Cs), len(Ct_train))) for i, Cs_i in enumerate(Cs): for j, Ct_j in enumerate(Ct_train): M[i, j] = distance_riemann(Cs_i, Ct_j)**2 # get the transportation plan mu_s = distribution_estimation_uniform(Cs) mu_t = distribution_estimation_uniform(Ct_train) gamma = sinkhorn_lpl1_mm(mu_s, ys, mu_t, M, reg=1.0) # transport the target matrices (train) Ct_train_transported = np.zeros(Ct_train.shape) for j in range(len(Ct_train_transported)): Ct_train_transported[j] = mean_riemann(Cs, sample_weight=gamma[:, j]) target_opt_train['covs'] = Ct_train_transported # transport the target matrices (test) D = np.zeros((len(Ct_test), len(Ct_train))) for k, Ct_k in enumerate(Ct_test): for l, Ct_l in enumerate(Ct_train): D[k, l] = distance_riemann(Ct_k, Ct_l)**2 idx = np.argmin(D, axis=1) # nearest neighbour to each target test matrix Ct_test_transported = np.zeros(Ct_test.shape) for i in range(len(Ct_test)): j = idx[i] Ci = Ct_test[i] Ri = Ct_train[j] Rf = Ct_train_transported[j] Ri_sqrt = sqrtm(Ri) Ri_invsqrt = invsqrtm(Ri) Li = logm(np.dot(Ri_invsqrt, np.dot(Ci, Ri_invsqrt))) eta_i = np.dot(Ri_sqrt, np.dot(Li, Ri_sqrt)) Ri_Rf = geodesic_riemann(Rf, Ri, alpha=0.5) Ri_inv = np.linalg.inv(Ri) eta_f = np.dot(Ri_inv, np.dot(eta_i, Ri_inv)) eta_f = np.dot(Ri_Rf, np.dot(eta_f, Ri_Rf)) Rf_sqrt = sqrtm(Rf) Rf_invsqrt = invsqrtm(Rf) Ef = expm(np.dot(Rf_invsqrt, np.dot(eta_f, Rf_invsqrt))) Ct_test_transported[i] = np.dot(Rf_sqrt, np.dot(Ef, Rf_sqrt)) target_opt_test['covs'] = Ct_test_transported return source, target_opt_train, target_opt_test