def relax_primal_loss(primal_var, dual_var, beta, lamb):
    """
    Loss to minimize in the primal iteration of the DCA algorithm
    """
    alpha = torch.sum(primal_var, 1)
    g = lamb * sinkhorn._KL(alpha, beta, primal_var)
    h_relax = torch.sum(primal_var * dual_var)
    return g - h_relax
def eval(net, y, beta, lamb=1, niter_sink=5, err_threshold=1e-4,
         size_batch=100000, rp=0, eta=10000):
    """
    Evaluate the true loss by sampling with a larger batch size and removing the smoothing parameters.
    """
    distrib = torch.distributions.Exponential(torch.tensor(1.))
    input = torch.zeros((size_batch, 1), requires_grad=True)
    samples = distrib.sample((size_batch, 1))
    input.data = samples.clone()
    ###### Compute loss matrix ###
    C = LazySecondPriceLoss(net, input, y, size_batch,
                            distribution="exponential", rp=rp, eta=eta)

    ###### Sinkhorn loss #########
    alpha = net.alpha()
    _, gamma, _, _ = sinkhorn.sinkhorn_loss_primal(
        alpha, C, beta, y, lamb, niter=niter_sink,
        cost="matrix", err_threshold=err_threshold, verbose=False)
    true_alpha = torch.sum(gamma, dim=1)
    p_loss = sinkhorn._KL(true_alpha, beta, gamma, epsilon=0)
    u_loss = torch.sum(gamma*C)
    loss = u_loss + lamb*p_loss
    return loss, gamma, C, p_loss, u_loss
def train_dc(net,
             y,
             beta,
             lamb=1,
             max_time=10,
             err_threshold=1e-4,
             cost=sinkhorn._squared_distances,
             dual_iter=100,
             debug=False,
             learning_rate=0.01,
             experiment=0,
             verbose=False,
             verbose_freq=100,
             device="cpu",
             **kwargs):
    """
    learn a discrete distribution (gamma) with a prior (beta, y) using DCA algorithm.
    """
    actionstr = "_actions{}".format(
        net.nactions) if (net.nactions != y.size(0) + 2) else ""

    fold = '{}_lamb{}{}_k{}_dim{}_dualiter{}_lr{}_dc_{}'.format(
        experiment, lamb, actionstr, y.size(0), y.size(1), dual_iter,
        learning_rate, device)
    if experiment != 0:
        os.system('mkdir experiments/dc/' + fold)

    one = torch.FloatTensor([1]).to(device)

    iterations = 0
    loss_profile = []
    time_profile = []
    start_time = timeit.default_timer()
    running_time = 0
    while running_time < max_time:

        # ---------------------------
        #        Optimize over net
        # ---------------------------
        time = timeit.default_timer()
        # gamma is the parameter to optimize (
        # the best x for a given gamma is automatically computed here)
        gamma, x = net(one)
        dual_var = -torch.mm(x, y.t())  # dual iteration of DCA
        # primal iteration of DCA
        gamma_it = solve_relaxed_primal(dual_var,
                                        beta,
                                        gamma,
                                        lamb=lamb,
                                        max_iter=dual_iter,
                                        learning_rate=learning_rate,
                                        err_threshold=err_threshold,
                                        debug=debug,
                                        device=device)

        # reasons explained in train_sinkhorn
        running_time += (timeit.default_timer() - time)
        time_profile.append(running_time)

        alpha = torch.sum(gamma_it, 1)
        # C = cost(x, y, **kwargs)
        # loss = torch.sum( gamma_it * C )  + lamb*sinkhorn._KL(alpha, beta, gamma_it)
        loss = lamb*sinkhorn._KL(alpha, beta, gamma_it) - \
            torch.sum(gamma_it*dual_var)  # acurate loss
        loss_profile.append(loss.cpu().detach().numpy())
        net.gamma = gamma_it
        iterations += 1

        if verbose:
            if iterations % verbose_freq == 0:
                print('iterations=' + str(iterations))
    if verbose:
        t_expe = (timeit.default_timer() - start_time)
        print('done in {0} s'.format(t_expe))
        print('total running time: {0} s'.format(running_time))
    if experiment != 0:
        # save data
        torch.save(net, 'experiments/dc/' + fold + '/network')
        np.save('experiments/dc/' + fold + '/losses.npy', loss_profile)
        np.save('experiments/dc/' + fold + '/time.npy', time_profile)
    return loss_profile
