Exemple #1
0
    def _initBary(self, bary):
        if not isinstance(bary, Distribution):
            # sample uniformly x in [0,1] and then x -> x*(max-min) + max
            # which corresponds to sampling x in [max,min]
            diff_max_min = self.min_max_range[1] - self.min_max_range[0]

            # Distribution class wants support tensors n x d
            bary = torch.rand(1, self.d) * diff_max_min + self.min_max_range[0]

            bary_full_size = self.support_budget
        else:
            # the potential support of the barycenter distribution
            # is the max support plus the FW iterations
            # bary_full_size = bary.support_size + self.niter
            bary_full_size = max(bary.support_size, self.support_budget)

        self.bary = Distribution(bary, max_support_size=bary_full_size)

        self.best_bary = Distribution(bary)
        self.best_func_val = -1

        # we store the potentials for all sinkhorn computations of
        # all distributions against the current barycenter estimate
        self.potential_bary = torch.zeros(
            self.bary.support_size * self.num_distributions, 1)

        # the potential for the OT(alpha,alpha)
        self.potential_bary_sym = torch.zeros(self.bary.support_size, 1)
def testDistributions():

    # generate a distribution with three points
    mu_support = torch.tensor([[1., 2.], [-3., 4.], [5., 9.]])
    mu0 = Distribution(mu_support)
    new_point = torch.tensor([9., 8.])

    rho = 0.1
    mu0.convexAddSupportPoint(new_point, rho)
 def train_single_distribution(self, distribution_set: list):
     distribution_sequence = Distribution(distribution_set[0])
     for i in range(1, self.__dim):
         distribution_sequence += Distribution(distribution_set[i])
     self.__class_set = [str(distribution_sequence)]
     self.__determine_columns()
     cwd = self.__change_cwd()
     distribution_df = self.__oa_hidden_single(distribution_sequence)
     self.__store_seq(distribution_df, distribution_sequence)
     self.__revert(cwd)
Exemple #4
0
def testGridBarycenter():

    sizes = [10, 20, 14]

    nus = [
        Distribution(torch.randn(s, 2), torch.rand(s, 1)).normalize()
        for s in sizes
    ]

    init_size = 20
    init_bary = Distribution(torch.randn(init_size, 2),
                             torch.rand(init_size, 1)).normalize()

    bary = GridBarycenter(nus, init_bary, support_budget=init_size + 2)

    bary.performFrankWolfe(4)
def from_image_to_distribution(image, rescale):
    loc = np.argwhere(image > 0)
    weights = image[loc[:, 0], loc[:, 1]]
    weights = weights / sum(weights)
    loc = rescale * loc
    distrib = Distribution(torch.tensor(loc), torch.tensor(weights))
    return distrib
Exemple #6
0
    def performFrankWolfe(self, num_itr=1):

        for idx_itr in range(num_itr):
            potentials_distributions, _ = self._computeSinkhorn()
            potentials_bary_sym = self._computeSymSinkhorn()

            new_support_point = self._argminGradient(potentials_distributions,
                                                     potentials_bary_sym)

            rho = self.currentRho()

            # add the new point to the support of the barycenter
            # perform self = (1-rho) * self + rho * other
            self.bary.convexAddSupportPoint(new_support_point, rho)

            # update the distance matrices with the new point (needs to be a 1 x d tensor)
            self._updateDistanceMatrices(new_support_point.view(1, -1),
                                         self.bary.last_updated_idx)

            # update the tensors containing the Sinkhorn potentials
            self._updatePotentialsContainers()

            # update the iteration counter
            self.current_iteration = self.current_iteration + 1

            # if the current barycenter is the best one so far...
            if self.best_func_val < 0 or self.best_func_val > self.func_val[-1]:
                self.best_bary = Distribution(self.bary)
                self.best_func_val = self.func_val[-1]
Exemple #7
0
def testFirstFW():

    nu1 = Distribution(torch.tensor([0., 0.]).view(1, -1)).normalize()
    nu2 = Distribution(torch.tensor([1., 1.]).view(1, -1)).normalize()
    nus = [nu1, nu2]

    init_bary = Distribution(torch.randn(1, 2)).normalize()

    bary = GridBarycenter(nus,
                          init_bary,
                          support_budget=200,
                          grid_step=200,
                          eps=0.1)

    bary.performFrankWolfe(150)

    bary.performFrankWolfe(1)
