Skip to content

vivienseguy/Large-Scale-OT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Large-Scale-OT

Implementation in PyTorch of stochastic algorithms for the computation of Regularized Optimal Transport proposed in [1]

[1] Proposes:

  • A stochastic algorithm (Alg. 1) for computing the optimal dual variables of the regularized OT problem (from which the regularized-OT objective can be computed simply)
  • A stochastic algorithm (Alg. 2) for learning an Optimal Map, parameterized as a Deep Neural Network between the source and target probability measures

Both entropy and L2 regularizations are considered and implemented.

Requirements

python2 or python3
pytorch
matplotlib

Install

git clone https://github.com/vivienseguy/Large-Scale-OT.git

Usage

Start by creating the regularized-OT computation class: either PyTorchStochasticDiscreteOT or PyTorchStochasticSemiDiscreteOT depending on your setting.

from StochasticOTClasses.StochasticOTDiscrete import PyTorchStochasticDiscreteOT

discreteOTComputer = PyTorchStochasticDiscreteOT(xs, ws, xt, wt, reg_type='l2', reg_val=0.02, device_type='cpu')

Compute the optimal dual variables through Alg. 1.:

history = discreteOTComputer.learn_OT_dual_variables(epochs=1000, batch_size=50, lr=0.0005)

Once the optimal dual variables have been obtained, you can compute the OT loss stochastically:

d_stochastic = discreteOTComputer.compute_OT_MonteCarlo(epochs=20, batch_size=50)

You can also learn an approximate optimal map between the two probability measures by learning the barycentric mapping (ALg. 2.). The mapping is parameterized as a deep neural network that you can supply in the functions parameters. Otherwise a default small 3-layers NN is used.

bp_history = discreteOTComputer.learn_barycentric_mapping(epochs=300, batch_size=50, lr=0.000002)

Once learned, you can apply the (approximate) optimal mapping to some sample via:

xsf = discreteOTComputer.evaluate_barycentric_mapping(xs)

You can visualize the source, target and mapped samples:

import matplotlib.pylab as pl

pl.figure()
pl.plot(xs[:, 0], xs[:, 1], '+b', label='source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='target samples')
pl.plot(xsf[:, 0], xsf[:, 1], '+g', label='mapped source samples')
pl.legend()

References

[1] Seguy, Vivien and Damodaran, Bharath Bhushan and Flamary, Rémi and Courty, Nicolas and Rolet, Antoine and Blondel, Mathieu. Large-Scale Optimal Transport and Mapping Estimation. Proceedings of the International Conference in Learning Representations (2018)

@inproceedings{seguy2018large,
  title={Large-Scale Optimal Transport and Mapping Estimation},
  author={Seguy, Vivien and Damodaran, Bharath Bhushan and Flamary, R{\'e}mi and Courty, Nicolas and Rolet, Antoine and Blondel, Mathieu},
  booktitle={Proceedings of the International Conference in Learning Representations},
  year={2018},
}

About

Stochastic algorithms for computing Regularized Optimal Transport

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages