Python port of the MATLAB package Manopt, for manifold optimization using Theano for automatic differentiation.
This project is independent from the Manopt project.
http://deeplearning.net/software/theano/
Manifold | Implemented |
---|---|
Sphere | Partially |
Stiefel | Partially |
Grassmann | Partially |
Symmetric positive semidefinite, fixed-rank (complex) |
Partially |
Oblique manifold | Partially |
Solver | Type | Implemented |
---|---|---|
Steepest-descent | First-order | Partially |
Conjugate-gradient | First-order | Partially |
Trust-regions | Second-order | Partially |
Particle swarm (PSO) | Derivative-free | Partially |
Nelder-Mead | Derivative-free | Partially |
This package depends on python 2.7.*, numpy, scipy and Theano. Instructions for installing numpy, scipy and Theano on different operating systems can be found here.
You can install pymanopt with the following command:
pip install --user git+https://github.com/j-towns/pymanopt.git
To do optimization with pymanopt, you will need to create a manifold object, a solver object, and a cost function. Classes of manifolds and solvers are provided with pymanopt. Cost functions have to be set up using Theano. A tutorial on Theano can be found here.
import theano.tensor as T
import numpy as np
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent
from pymanopt.manifolds import Stiefel
# ---------------------------------
# Define cost function using Theano
# ---------------------------------
# Note, your cost function needs to have one (matrix) input and one (scalar) output.
X = T.matrix()
# Cost is the sum of all of the elements of the matrix X.
cost = T.sum(X)
# ---------------------------------
# Setup solver and manifold objects
# ---------------------------------
solver = SteepestDescent()
manifold = Stiefel(5, 2)
# --------------------
# Setup problem object
# --------------------
problem = Problem(man=manifold, ad_cost=cost, ad_arg=X)
# --------------------
# Perform optimization
# --------------------
# Currently the solve function takes the problem object as input.
Xopt = solver.solve(problem)
print(Xopt)
See here for more examples.