Exemple #8
0
def testFW():

    sizes = [10,20,14]

    nus = [Distribution(torch.randn(s,2),torch.rand(s,1)).normalize() for s in sizes]

    init_size = 3
    init_bary = Distribution(torch.randn(init_size,2),torch.rand(init_size,1)).normalize()

    bary = Barycenter(nus,init_bary,support_budget=init_size+2)


    try:
        bary.performFrankWolfe(4)
    except Exception as msg:
        if msg.args[0] != "You are using a 'virtual' argminGradient method. This is doing nothing":
            raise Exception('Error in testFW')
 def train(self, distributions):
     """
     :param distributions:
     :return:
     """
     assert self.__theta is not None
     import itertools
     for index in itertools.product(range(len(distributions)), repeat=self.__dim):
         dist = Distribution(distributions[index[0]])
         for i in range(1, len(index)):
             dist += Distribution(distributions[index[i]])
         assert dist.dimension() == self.__dim, 'Dimension mismatch. Dimension should be ' + str(self.__dim) + ' but is ' + str(dist.dimension())
         self.__class_set.append(dist)
     self.__determine_columns()
     cwd = self.__change_cwd()
     for i in progressbar.progressbar(range(len(self.__class_set)), redirect_stdout=True):
         distribution_df = self.__oa_hidden_single(self.__class_set[i]).astype(dtype=self.__dtype).fillna(value=0)
         self.__store_seq(distribution_df, self.__class_set[i])
     self.__revert(cwd)
rescale = 1 / 28
for i in range(images.shape[0]):
    distrib.append(from_image_to_distribution(images[i], rescale))

reg = 0.001

ind_rand = np.random.randint(0, num_images)
first_centroid = from_image_to_distribution(images[ind_rand], rescale)
centroids_distrib = initial_plus(first_centroid, distrib, num_groups, reg)

# Frank-Wolfe stuff
grid_step = 30
fw_iter = 1500
support_budget = fw_iter + 100
init = torch.Tensor([0.5, 0.5])
init_bary = Distribution(init.view(1, -1)).normalize()
# -----

kmeans_iteration = 0
kmeans_iteration_max = 1000
while kmeans_iteration < kmeans_iteration_max:
    tic = time.time()
    print('\nK-Means Iteration N:', kmeans_iteration)
    t_group = time.time()

    num_groups = len(centroids_distrib)

    print("Total number of groups: ", num_groups)

    groups = partition_into_groups(distrib, centroids_distrib, num_groups, reg,
                                   rescale)
Y = torch.linspace(0, scale * 1, 50)
X, Y = torch.meshgrid(X, Y)
X1 = X.reshape(X.shape[0] ** 2)
Y1 = Y.reshape(Y.shape[0] ** 2)

distributions = []
supp_meas = []
weights_meas = []
for i in range(len(pre_supp)):
    supp = torch.zeros((pre_supp[i].shape[0], 2))
    supp[:, 0] = X1[pre_supp[i]]
    supp[:, 1] = Y1[pre_supp[i]]
    supp_meas.append(supp)
    weights = (1 / pre_supp[i].shape[0]) * torch.ones(pre_supp[i].shape[0], 1)
    weights_meas.append(weights)
    distributions.append(Distribution(supp,weights))


init_bary = Distribution(torch.rand(10, 2)).normalize()


total_iter = 500
eps = 0.001

grid_step = 50

support_budget = total_iter + 100

bary = GridBarycenter(distributions, init_bary, support_budget = support_budget,\
                      grid_step = grid_step, eps=eps)
