def solve_map(transport_map,
              coupling,
              loss,
              lr=1e-2,
              max_iter=2000,
              stop_thr=1e-4,
              verbose=False,
              every_n=200,
              is_hyperbolic=True,
              results_path=None):

    vloss = [stop_thr]
    # init loop
    loop = 1 if max_iter > 0 else 0
    it = 0
    if is_hyperbolic:
        optimizer_map = RiemannianAdam(transport_map.parameters(), lr=lr)
    else:
        optimizer_map = Adam(transport_map.parameters(), lr=lr)

    while loop:
        it += 1
        optimizer_map.zero_grad()
        f_loss = loss.similarity_coupling_fix(transport_map, coupling)
        f_loss.backward()
        optimizer_map.step()

        vloss.append(to_float(f_loss))
        relative_error = abs(vloss[-1] - vloss[-2]) / abs(vloss[-2])

        if verbose and (it % every_n == 0):
            print("\t \t it: %s similarity loss: %.3f" % (it, vloss[-1]))

        if (it >= max_iter) or (np.isnan(vloss[-1])):
            loop = 0

        if relative_error < stop_thr:
            loop = 0
    return transport_map
def solve_hyperbolic_nn_model(transport_map,
                              coupling,
                              loss,
                              max_iter=100,
                              stop_thr=1e-3,
                              lr=1e-2,
                              display_every=100,
                              verbose=False,
                              is_hyperbolic=True):
    if is_hyperbolic:
        optimizer = RiemannianAdam(transport_map.parameters(), lr=lr)
    else:
        optimizer = Adam(transport_map.parameters(), lr=lr)

    vloss = [stop_thr]
    # init loop
    loop = 1 if max_iter > 0 else 0
    it = 0
    while loop:
        it += 1
        optimizer.zero_grad()
        l = loss.similarity_coupling_fix(transport_map, coupling)
        vloss.append(to_float(l))

        if (it >= max_iter) or (np.isnan(vloss[-1])):
            loop = 0

        relative_error = abs(vloss[-1] - vloss[-2]) / abs(vloss[-2])

        if relative_error < stop_thr:
            loop = 0

        if (it % display_every == 0) and verbose:
            print("\t\t it: %s loss map: %.4f" % (it, to_float(l)))

        l.backward()
        optimizer.step()

    return transport_map
