def solve_map(transport_map, coupling, loss, lr=1e-2, max_iter=2000, stop_thr=1e-4, verbose=False, every_n=200, is_hyperbolic=True, results_path=None): vloss = [stop_thr] # init loop loop = 1 if max_iter > 0 else 0 it = 0 if is_hyperbolic: optimizer_map = RiemannianAdam(transport_map.parameters(), lr=lr) else: optimizer_map = Adam(transport_map.parameters(), lr=lr) while loop: it += 1 optimizer_map.zero_grad() f_loss = loss.similarity_coupling_fix(transport_map, coupling) f_loss.backward() optimizer_map.step() vloss.append(to_float(f_loss)) relative_error = abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) if verbose and (it % every_n == 0): print("\t \t it: %s similarity loss: %.3f" % (it, vloss[-1])) if (it >= max_iter) or (np.isnan(vloss[-1])): loop = 0 if relative_error < stop_thr: loop = 0 return transport_map
def solve_hyperbolic_nn_model(transport_map, coupling, loss, max_iter=100, stop_thr=1e-3, lr=1e-2, display_every=100, verbose=False, is_hyperbolic=True): if is_hyperbolic: optimizer = RiemannianAdam(transport_map.parameters(), lr=lr) else: optimizer = Adam(transport_map.parameters(), lr=lr) vloss = [stop_thr] # init loop loop = 1 if max_iter > 0 else 0 it = 0 while loop: it += 1 optimizer.zero_grad() l = loss.similarity_coupling_fix(transport_map, coupling) vloss.append(to_float(l)) if (it >= max_iter) or (np.isnan(vloss[-1])): loop = 0 relative_error = abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) if relative_error < stop_thr: loop = 0 if (it % display_every == 0) and verbose: print("\t\t it: %s loss map: %.4f" % (it, to_float(l))) l.backward() optimizer.step() return transport_map
def minimize_sinkhorn(Xs, Xt, model, ys=None, yt=None, lr=1e-2, reg_ot=1e-2, is_hyperbolic=True, match_targets=True, n_iter_sinkhorn=100, max_iter=10000, stop_thr=1e-5, max_iter_map=1000, is_sinkhorn_normalized=False, every_n=100, rate=1e-2, is_projected_output=False, type_ini='bary', is_init_map=False, verbose=False): mobius = Mobius() # Initialize: Barycenter approximation if is_init_map: model = deepcopy(model) if is_hyperbolic: optimizer_init = RiemannianAdam(model.parameters(), lr=lr) manifold = Mobius() else: optimizer_init = Adam(model.parameters(), lr=lr) manifold = Euclidean() ### Compute cost _, _, _, coupling = compute_transport(Xs=Xs, Xt=Xt, ys=ys, yt=yt, reg_ot=reg_ot, match_targets=match_targets, manifold=manifold, is_hyperbolic=is_hyperbolic) if type_ini == 'bary': x_approx = manifold.barycenter_mapping(Xt, coupling) elif (type_ini == 'rot_s2t') or (type_ini == 'rot_t2s'): xs_np = Xs.data.numpy() xt_np = Xt.data.numpy() xs_mean = xs_np.mean(0) xt_mean = xt_np.mean(0) xs_centered = xs_np - xs_mean xt_centered = xt_np - xt_mean if type_ini == 'rot_s2t': P, _ = orthogonal_procrustes(xs_centered, xt_centered) x_approx = torch.FloatTensor(xs_centered.dot(P) + xt_mean) else: P, _ = orthogonal_procrustes(xt_centered, xs_centered) x_approx = torch.FloatTensor(xt_centered.dot(P) + xs_mean) elif type_ini == 'id': x_approx = Xs loop_map = 1 if max_iter_map > 0 else 0 vloss_map = [stop_thr] it = 0 while loop_map: it += 1 optimizer_init.zero_grad() X_pred = mobius.proj2ball( model(Xs)) if is_projected_output else model(Xs) loss_map = manifold.distance(X_pred, x_approx).mean() vloss_map.append(loss_map.item()) relative_error = abs(vloss_map[-1] - vloss_map[-2]) / abs( vloss_map[-2]) if (it >= max_iter_map) or (np.isnan(vloss_map[-1])): loop_map = 0 if relative_error < stop_thr: loop_map = 0 loss_map.backward() optimizer_init.step() this_model = deepcopy(model) lr_mapping = lr * rate if is_hyperbolic: optimizer = RiemannianAdam(this_model.parameters(), lr=lr_mapping) else: optimizer = Adam(this_model.parameters(), lr=lr_mapping) vloss = [stop_thr] loop = 1 if max_iter > 0 else 0 it = 0 while loop: it += 1 optimizer.zero_grad() X_pred = mobius.proj2ball( this_model(Xs)) if is_projected_output else this_model(Xs) if is_sinkhorn_normalized: loss = sinkhorn_normalized(X_pred, Xt, reg_ot=reg_ot, n_iter=n_iter_sinkhorn, ys=ys, yt=yt, match_targets=match_targets, is_hyperbolic=is_hyperbolic) else: G, loss = sinkhorn_cost( X_pred, Xt, reg_ot=reg_ot, n_iter=n_iter_sinkhorn, match_targets=match_targets, #wrapped_function=lambda x: -torch.cosh(x), ys=ys, yt=yt, is_hyperbolic=is_hyperbolic) vloss.append(loss.item()) relative_error = (abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) if vloss[-2] != 0 else 0) if verbose and (it % every_n == 0): print("\t \t it: %s similarity loss: %.3f" % (it, vloss[-1])) if (it >= max_iter) or (np.isnan(vloss[-1])): loop = 0 if relative_error < stop_thr: loop = 0 loss.backward() optimizer.step() return this_model
# out = P^T x out = torch.matmul(self.Matrix.unsqueeze(0), x) return out if __name__ == "__main__": birh = BirkhoffPoly(10).cuda() rie_adam = RiemannianAdam(birh.parameters()) inputs = torch.randn(2,10, 1) inputs_permute = inputs.clone() inputs_permute[:, 1], inputs_permute[:, 2] = inputs[:, 2].clone(), inputs[:, 1].clone() print(inputs) print(inputs_permute) print('before {}'.format(birh.Matrix)) print(torch.inverse(birh.Matrix)) for iter in range(20): rie_adam.zero_grad() loss = 10*ortho_loss(birh.Matrix) + 10*torch.sum(torch.abs(birh.Matrix)) print(loss) loss.backward() rie_adam.step() print('after {}'.format(birh.Matrix)) for iter in range(10): rie_adam.zero_grad() out = birh(inputs_permute) # loss = torch.zeros(1) # for param in birh.parameters(): # if isinstance(param, geoopt.tensor.ManifoldParameter): # loss = ortho_loss(birh.Matrix) loss = F.mse_loss(out, inputs) + ortho_loss(birh.Matrix) + 0*torch.sum(torch.abs(birh.Matrix)) loss.backward() rie_adam.step()