def minimize_sinkhorn(Xs,
                      Xt,
                      model,
                      ys=None,
                      yt=None,
                      lr=1e-2,
                      reg_ot=1e-2,
                      is_hyperbolic=True,
                      match_targets=True,
                      n_iter_sinkhorn=100,
                      max_iter=10000,
                      stop_thr=1e-5,
                      max_iter_map=1000,
                      is_sinkhorn_normalized=False,
                      every_n=100,
                      rate=1e-2,
                      is_projected_output=False,
                      type_ini='bary',
                      is_init_map=False,
                      verbose=False):

    mobius = Mobius()
    # Initialize: Barycenter approximation
    if is_init_map:

        model = deepcopy(model)

        if is_hyperbolic:
            optimizer_init = RiemannianAdam(model.parameters(), lr=lr)
            manifold = Mobius()
        else:
            optimizer_init = Adam(model.parameters(), lr=lr)
            manifold = Euclidean()

        ### Compute cost
        _, _, _, coupling = compute_transport(Xs=Xs,
                                              Xt=Xt,
                                              ys=ys,
                                              yt=yt,
                                              reg_ot=reg_ot,
                                              match_targets=match_targets,
                                              manifold=manifold,
                                              is_hyperbolic=is_hyperbolic)

        if type_ini == 'bary':
            x_approx = manifold.barycenter_mapping(Xt, coupling)
        elif (type_ini == 'rot_s2t') or (type_ini == 'rot_t2s'):
            xs_np = Xs.data.numpy()
            xt_np = Xt.data.numpy()
            xs_mean = xs_np.mean(0)
            xt_mean = xt_np.mean(0)
            xs_centered = xs_np - xs_mean
            xt_centered = xt_np - xt_mean
            if type_ini == 'rot_s2t':
                P, _ = orthogonal_procrustes(xs_centered, xt_centered)
                x_approx = torch.FloatTensor(xs_centered.dot(P) + xt_mean)
            else:
                P, _ = orthogonal_procrustes(xt_centered, xs_centered)
                x_approx = torch.FloatTensor(xt_centered.dot(P) + xs_mean)

        elif type_ini == 'id':
            x_approx = Xs

        loop_map = 1 if max_iter_map > 0 else 0
        vloss_map = [stop_thr]
        it = 0
        while loop_map:
            it += 1
            optimizer_init.zero_grad()
            X_pred = mobius.proj2ball(
                model(Xs)) if is_projected_output else model(Xs)
            loss_map = manifold.distance(X_pred, x_approx).mean()
            vloss_map.append(loss_map.item())
            relative_error = abs(vloss_map[-1] - vloss_map[-2]) / abs(
                vloss_map[-2])
            if (it >= max_iter_map) or (np.isnan(vloss_map[-1])):
                loop_map = 0

            if relative_error < stop_thr:
                loop_map = 0

            loss_map.backward()
            optimizer_init.step()

    this_model = deepcopy(model)
    lr_mapping = lr * rate

    if is_hyperbolic:
        optimizer = RiemannianAdam(this_model.parameters(), lr=lr_mapping)
    else:
        optimizer = Adam(this_model.parameters(), lr=lr_mapping)

    vloss = [stop_thr]

    loop = 1 if max_iter > 0 else 0
    it = 0
    while loop:
        it += 1
        optimizer.zero_grad()
        X_pred = mobius.proj2ball(
            this_model(Xs)) if is_projected_output else this_model(Xs)

        if is_sinkhorn_normalized:

            loss = sinkhorn_normalized(X_pred,
                                       Xt,
                                       reg_ot=reg_ot,
                                       n_iter=n_iter_sinkhorn,
                                       ys=ys,
                                       yt=yt,
                                       match_targets=match_targets,
                                       is_hyperbolic=is_hyperbolic)
        else:
            G, loss = sinkhorn_cost(
                X_pred,
                Xt,
                reg_ot=reg_ot,
                n_iter=n_iter_sinkhorn,
                match_targets=match_targets,
                #wrapped_function=lambda x: -torch.cosh(x),
                ys=ys,
                yt=yt,
                is_hyperbolic=is_hyperbolic)

        vloss.append(loss.item())

        relative_error = (abs(vloss[-1] - vloss[-2]) /
                          abs(vloss[-2]) if vloss[-2] != 0 else 0)

        if verbose and (it % every_n == 0):
            print("\t \t it: %s similarity loss: %.3f" % (it, vloss[-1]))

        if (it >= max_iter) or (np.isnan(vloss[-1])):
            loop = 0

        if relative_error < stop_thr:
            loop = 0

        loss.backward()
        optimizer.step()

    return this_model
class HyperbolicSinkhornTransport(BaseEstimator):
    def __init__(self,
                 reg_ot=1e-1,
                 ot_solver='sinkhorn_knopp',
                 batch_size=128,
                 wrapped_function=None,
                 normalization='max',
                 is_hyperbolic=True,
                 match_targets=False):
        self.reg_ot = reg_ot
        self.ot_solver = ot_solver
        self.batch_size = batch_size
        self.wrapped_function = wrapped_function
        self.normalization = normalization
        self.is_hyperbolic = is_hyperbolic
        self.match_targets = match_targets
        self.manifold = Mobius() if is_hyperbolic else Euclidean()

    def fit(self, Xs, Xt, ys=None, yt=None):
        self.Xs_train_ = deepcopy(Xs)
        self.Xt_train_ = deepcopy(Xt)

        _, _, _, coupling = compute_transport(
            Xs=Xs,
            Xt=Xt,
            ys=ys,
            yt=yt,
            reg_ot=self.reg_ot,
            ot_solver=self.ot_solver,
            wrapped_function=self.wrapped_function,
            normalization=self.normalization,
            is_hyperbolic=self.is_hyperbolic,
            detach_x=False,
            detach_y=False,
            match_targets=self.match_targets,
        )
        self.coupling_ = coupling
        return self

    def transform(self, Xs):
        indices = np.arange(Xs.shape[0])
        batch_ind = [
            indices[i:i + self.batch_size]
            for i in range(0, len(indices), self.batch_size)
        ]

        transp_Xs = []
        X_bary = self.manifold.barycenter_mapping(self.Xt_train_,
                                                  self.coupling_)
        #print(X_bary)
        for bi in batch_ind:
            # get the nearest neighbor in the source domain
            M0 = compute_cost(
                Xs[bi],
                self.Xs_train_,
                normalization=self.normalization,
                wrapped_function=lambda x: x,  # using the hyperbolic distance
                is_hyperbolic=self.is_hyperbolic)
            idx = M0.argmin(dim=1)
            # define the transported points
            diff = self.manifold.add(-self.Xs_train_[idx, :], Xs[bi])
            ####transp_Xs_ = self.manifold.add(diff, X_bary[idx, :])
            transp_Xs_ = self.manifold.add(X_bary[idx, :], diff)
            transp_Xs.append(transp_Xs_)

        return torch.cat(transp_Xs, dim=0)

    def fit_transport(self, Xs, Xt, ys=None, yt=None):
        self.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
        return self.transform(Xs)