def train_sinkhorn(net,
                   y,
                   beta,
                   lamb=1,
                   niter_sink=1,
                   max_time=10,
                   cost=sinkhorn._squared_distances,
                   experiment=0,
                   differentiation="analytic",
                   learning_rate=0.1,
                   err_threshold=1e-4,
                   verbose=False,
                   verbose_freq=100,
                   device="cpu",
                   optim="descent",
                   warm_restart=False,
                   **kwargs):
    """
    learn a discrete distribution (alpha, x) with a prior (beta, y)
    """
    momentum = kwargs.get('momentum', 0)
    momstring = "_{}".format(momentum) if momentum != 0 else ""
    restart_string = "_warm" if warm_restart else ""
    actionstr = "_actions{}".format(
        net.nactions) if (net.nactions != y.size(0) + 2) else ""

    # if experiment !+0, save the simulation
    fold = '{}_lamb{}_k{}{}_dim{}_sinkiter{}_lr{}'.format(
        experiment, lamb, y.size(0), actionstr, y.size(1), niter_sink,
        learning_rate)
    fold += '_sinkhorn_{}_{}_{}{}{}'.format(device, optim, differentiation,
                                            momstring, restart_string)
    if experiment != 0:
        os.system('mkdir experiments/sinkhorn/' + fold)

    # optimizer choice
    if optim == "SGD":
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=learning_rate,
                                    momentum=momentum)
    elif optim == "adam":
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    elif optim == "rms":
        optimizer = torch.optim.RMSprop(net.parameters(), lr=learning_rate)
    else:
        print('Invalid choice of optimizer.')
        return None
    one = torch.FloatTensor([1]).to(device)

    iterations = 0
    loss_profile = []  # evolution of the loss
    time_profile = []  # evolution of the training time
    if verbose:
        u_profile = []
        v_profile = []  # evolution of u and v (sinkhorn dual variables)
        iteration_profile = []  # number of iterations per sinkhorn call
        alpha_profile = []
        x_profile = []
    start_time = timeit.default_timer()
    running_time = 0
    # convergence of Sinkhorn is assumed if we use analytic differentiation
    conv = (differentiation == "analytic")

    while running_time < max_time:  # remains time to train

        # ---------------------------
        #        Optimize over net
        # ---------------------------
        time = timeit.default_timer()
        optimizer.zero_grad()

        # if decreasing_rate:
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = learning_rate/(iterations+1)

        alpha, x = net(one)  # output of the net (parameters to optimize)

        ###### Sinkhorn loss #########
        if iterations == 0:
            # compute the loss
            z = sinkhorn.sinkhorn_loss_primal(alpha,
                                              x,
                                              beta,
                                              y,
                                              lamb,
                                              niter=niter_sink,
                                              verbose=verbose,
                                              cost=cost,
                                              convergence=conv,
                                              err_threshold=err_threshold,
                                              warm_restart=warm_restart)

        else:
            # compute the loss
            z = sinkhorn.sinkhorn_loss_primal(alpha,
                                              x,
                                              beta,
                                              y,
                                              lamb,
                                              niter=niter_sink,
                                              verbose=verbose,
                                              cost=cost,
                                              convergence=conv,
                                              err_threshold=err_threshold,
                                              warm_restart=warm_restart,
                                              u=u,
                                              v=v)

        # if verbose:
        #     loss, _, u, v, sink_iter = z
        # else:
        loss, _, u, v = z

        loss.backward(one)  # automatic differentiation
        optimizer.step()  # gradient descent step

        # if projected gradient
        if net.proj:
            net.projection()

        # for the sinkhorn method, it takes some time to compute the accurate
        # loss from the estimated loss in the training
        running_time += (timeit.default_timer() - time)
        # because we need to do sinkhorn algorithm with more iterations
        # (as niter_sink is small). In this case, we compute the true loss
        # for the plot but it does not count in the running time
        time_profile.append(running_time)

        # compute the true loss for plots and does not count in running time

        # 100 is enough with the chosen parameters in the experiments
        _, gamma, _, _ = sinkhorn.sinkhorn_loss_primal(alpha,
                                                       x,
                                                       beta,
                                                       y,
                                                       lamb,
                                                       niter=100,
                                                       cost=cost,
                                                       err_threshold=1e-4)
        loss_p = torch.sum(gamma*cost(x, y)) + lamb * \
            sinkhorn._KL(alpha, beta, gamma, epsilon=0)
        # print(gamma.shape)
        # print(torch.sum(gamma))
        loss_profile.append(loss_p.cpu().detach().numpy())

        iterations += 1

        if verbose:
            u_profile.append(u.detach().numpy())
            v_profile.append(v.detach().numpy())
            iteration_profile.append(sink_iter)
            alpha_profile.append(alpha.detach().numpy())
            x_profile.append(x.detach().numpy())
            if iterations % verbose_freq == 0:
                print('iterations=' + str(iterations))

    if verbose:
        t_expe = (timeit.default_timer() - start_time)
        print('done in {0} s'.format(t_expe))
        print('total running time: {0} s'.format(running_time))
    if experiment != 0:
        # save data
        torch.save(net, 'experiments/sinkhorn/' + fold + '/network')
        np.save('experiments/sinkhorn/' + fold + '/losses.npy', loss_profile)
        np.save('experiments/sinkhorn/' + fold + '/time.npy', time_profile)
    if verbose:
        u_profile = np.array(u_profile)
        v_profile = np.array(v_profile)
        iteration_profile = np.array(iteration_profile)
        alpha_profile = np.array(alpha_profile)
        x_profile = np.array(x_profile)
        return (loss_profile, time_profile, u_profile, v_profile,
                iteration_profile, x_profile, alpha_profile)
    else:
        return loss_profile