Exemple #12
0
def testProvideGridBarycenter():
    d = 2
    n = 4
    m = 5
    y1 = torch.Tensor([[0.05, 0.2], [0.05, 0.7], [0.05, 0.8], [0.05, 0.9]])
    y1 = torch.reshape(y1, (n, d))

    y2 = torch.Tensor([[0.6, 0.25], [0.8, 0.1], [0.8, 0.23], [0.8, 0.61],
                       [1., 0.21]])
    y2 = torch.reshape(y2, (m, d))

    eps = 0.01
    niter = 1000

    init = torch.Tensor([0.5, 0.5])

    nu1 = Distribution(y1).normalize()
    nu2 = Distribution(y2).normalize()
    nus = [nu1, nu2]

    init_bary = Distribution(init.view(1, -1)).normalize()

    # init_bary = Distribution(torch.rand(100,2)).normalize()

    # support_budget = niter + 100
    support_budget = 100

    # create grid
    grid_step = 50

    min_max_range = torch.tensor([[0.0500, 0.1000],\
                                  [1.0000, 0.9000]])

    margin_percentage = 0.05
    margin = (min_max_range[0, :] -
              min_max_range[1, :]).abs() * margin_percentage

    tmp_ranges = [torch.arange(min_max_range[0, i] - margin[i], min_max_range[1, i] + margin[i], \
                               ((min_max_range[1, i] - min_max_range[0, i]).abs() + 2 * margin[
                                   i]) / grid_step) \
                  for i in range(d)]

    tmp_meshgrid = torch.meshgrid(*tmp_ranges)

    grid = torch.cat(
        [mesh_column.reshape(-1, 1) for mesh_column in tmp_meshgrid], dim=1)
    # created grid

    bary = GridBarycenter(nus, init_bary, support_budget=support_budget, \
                          grid = grid, eps=eps, \
                          sinkhorn_n_itr=100, sinkhorn_tol=1e-3)

    for i in range(10):
        t1 = time.time()
        bary.performFrankWolfe(100)
        t1 = time.time() - t1

        print('Time for 100 FW iterations:', t1)
        ### DEBUG FOR PRINTING

        print(min(bary.func_val) / 2)

        plot(bary.bary.support, bary.bary.weights)

        plt.figure()
        plt.plot(bary.func_val[30:])
        plt.show()

        ciao = 3
 def __beta_reduction(self, stats: Distribution):
     return self.__z_2(stats.rvs(size=self.__size, samples=self.__monte_carlo, dtype=self.__dtype))