def minimize_sinkhorn(Xs,
                      Xt,
                      model,
                      ys=None,
                      yt=None,
                      lr=1e-2,
                      reg_ot=1e-2,
                      is_hyperbolic=True,
                      match_targets=True,
                      n_iter_sinkhorn=100,
                      max_iter=10000,
                      stop_thr=1e-5,
                      max_iter_map=1000,
                      is_sinkhorn_normalized=False,
                      every_n=100,
                      rate=1e-2,
                      is_projected_output=False,
                      type_ini='bary',
                      is_init_map=False,
                      verbose=False):

    mobius = Mobius()
    # Initialize: Barycenter approximation
    if is_init_map:

        model = deepcopy(model)

        if is_hyperbolic:
            optimizer_init = RiemannianAdam(model.parameters(), lr=lr)
            manifold = Mobius()
        else:
            optimizer_init = Adam(model.parameters(), lr=lr)
            manifold = Euclidean()

        ### Compute cost
        _, _, _, coupling = compute_transport(Xs=Xs,
                                              Xt=Xt,
                                              ys=ys,
                                              yt=yt,
                                              reg_ot=reg_ot,
                                              match_targets=match_targets,
                                              manifold=manifold,
                                              is_hyperbolic=is_hyperbolic)

        if type_ini == 'bary':
            x_approx = manifold.barycenter_mapping(Xt, coupling)
        elif (type_ini == 'rot_s2t') or (type_ini == 'rot_t2s'):
            xs_np = Xs.data.numpy()
            xt_np = Xt.data.numpy()
            xs_mean = xs_np.mean(0)
            xt_mean = xt_np.mean(0)
            xs_centered = xs_np - xs_mean
            xt_centered = xt_np - xt_mean
            if type_ini == 'rot_s2t':
                P, _ = orthogonal_procrustes(xs_centered, xt_centered)
                x_approx = torch.FloatTensor(xs_centered.dot(P) + xt_mean)
            else:
                P, _ = orthogonal_procrustes(xt_centered, xs_centered)
                x_approx = torch.FloatTensor(xt_centered.dot(P) + xs_mean)

        elif type_ini == 'id':
            x_approx = Xs

        loop_map = 1 if max_iter_map > 0 else 0
        vloss_map = [stop_thr]
        it = 0
        while loop_map:
            it += 1
            optimizer_init.zero_grad()
            X_pred = mobius.proj2ball(
                model(Xs)) if is_projected_output else model(Xs)
            loss_map = manifold.distance(X_pred, x_approx).mean()
            vloss_map.append(loss_map.item())
            relative_error = abs(vloss_map[-1] - vloss_map[-2]) / abs(
                vloss_map[-2])
            if (it >= max_iter_map) or (np.isnan(vloss_map[-1])):
                loop_map = 0

            if relative_error < stop_thr:
                loop_map = 0

            loss_map.backward()
            optimizer_init.step()

    this_model = deepcopy(model)
    lr_mapping = lr * rate

    if is_hyperbolic:
        optimizer = RiemannianAdam(this_model.parameters(), lr=lr_mapping)
    else:
        optimizer = Adam(this_model.parameters(), lr=lr_mapping)

    vloss = [stop_thr]

    loop = 1 if max_iter > 0 else 0
    it = 0
    while loop:
        it += 1
        optimizer.zero_grad()
        X_pred = mobius.proj2ball(
            this_model(Xs)) if is_projected_output else this_model(Xs)

        if is_sinkhorn_normalized:

            loss = sinkhorn_normalized(X_pred,
                                       Xt,
                                       reg_ot=reg_ot,
                                       n_iter=n_iter_sinkhorn,
                                       ys=ys,
                                       yt=yt,
                                       match_targets=match_targets,
                                       is_hyperbolic=is_hyperbolic)
        else:
            G, loss = sinkhorn_cost(
                X_pred,
                Xt,
                reg_ot=reg_ot,
                n_iter=n_iter_sinkhorn,
                match_targets=match_targets,
                #wrapped_function=lambda x: -torch.cosh(x),
                ys=ys,
                yt=yt,
                is_hyperbolic=is_hyperbolic)

        vloss.append(loss.item())

        relative_error = (abs(vloss[-1] - vloss[-2]) /
                          abs(vloss[-2]) if vloss[-2] != 0 else 0)

        if verbose and (it % every_n == 0):
            print("\t \t it: %s similarity loss: %.3f" % (it, vloss[-1]))

        if (it >= max_iter) or (np.isnan(vloss[-1])):
            loop = 0

        if relative_error < stop_thr:
            loop = 0

        loss.backward()
        optimizer.step()

    return this_model
        # out = P^T x
        out = torch.matmul(self.Matrix.unsqueeze(0), x)
        return out

if __name__ == "__main__":
    birh = BirkhoffPoly(10).cuda()
    rie_adam = RiemannianAdam(birh.parameters())
    inputs = torch.randn(2,10, 1)
    inputs_permute = inputs.clone()
    inputs_permute[:, 1], inputs_permute[:, 2] = inputs[:, 2].clone(), inputs[:, 1].clone()
    print(inputs)
    print(inputs_permute)
    print('before {}'.format(birh.Matrix))
    print(torch.inverse(birh.Matrix))
    for iter in range(20):
        rie_adam.zero_grad()
        loss = 10*ortho_loss(birh.Matrix) + 10*torch.sum(torch.abs(birh.Matrix))
        print(loss)
        loss.backward()
        rie_adam.step()
    print('after {}'.format(birh.Matrix))
    for iter in range(10): 
        rie_adam.zero_grad()
        out = birh(inputs_permute)
        # loss = torch.zeros(1)
        # for param in birh.parameters():
        #     if isinstance(param, geoopt.tensor.ManifoldParameter):
        #         loss = ortho_loss(birh.Matrix)
        loss = F.mse_loss(out, inputs) + ortho_loss(birh.Matrix) + 0*torch.sum(torch.abs(birh.Matrix))
        loss.backward()
        rie_adam.step()