def projection_solver(M, constraints, rank=2, maxiter=100, tol=1e-6): """ does projection, so that L.T L is closest to M(y) and y is the projection of cloest to L.T L what satisfies the contraints any fixed point is a valid solution """ lenys = len(M.matrix_monos) rowsM = len(M.row_monos) lenLs = rank*rowsM L = np.random.randn(rank, len(M.row_monos)) L = np.array([[1,1,1,1],[1,2,4,8]])/np.sqrt(2) + np.random.randn(2,4)*0.3 Bf = M.get_Bflat() A,b = M.get_Ab(constraints, cvxoptmode = False) weightone = 1 A = A*weightone; b = b*weightone for i in xrange(maxiter): yL = Bf.dot(L.T.dot(L).flatten()[:, np.newaxis])/np.sum(Bf,1)[:,np.newaxis] y = util.project_nullspace(A,b,yL, randomize =0/np.sqrt(i+100)) My = M.numeric_instance(y) objyconstrs = scipy.linalg.norm(L.T.dot(L) - M.numeric_instance(y)) + scipy.linalg.norm(A.dot(y)-b) U,D,V=scipy.linalg.svd(My) L = V[0:rank,:]*np.sqrt(D[0:rank, np.newaxis]) objprojL = scipy.linalg.norm(L.T.dot(L) - M.numeric_instance(y)) + scipy.linalg.norm(A.dot(y)-b) print '%d:\t%f\t%f' % (i, objyconstrs, objprojL) if objprojL < tol: break return y,L
def convex_projection_solver(M, constraints, rank=2, tau=1, delta = 0.1, maxiter=100, tol=1e-6): """ does projection, so that L.T L is closest to M(y) and y is the projection of cloest to L.T L what satisfies the contraints any fixed point is a valid solution """ lenys = len(M.matrix_monos) rowsM = len(M.row_monos) lenLs = rank*rowsM Bf = M.get_Bflat() A,b = M.get_Ab(constraints, cvxoptmode = False) weightone = 1 A = A*weightone; b = b*weightone X = np.random.randn(rowsM, rowsM) for i in xrange(maxiter): yX = Bf.dot(X.flatten()[:, np.newaxis])/np.sum(Bf,1)[:,np.newaxis] yproj = util.project_nullspace(A,b,yX, randomize =0) y = yproj #yX + delta*(yproj - yX) My = M.numeric_instance(y) objyconstrs = scipy.linalg.norm(X- M.numeric_instance(y)) + scipy.linalg.norm(A.dot(y)-b) U,D,V=scipy.linalg.svd(My) D = np.fmax(D - tau, 0) X = U.dot(np.diag(D)).dot(V) objprojL = scipy.linalg.norm(X - M.numeric_instance(y)) + scipy.linalg.norm(A.dot(y)-b) print '%d:\t%f\t%f' % (i, objyconstrs, objprojL) if objprojL < tol: break return y,X