Exemple #14
0
class Barycenter:

    # "virtual" method for custom initialization of child Barycenter classes
    def _inner_init(self, **params):
        pass


    def __init__(self,distributions,bary = torch.empty(0),
                 eps = 0.1,mixing_weights=torch.empty(0),\
                 support_budget=100, sinkhorn_tol = 1e-3,\
                 sinkhorn_n_itr = 100, **params):

        # basic variables
        self.eps = eps
        self.support_budget = support_budget

        self.sinkhorn_tol = sinkhorn_tol
        self.sinkhorn_n_itr = sinkhorn_n_itr

        self.func_val = []

        # current iteration of FW
        self.current_iteration = 1

        # the weights for each distribution
        self.mixing_weights = mixing_weights.clone().detach()

        # store the information about the distributions
        self._storeDistributions(distributions)

        # initialize the barycenter
        self._initBary(bary)

        # customizable initialization function
        # it is performed *before* the creation of the distance matrices
        self._inner_init(**params)

        # now that we have both the starting barycenter and the distributions
        # compute the distance matrices of the corresponding supports
        self._initDistanceMatrices()

    # Store information about the distributions
    def _storeDistributions(self, distributions):

        # number of distributions
        self.num_distributions = len(distributions)

        # dimension of the ambient space
        self.d = distributions[0].support.size(1)

        # save a tensor filled with all support points of all distributions
        self.full_support = torch.cat([nu.support for nu in distributions],
                                      dim=0)
        self.full_weights = torch.cat([nu.weights for nu in distributions],
                                      dim=0)

        # if the weights for each distribution have not been provided, assign one to each one
        # then normalize to have all the weights sum to one
        if self.mixing_weights.size(0) == 0 or self.mixing_weights.size(
                0) != self.num_distributions:
            self.mixing_weights = torch.ones(
                self.num_distributions) * (1.0 / self.num_distributions)

        # if some weight is negative, throw an error
        if (self.mixing_weights < 0).sum() > 0:
            warnings.warn(
                "Warning! Negative weights assigned to barycenter distributions!"
            )

        # smallest cube containing all distributions (hence also the barycenter)
        # oriented as a 2 x d vector (first row "min" second row "max")
        self.min_max_range = torch.cat((self.full_support.min(0)[0].view(-1,1),\
                                        self.full_support.max(0)[0].view(-1,1)),\
                                       dim=1).t()

        # list of support sizes
        self.support_number_points = torch.tensor(
            [nu.support_size for nu in distributions])

        # indices to recover the support points (we start with a leading 0 for the first position)
        self.support_location_indices = torch.cat([torch.tensor([0]),\
                                                   self.support_number_points.cumsum(dim=0)])

        self.potential_distributions = torch.zeros_like(self.full_weights)

    # Initialize the barycenter with the one provided or with a random point in the distributions range
    def _initBary(self, bary):
        if not isinstance(bary, Distribution):
            # sample uniformly x in [0,1] and then x -> x*(max-min) + max
            # which corresponds to sampling x in [max,min]
            diff_max_min = self.min_max_range[1] - self.min_max_range[0]

            # Distribution class wants support tensors n x d
            bary = torch.rand(1, self.d) * diff_max_min + self.min_max_range[0]

            bary_full_size = self.support_budget
        else:
            # the potential support of the barycenter distribution
            # is the max support plus the FW iterations
            # bary_full_size = bary.support_size + self.niter
            bary_full_size = max(bary.support_size, self.support_budget)

        self.bary = Distribution(bary, max_support_size=bary_full_size)

        self.best_bary = Distribution(bary)
        self.best_func_val = -1

        # we store the potentials for all sinkhorn computations of
        # all distributions against the current barycenter estimate
        self.potential_bary = torch.zeros(
            self.bary.support_size * self.num_distributions, 1)

        # the potential for the OT(alpha,alpha)
        self.potential_bary_sym = torch.zeros(self.bary.support_size, 1)

    # initializes the big matrix containing all distances
    def _initDistanceMatrices(self):

        x_max_size = self.bary.max_support_size

        # big matirx containing support_bary x full_support + support_bary
        self._bigC = torch.empty(x_max_size,
                                 self.full_support.size(0) + x_max_size)
        self._updateDistanceMatrices(self.bary.support, 0)

    # updates the pointers (views) of all distribution-vs-bary distances
    # on the big matrix of all distances
    def _updateDistanceMatrices(self, x_new, idx):

        x_size = self.bary.support_size
        y_size = self.full_support.size(0)

        # update the pointers
        self.Cxy = self._bigC[:x_size, :y_size]
        self.Cxx = self._bigC[:x_size, y_size:(y_size + x_size)]

        sl_idx = self.support_location_indices
        self.Cxy_list = [
            self.Cxy[:, sl_idx[i]:sl_idx[i + 1]]
            for i in range(self.num_distributions)
        ]

        # current support of the barycenter
        bary_supp = self.bary.support

        self.Cxy[idx:idx + x_new.size(0), :].copy_(
            dist_matrix(x_new, self.full_support) / self.eps)
        self.Cxx[idx:idx + x_new.size(0), :].copy_(
            dist_matrix(x_new, bary_supp) / self.eps)

        # if we are not giving exactly a new support, then update both corresponding columns and rows
        if x_new.size(0) != self.bary.support_size:
            self.Cxx[:, idx:idx + x_new.size(0)].copy_(
                dist_matrix(x_new, bary_supp).t() / self.eps)

    # whenever a new point is added to the bary support,
    # we need to update the shape of the tensors containing the potentials
    # here we keep track of the previous potential as starting point for the next Sinkhorn
    # computation. We add a zero in the new position added
    def _updatePotentialsContainers(self):

        # make the potentials larger only if the bary has not yet reached its maximum size (budget)
        if self.bary.support_size > self.potential_bary_sym.size(0):

            tmp_potential_bary = torch.zeros(
                self.bary.support_size * self.num_distributions, 1)
            for k in range(self.num_distributions):
                idx_pre = self.bary.support_size * k
                idx_next = self.bary.support_size * (k + 1) - 1

                idx_pre_old = (self.bary.support_size - 1) * k
                idx_next_old = (self.bary.support_size - 1) * (k + 1)

                tmp_potential_bary[idx_pre:idx_next].copy_(
                    self.potential_bary[idx_pre_old:idx_next_old])

            self.potential_bary = tmp_potential_bary

            # the potential for the OT(alpha,alpha)
            tmp_potential_bary_sym = torch.empty(self.bary.support_size, 1)
            tmp_potential_bary_sym[:-1].copy_(self.potential_bary_sym)
            tmp_potential_bary_sym[-1] = 0
            self.potential_bary_sym = tmp_potential_bary_sym

    # evaluate sinkhorn for the current barycenter and all distributions
    # code adapted from https://github.com/jeanfeydy/global-divergences
    def _computeSinkhorn(self):

        # we repeat the weight vector for the barycenter for each distribution
        α_log = self.bary.weights.log()
        β_log = self.full_weights.log()

        A = self.potential_distributions
        B = self.potential_bary

        # the iterations are performed for the potentials u/eps v/eps
        # we will multiply it back at the end of the Sinkhorn computation
        A.mul_(1 / self.eps)
        B.mul_(1 / self.eps)

        A_prev = torch.empty_like(A)

        # create list of pointers
        B_list =  [B[(i*self.bary.support_size):((i+1)*self.bary.support_size),:]\
                   for i in range(self.num_distributions)]

        Cxy = self.Cxy
        Cxy_list = self.Cxy_list
        tmpM = torch.empty_like(Cxy)

        # create list of pointers to the temporary matrix M
        sl_idx = self.support_location_indices
        tmpM_list = [
            tmpM[:, sl_idx[i]:sl_idx[i + 1]]
            for i in range(self.num_distributions)
        ]

        perform_last_step = False

        for idx_itr in range(self.sinkhorn_n_itr):

            A_prev.copy_(A)

            tmpM.copy_((A + β_log).view(1, -1) - Cxy)

            for idx_nu in range(self.num_distributions):
                B_list[idx_nu].copy_(-lse(tmpM_list[idx_nu]))

            # add alpha log (in place)
            for idx_nu in range(self.num_distributions):
                tmpM_list[idx_nu].copy_(B_list[idx_nu] + α_log -
                                        Cxy_list[idx_nu])

            A.copy_(-lse(tmpM.t()))

            if perform_last_step: break

            err = self.eps * (A - A_prev).abs().mean(
            )  # Stopping criterion: L1 norm of the updates
            if self.num_distributions * err.item() < self.sinkhorn_tol:
                perform_last_step = True

        A.mul_(self.eps)
        B.mul_(self.eps)

        # compute the sinkhorn functional OTe(alpha,beta)
        tmp_func_val = 0
        for idx_nu in range(self.num_distributions):

            inner_tmp_func_val = 0

            a = A[sl_idx[idx_nu]:sl_idx[idx_nu + 1], :].view(-1)
            s_a = self.full_weights[sl_idx[idx_nu]:sl_idx[idx_nu +
                                                          1], :].view(-1)
            inner_tmp_func_val = inner_tmp_func_val + torch.dot(a, s_a)

            b = B_list[idx_nu].view(-1)

            inner_tmp_func_val = inner_tmp_func_val + torch.dot(
                b, self.bary.weights.view(-1))

            tmp_func_val = tmp_func_val + self.mixing_weights[
                idx_nu] * inner_tmp_func_val

        self.func_val.append(tmp_func_val.item())

        return A, B

    # Compute OTe(alpha,alpha)
    # code adapted from https://github.com/jeanfeydy/global-divergences
    def _computeSymSinkhorn(self):

        α_log = self.bary.weights.log()
        A = self.potential_bary_sym

        # the iterations are performed for the potentials u/eps v/eps
        # we will multiply it back at the end of the Sinkhorn computation
        A.mul_(1 / self.eps)

        A_prev = torch.empty_like(A)

        for idx_itr in range(self.sinkhorn_n_itr):

            A_prev.copy_(A)

            A.copy_(
                0.5 * (A - lse((A + α_log).view(1, -1) - self.Cxx))
            )  # a(x)/ε = .5*(a(x)/ε + Smin_ε,y~α [ C(x,y) - a(y) ] / ε)

            err = self.eps * (A - A_prev).abs().mean(
            )  # Stopping criterion: L1 norm of the updates
            if err.item() < self.sinkhorn_tol: break

        A.copy_(-lse((A + α_log).view(1, -1) -
                     self.Cxx))  # a(x) = Smin_e,z~α [ C(x,z) - a(z) ]

        A.mul_(self.eps)

        tmp_func_val = self.mixing_weights.sum() * torch.dot(
            A.view(-1), self.bary.weights.view(-1))
        self.func_val[-1] = self.func_val[-1] - tmp_func_val.item()

        return A

    # "Virtual" method for gradient minimization w.r.t. z
    def _argminGradient(self, potentials_distributions, potentials_bary_sym):
        # warnings.warn("You are using a 'virtual' argminGradient method. This is doing nothing")
        raise Exception(
            "You are using a 'virtual' argminGradient method. This is doing nothing"
        )
        # return torch.randn(self.full_support[0,:].size(0),1)
        return None

    def currentRho(self):
        return 1 / (
            1 + torch.tensor([float(self.current_iteration)]).pow(1).item())

    # perform a single step of the Frank Wolfe algorithm
    def performFrankWolfe(self, num_itr=1):

        for idx_itr in range(num_itr):
            potentials_distributions, _ = self._computeSinkhorn()
            potentials_bary_sym = self._computeSymSinkhorn()

            new_support_point = self._argminGradient(potentials_distributions,
                                                     potentials_bary_sym)

            rho = self.currentRho()

            # add the new point to the support of the barycenter
            # perform self = (1-rho) * self + rho * other
            self.bary.convexAddSupportPoint(new_support_point, rho)

            # update the distance matrices with the new point (needs to be a 1 x d tensor)
            self._updateDistanceMatrices(new_support_point.view(1, -1),
                                         self.bary.last_updated_idx)

            # update the tensors containing the Sinkhorn potentials
            self._updatePotentialsContainers()

            # update the iteration counter
            self.current_iteration = self.current_iteration + 1

            # if the current barycenter is the best one so far...
            if self.best_func_val < 0 or self.best_func_val > self.func_val[-1]:
                self.best_bary = Distribution(self.bary)
                self.best_func_val = self.func_val[-1]

    # perform one single step of FW
    def performFrankWolfeStep(self):
        self.performFrankWolfe(1)