def train_descent(net,
                  y,
                  beta,
                  lamb=1,
                  max_time=10,
                  cost=sinkhorn._squared_distances,
                  learning_rate=0.1,
                  experiment=0,
                  verbose=False,
                  verbose_freq=100,
                  device="cpu",
                  optim="descent",
                  **kwargs):
    """
    learn a discrete distribution (gamma, x) with a prior (beta, y) using gradient descent on (gamma, x). Similar structure than train_sinkhorn
    """
    momentum = kwargs.get('momentum', 0)
    momstring = "_{}".format(momentum) if momentum != 0 else ""
    actionstr = "_actions{}".format(
        net.nactions) if (net.nactions != y.size(0) + 2) else ""

    fold = '{}_lamb{}{}_k{}_dim{}_lr{}_descent_{}_{}{}'.format(
        experiment, lamb, actionstr, y.size(0), y.size(1), learning_rate,
        device, optim, momstring)
    if experiment != 0:
        os.system('mkdir experiments/descent/' + fold)

    # optimizer choice
    if optim == "SGD":
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=learning_rate,
                                    momentum=momentum)
    elif optim == "adam":
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    elif optim == "rms":
        optimizer = torch.optim.RMSprop(net.parameters(), lr=learning_rate)
    else:
        print('Invalid choice of optimizer.')
        return None
    one = torch.FloatTensor([1]).to(device)

    iterations = 0
    loss_profile = []
    time_profile = []
    start_time = timeit.default_timer()
    running_time = 0
    while running_time < max_time:

        # ---------------------------
        #        Optimize over net
        # ---------------------------

        time = timeit.default_timer()
        optimizer.zero_grad()
        gamma, x = net(one)  # output of the network. Parameters to optimize
        alpha = torch.sum(gamma, dim=1)

        ###### Total loss #########
        C = cost(x, y)

        loss = torch.sum(gamma * C) + lamb * \
            sinkhorn._KL(alpha, beta, gamma)  # loss
        loss.backward(one)  # autodiff
        optimizer.step()  # gradient descent step

        if net.proj:  # for projected gradient
            net.projection()

        # reasons explained in train_sinkhorn
        running_time += (timeit.default_timer() - time)
        time_profile.append(running_time)

        # here the training loss is accurate
        loss_profile.append(loss.cpu().detach().numpy())

        iterations += 1

        if verbose:
            if iterations % verbose_freq == 0:
                print('iterations=' + str(iterations))

    if verbose:
        t_expe = (timeit.default_timer() - start_time)
        print('done in {0} s'.format(t_expe))
        print('total running time: {0} s'.format(running_time))
    if experiment != 0:
        # save data
        torch.save(net, 'experiments/descent/' + fold + '/network')
        np.save('experiments/descent/' + fold + '/losses.npy', loss_profile)
        np.save('experiments/descent/' + fold + '/time.npy', time_profile)
    return loss_profile