Exemplo n.º 1
0
def initialisation_gw(p, q, Cs, U, prior=0, nb_init=5, nb_dummies=1):
    """Initialisation of GW. From a barycenter, it gives a first shot for
    the transport matrix

    Parameters
    ----------
    p: array of len (n_p)
        weights of the source (positives !! with no dummies !!)

    q: array of len (n_u)
        weights of the target (unlabeled)

    Cs: array of shape (n_p,n_p)
        intra-source (P) cost matrix (!! no dummies !!)

    U: array of shape (n_u,d_u)
        U dataset

    prior: percentage of positives on the dataset (s)

    n_init: number of atoms on the barycenter

    nb_dummies: number of dummy points, default: 1
        (to avoid numerical instabilities of POT)

    Returns
    -------
    list of numpy.array of shape (n_p+, n_u)
        list of potentialtransport matrix initialisation

    """
    if nb_init > 2:
        res, _, _ = wass_bary_coarsening(nb_init,
                                         np.array(U),
                                         pt=np.ones(U.shape[0]) / (U.shape[0]))
    else:
        res, _, _ = wass_bary_coarsening(nb_init,
                                         np.array(U),
                                         pt=np.ones(U.shape[0]) / (U.shape[0]),
                                         pb=[prior, 1 - prior])
    idx = []
    l_gamma = []
    for i in range(nb_init):
        idx = np.where(res[i, :] > 1e-5)[0]
        gamma = np.zeros((len(p) + nb_dummies, len(q)))

        Ct_0 = cdist(U.iloc[idx], U.iloc[idx])
        gamma1 = gromov_wasserstein(Cs, Ct_0, p,
                                    np.ones(Ct_0.shape[0]) / Ct_0.shape[0],
                                    'square_loss')
        gamma1 /= np.sum(gamma1)
        for i in range(len(idx)):
            gamma[:-nb_dummies, idx[i]] = gamma1[:, i]
        l_gamma.append(gamma)
    return l_gamma
Exemplo n.º 2
0
def _compute_gw(pts_1, pts_2, wts_1, wts_2):
    """Normalize weights and compute OT matrix."""

    # Normalize weights
    p_1 = wts_1 / np.sum(wts_1)
    p_2 = wts_2 / np.sum(wts_2)

    # Normalized distance matrices
    c_1 = _normalized_dist_mtx(pts_1, pts_1, metric='sqeuclidean')
    c_2 = _normalized_dist_mtx(pts_2, pts_2, metric='sqeuclidean')

    # Compute transport plan
    return gromov_wasserstein(c_1, c_2, p_1, p_2, 'square_loss', log=True)
    # Generate gaussian mixtures translated from each other
    a, x, b, y = generate_data(n1, 0.7)
    clf = KMeans(n_clusters=n_clust)
    clf.fit(x)
    idx = np.zeros(n_clust)
    for i in range(n_clust):
        d = clf.transform(x)[:, i]
        idx[i] = np.argmin(d)
    idx = idx.astype(int)

    # Generate costs and transport plan
    Cx, Cy = euclid_dist(x, x), euclid_dist(y, y)

    if compute_balanced:
        pi_b = gromov_wasserstein(Cx, Cy, a, b, loss_fun='square_loss')
        plot_density_matching(pi_b, a, x, b, y, idx, alpha=1., linewidth=.5)
        plt.legend()
        plt.savefig(path + '/fig_matching_plan_balanced.png')
        plt.show()

    Cx, Cy = torch.from_numpy(Cx), torch.from_numpy(Cy)

    rho_list = [0.1]
    peps_list = [2, 1, 0, -1, -2, -3]
    for rho in rho_list:
        solver.rho = rho
        pi = None
        for p in peps_list:
            eps = 10 ** p
            solver.eps = eps
Exemplo n.º 4
0
def g_wasserstein(x_src, x_tgt, C2):
    N = x_src.shape[0]
    C1 = GWmatrix(x_src)
    M = gromov_wasserstein(C1, C2, np.ones(N), np.ones(N),
                           "square_loss")  # epsilon=0.55,max_iter=100,tol=1e-4
    return procrustes(np.dot(M, x_tgt), x_src)
    a, b = a.cpu().data.numpy(), b.cpu().data.numpy()
    pi = pi.cpu().data.numpy()
    plot_density_matching(
        pi,
        a,
        x,
        b,
        y,
        Gx,
        Gy,
        titlename=f'UGW matching, ($\\rho$,$\epsilon$)={rho, eps}')
    plt.legend()
    plt.show()

    if normalize_proba:
        pi_b = gromov_wasserstein(Cx.cpu().numpy(),
                                  Cy.cpu().numpy(),
                                  a,
                                  b,
                                  loss_fun='square_loss')
        plot_density_matching(pi_b,
                              a,
                              x,
                              b,
                              y,
                              Gx,
                              Gy,
                              titlename='GW matching')
        plt.legend()
        plt.show()
Exemplo n.º 6
0
Gwg, logw = fused_gromov_wasserstein(M,
                                     C1,
                                     C2,
                                     p,
                                     q,
                                     loss_fun='square_loss',
                                     alpha=alpha,
                                     verbose=True,
                                     log=True)
ot.toc()

#%reload_ext WGW
Gg, log = gromov_wasserstein(C1,
                             C2,
                             p,
                             q,
                             loss_fun='square_loss',
                             verbose=True,
                             log=True)

##############################################################################
# Visualize transport matrices
# ---------

#%% visu OT matrix
cmap = 'Blues'
fs = 15
pl.figure(2, (13, 5))
pl.clf()
pl.subplot(1, 3, 1)
pl.imshow(Got, cmap=cmap, interpolation='nearest')
    # for m in list_mass_pgw:  # Compute partial GW plans, initialized with simulated annealing
    #     pi = a[:,None] * b[None,:]
    #     for eps in [10 ** e for e in [2., 1.5, 1]]:  # Simulated annealing loop
    #         pi = entropic_partial_gromov_wasserstein(cx, cy, a, b, eps, m=m, G0=pi)
    #     pi = partial_gromov_wasserstein(cx, cy, a, b, m=m, G0=pi)
    #     cost = partial_gromov_wasserstein2(cx, cy, a, b, m=m, G0=pi)
    #     # Initialize with partial OT plan
    #     M = sp.spatial.distance.cdist(x, y)
    #     gam = partial_wasserstein(a, b, M, m=m)
    #     gam = partial_gromov_wasserstein(cx, cy, a, b, m=m, G0=gam)
    #     if partial_gromov_wasserstein2(cx, cy, a, b, m=m, G0=gam) < cost:
    #         pi = gam
    #     # Initialize with OT plan
    #     gam = emd(a, b, M)
    #     gam = partial_gromov_wasserstein(cx, cy, a, b, m=m, G0=gam)
    #     if partial_gromov_wasserstein2(cx, cy, a, b, m=m, G0=gam) < cost:
    #         pi = gam
    #
    #     # Plot matchings between measures --> Partial GW
    #     plot_density_matching(pi, a, x, b, y, Gx, Gy, titlefile=f'PGW_mass{m:.3f}')
    #     plt.legend()
    #     plt.show()

    if normalize_proba & compare_with_gw:  # Plot the behaviour of GW as reference
        pi = a[:, None] * b[None, :]
        pi_gw = gromov_wasserstein(cx, cy, a, b, loss_fun='square_loss')
        plot_density_matching(pi_gw, a, x, b, y, Gx, Gy, titlefile='GW')
        plt.legend()
        plt.show()