weights = []
pix_arr = pix[:, :, 0].reshape(pix.shape[0]**2)
for i in range(n):
    if pix_arr[i] > 50:
        y1.append(torch.tensor([Y1[i], MX - X1[i]]))
        weights.append(torch.tensor(pix_arr[i], dtype=torch.float32))

nu1t = torch.stack(y1)
w1 = torch.stack(weights).reshape((len(weights), 1))
w1 = w1 / (torch.sum(w1, dim=0)[0])
supp_meas = [nu1t]
weights_meas = [w1]

# create the list of "distributions" of which we will compute the barycenter
distributions = [Distribution(nu1t, w1)]

# barycenter initialization
init = torch.Tensor([0.5, 0.5]).view(1, -1)
# init = Distribution(torch.rand(100,2)).normalize()

init_bary = Distribution(init).normalize()

total_iter = 20000
eps = 0.005

support_budget = total_iter + 100
grid_step = 100

grid_step = min(grid_step, im_size)
Exemple #16
0
def testBarycenter():

    # generate a distribution with three points
    mu_support = torch.tensor([[1., 2.], [-3., 14.], [5., 9.]])
    mu_support2 = torch.tensor([[11., -2.], [53., 4.], [21., 0.]])
    mu_support3 = torch.tensor([[3., 4.], [83., 7.], [3., 3.]])
    mu_weights = torch.tensor([0.3, 1.0, 0.8])
    mu_weights2 = torch.tensor([0.3, 1.0, 0.8]).unsqueeze(1)

    mu0 = Distribution(mu_support)
    mu1 = Distribution(mu_support2, mu_weights)
    mu2 = Distribution(mu_support3, mu_weights2)

    mu0.normalize()
    mu1.normalize()
    mu2.normalize()

    bary_init_support = torch.tensor([[11.,12.],[8.,10.]])
    bary_init = Distribution(bary_init_support)
    bary_init.normalize()

    bary = Barycenter([mu0,mu1,mu2],bary=bary_init)

    bary._computeSinkhorn()
Exemple #17
0
def testMatch():
    ##########IMAGE TEST

    im_size = 100

    img = Image.open(r"cheeta2.jpg")
    I = np.asarray(Image.open(r"cheeta2.jpg"))
    img.thumbnail((im_size, im_size),
                  Image.ANTIALIAS)  # resizes image in-place
    imgplot = plt.imshow(img)
    plt.show()
    pix = np.array(img)
    min_side = np.min(pix[:, :, 0].shape)
    # pix = pix[2:,:,0]/sum(sum(pix[2:,:,0]))
    pix = 255 - pix[0:min_side, 0:min_side]
    # x=torch.linspace(0,1,steps=62)
    # y=torch.linspace(0,1,steps=62)
    # X, Y = torch.meshgrid(x, y)
    # X1 = X.reshape(X.shape[0]**2)
    # Y1 = Y.reshape(Y.shape[0] ** 2)

    imgplot = plt.imshow(img)
    # pix = pix/sum(sum(pix))
    x = torch.linspace(0, 1, steps=pix.shape[0])
    y = torch.linspace(0, 1, steps=pix.shape[0])
    X, Y = torch.meshgrid(x, y)
    X1 = X.reshape(X.shape[0]**2)
    Y1 = Y.reshape(Y.shape[0]**2)
    n = X.shape[0]**2
    y1 = []

    MX = max(X1)

    weights = []
    pix_arr = pix[:, :, 0].reshape(pix.shape[0]**2)
    for i in range(n):
        if pix_arr[i] > 50:
            y1.append(torch.tensor([Y1[i], MX - X1[i]]))
            # y1.append(torch.tensor([X1[i], Y1[i]]))
            weights.append(torch.tensor(pix_arr[i], dtype=torch.float32))

    nu1t = torch.stack(y1)
    w1 = torch.stack(weights).reshape((len(weights), 1))
    w1 = w1 / (torch.sum(w1, dim=0)[0])
    supp_meas = [nu1t]
    weights_meas = [w1]

    distributions = [Distribution(nu1t, w1)]

    init = torch.Tensor([0.5, 0.5])

    init_bary = Distribution(init.view(1, -1)).normalize()

    # init_bary = Distribution(torch.rand(100,2)).normalize()

    # init_bary = distributions[0]

    niter = 10000
    eps = 0.01

    support_budget = niter + 100
    grid_step = 100

    grid_step = min(grid_step, im_size)

    bary = GridBarycenter(distributions, init_bary, support_budget = support_budget,\
                          grid_step = grid_step, eps=eps,\
                          sinkhorn_n_itr=100,sinkhorn_tol=1e-3)

    plot(distributions[0].support, distributions[0].weights, bins=im_size)
    plt.show()

    n_iter_per_loop = 200

    print('starting FW iterations')
    for i in range(100):
        t1 = time.time()
        bary.performFrankWolfe(n_iter_per_loop)
        t1 = time.time() - t1

        print('Time for ', n_iter_per_loop, ' FW iterations:', t1)
        ### DEBUG FOR PRINTING
        print('n iterations = ', (i + 1) * n_iter_per_loop, 'n support points',
              bary.bary.support_size)

        print(min(bary.func_val))

        plt.figure()
        plt.plot(bary.func_val[30:])
        # plot(bary.bary.support, bary.bary.weights,bins=grid_step)
        plt.show()

        plot(bary.best_bary.support, bary.best_bary.weights, bins=im_size)
        plt.show()

        plot(bary.best_bary.support, bary.best_bary.weights,\
             bins=im_size, thresh=bary.best_bary.weights.min().item())
        plt.show()

        plot(bary.bary.support, bary.bary.weights, bins=im_size)
        plt.show()

        ciao = 3