def iterative( pts, vals, pce ):

    num_pts = pts.shape[1]

    all_indices = PolyIndexVector()
    pce.get_basis_indices( all_indices )

    sorter = lambda x: abs( x.get_array_index() )
    remain_indices = sorted( all_indices, key = sorter )
    num_keep = len( all_indices )


    max_iter = 3;
    increment = ( len( all_indices ) - num_pts ) / ( max_iter - 1 )
    for iter in xrange( max_iter ):
        i = 0
        indices = PolyIndexVector()
        indices.resize( int( num_keep ) )
        for index in remain_indices:
            indices[int(i)] = index
            index.set_array_index( i )
            i += 1 
            if ( i >= num_keep ):
                break
        
        pce.set_basis_indices( indices )
        #least_interpolation( pts, vals, pce )
        A = pce.build_vandermonde( pts )
        solver = SPGL1Solver( residual_tolerance = 0. )
        coeff, tmp = solver.solve( A, vals.squeeze() )
        #coeff = pce.get_coefficients()
        #pce.get_basis_indices( indices )
        #print coeff.shape, len( indices )

        sorter = lambda x: abs( coeff[x.get_array_index(),0]**2 )
        remain_indices = sorted( indices, key = sorter )[::-1]
        num_keep = num_pts + increment * ( max_iter - iter - 2)

    pce.set_coefficients( coeff.reshape(coeff.shape[0], 1 ) )
    pce.set_basis_indices( indices )
    def test_tensor_product_basis( self ):

        # test orthogonal basis functions
        num_dims = 2
        #domain = TensorProductDomain( num_dims, ranges = [[-1.,1.]] )
        domain = get_symmetric_hypercube( -1., 1., num_dims )
        poly_1d = OrthogPolyVector();
        poly_1d.resize( 1 )
        poly_1d[0] = JacobiPolynomial1D( 0.5, 0.5 )
        #basis = TensorProductBasis( num_dims, poly_1d )
        basis = TensorProductBasis()
        basis.set_domain( domain )
        basis.set_bases( poly_1d, [numpy.arange( num_dims,dtype=numpy.int32 )], 
                         num_dims )
        #assert basis.num_dims == num_dims
        x = numpy.random.uniform( -1., 1., ( num_dims, 20 ) )
        index = PolynomialIndex( numpy.array( [[1, 1]], numpy.int32 ) )
        true_val = poly_1d[0].value_set( x[0,:], index.level(0), -1., 1. ) * \
            poly_1d[0].value_set( x[1,:], index.level(1), -1., 1. )
        assert numpy.allclose( basis.value_set( x, index), true_val )
        #true_grad = poly_1d[0].gradient_set( x[0,:], index.level( 0 ), -1., 1. ) * \
        #    poly_1d[0].value_set( x[1,:], index.level(1), -1., 1. )
        #assert numpy.allclose( basis.gradient( x, index, 0), 
        #                       true_grad )
        #true_grad = poly_1d[0].value_set( x[0,:], index.level(0), -1., 1. ) * \
        #    poly_1d[0].gradient_set( x[1,:], index.level(1), -1., 1. )
        #assert numpy.allclose( basis.gradient( x, index, 1), 
        #                       true_grad )

        num_dims = 2
        poly_1d = OrthogPolyVector();
        poly_1d.resize( 2 )
        poly_1d[0] = JacobiPolynomial1D( 0.5, 0.5 )
        poly_1d[1] = JacobiPolynomial1D( 0.5, 1.5 )
        #basis = TensorProductBasis( num_dims, poly_1d )
        basis = TensorProductBasis()
        basis.set_domain( domain )
        basis.set_bases( poly_1d, [numpy.array([0],numpy.int32),
                                   numpy.array([1],numpy.int32)], 
                         num_dims )
        x_1 = numpy.random.uniform( -1., 1., ( 1, 20 ) )
        x_2 = numpy.random.normal( 0., 1., ( 1, 20 ) )
        x = numpy.vstack( ( x_1, x_2 ) )
        index = PolynomialIndex( numpy.array( [[1,1]], numpy.int32 ) )
        true_val = poly_1d[0].value_set( x[0,:], index.level(0), -1., 1. ) * \
            poly_1d[1].value_set( x[1,:], index.level(1), -1., 1. )
        assert numpy.allclose( basis.value_set( x, index ), true_val )
        #true_grad = poly_1d[0].gradient_set( x[0,:], index.level(0), -1., 1. ) * \
        #    poly_1d[1].value_set( x[1,:], index.level(1), -1., 1. )
        #assert numpy.allclose( basis.gradient( x, index, 0), 
        #                       true_grad )
        #true_grad = poly_1d[0].value_set( x[0,:], index.level(0), -1., 1. ) * \
        #    poly_1d[1].gradient_set( x[1,:], index.level(1), -1., 1. )
        #assert numpy.allclose( basis.gradient( x, index, 1), 
        #                       true_grad )
        num_dims = 3
        domain = get_symmetric_hypercube( -1., 1., num_dims )
        basis = TensorProductBasis()
        basis.set_domain( domain )
        basis.set_bases( poly_1d, [numpy.array([0,2],numpy.int32),
                                   numpy.array([1],numpy.int32)], 
                         num_dims )
        indices = PolyIndexVector()
        indices.resize( 3 )
        indices[0] = PolynomialIndex( numpy.array( [[1,1]], numpy.int32 ) )
        indices[1] = PolynomialIndex( numpy.array( [[2,1]], numpy.int32 ) )
        indices[2] = PolynomialIndex( numpy.array( [[0,2],[2,3]], numpy.int32 ) )
        indices[2] = PolynomialIndex( numpy.array( [[0,3],[1,1],[2,2]], 
                                                   numpy.int32 ) )
        x_1 = numpy.random.uniform( -1., 1., ( 1, 20 ) )
        x_2 = numpy.random.normal( 0., 1., ( 1, 20 ) )
        x_3 = numpy.random.normal( 0., 1., ( 1, 20 ) )
        x = numpy.vstack( ( x_1, x_2, x_3 ) )
        vals_m = basis.value_set_multiple( x, indices )
        for i, index in enumerate( indices ):
            true_vals = poly_1d[0].value_set( x[0,:], index.level(0), 
                      domain[0], domain[1] ) * poly_1d[1].value_set( x[1,:], index.level(1), domain[2], domain[3] ) *  poly_1d[0].value_set( x[2,:], index.level(2), 
                      domain[4], domain[5] )
            vals = basis.value_set( x, index )
            assert numpy.allclose( true_vals, vals )
            assert numpy.allclose( true_vals, vals_m[:,i] )
    def test_index_generation( self ):
        num_dims = 2
        order = 3
        indices = PolyIndexVector()
        get_hyperbolic_indices( num_dims, order, 1., indices )        

        true_indices = [[0, 0],
                        [1, 0],
                        [2, 0],
                        [0, 1],
                        [1, 1],
                        [0, 2],
                        [1, 2],
                        [2, 1],
                        [3, 0],
                        [0, 3]]

        indices_list = []
        for i in xrange( indices.size() ):
            indices_list.append( indices[i].uncompressed_data( num_dims ) )
        
        indices = unique_matrix_rows( numpy.array( indices_list ))
        true_indices =  unique_matrix_rows( numpy.array( true_indices ) )
        assert numpy.allclose( true_indices, indices )

        num_dims = 3
        order = 2
        indices = PolyIndexVector()
        get_hyperbolic_indices( num_dims, order, 1., indices )

        true_indices = [[0, 0, 0],
                        [1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1],
                        [2, 0, 0],
                        [1, 1, 0],
                        [0, 2, 0],
                        [1, 0, 1],
                        [0, 1, 1],
                        [0, 0, 2]]

        indices_list = []
        for i in xrange( indices.size() ):
            indices_list.append( indices[i].uncompressed_data( num_dims ) )
        
        indices = unique_matrix_rows( numpy.array( indices_list ))
        true_indices =  unique_matrix_rows( numpy.array( true_indices ) )
        assert numpy.allclose( true_indices, indices )

        num_dims = 2
        order = 3
        indices = PolyIndexVector()
        get_hyperbolic_indices( num_dims, order, .5, indices )

        true_indices = [[0, 0],
                        [1, 0],
                        [2, 0],
                        [3, 0],
                        [0, 1],
                        [0, 2],
                        [0, 3]]

        indices_list = []
        for i in xrange( indices.size() ):
            indices_list.append( indices[i].uncompressed_data( num_dims ) )
        
        indices = unique_matrix_rows( numpy.array( indices_list ))
        true_indices =  unique_matrix_rows( numpy.array( true_indices ) )
        assert numpy.allclose( true_indices, indices )

        num_dims = 2
        order = 6
        indices = PolyIndexVector()
        get_hyperbolic_indices( num_dims, order, .5, indices )

        true_indices = [[0, 0],
                        [1, 0],
                        [2, 0],
                        [3, 0],
                        [4, 0],
                        [5, 0],
                        [6, 0],
                        [1, 1],
                        [2, 1],
                        [1, 2],
                        [0, 1],
                        [0, 2],
                        [0, 3],
                        [0, 4],
                        [0, 5],
                        [0, 6]]

        indices_list = []
        for i in xrange( indices.size() ):
            indices_list.append( indices[i].uncompressed_data( num_dims ) )
        
        indices = unique_matrix_rows( numpy.array( indices_list ))
        true_indices =  unique_matrix_rows( numpy.array( true_indices ) )
        assert numpy.allclose( true_indices, indices )
def osciallator_study():

    oscillator_model = RandomOscillator()
    g = lambda x: oscillator_model( x ).squeeze()
    f = CppModel( g )
    num_dims = 6
    domain =  numpy.array( [0.08,0.12,0.03,0.04,0.08,0.12,
                            0.8,1.2,0.45,0.55,-0.05,0.05], numpy.double )

    ##---------------------------------------------------------------------
    # Read in test data
    ##---------------------------------------------------------------------
    """
    func_basename = 'oscillator-'
    data_path = '/home/jdjakem/software/pyheat/work/gp-pce-comparison/data'
    data_path = join(data_path, 'convergence-data' )
    test_pts_filename = join( data_path, "TestPts" + str( num_dims )+
    ".txt" )
    test_pts = numpy.loadtxt( test_pts_filename, delimiter = ' ' ).T    
    test_pts = domain.map_from_unit_hypercube( test_pts )
    test_vals_filename = join( join( data_path, 'random-oscillator'  ), 
    func_basename + 
    str( num_dims )+
    '-test-values.txt' )
    test_vals = numpy.loadtxt( test_vals_filename, 
    delimiter = ' ' )
    """
    test_pts = numpy.random.uniform( 0, 1, (num_dims, 10000) )
    unit_hypercube = set_hypercube_domain( num_dims, 0., 1. )
    test_pts = hypercube_map( test_pts, unit_hypercube, domain )
    test_vals = f.evaluate_set(test_pts).squeeze()

    
    ##---------------------------------------------------------------------
    # Build sparse grid to machine precision
    ##--------------------------------------------------------------------- 
    print 'Building sparse grid'
    quadrature_rule_1d = ClenshawCurtisQuadRule1D()
    #quadrature_rule_1d = GaussPattersonQuadRule1D()
    basis = LagrangePolynomialBasis()
    max_level = 20
    SG, sg_error, num_pts = build_sparse_grid_cpp( quadrature_rule_1d, basis,
                                                   domain = domain, 
                                                   f = f, num_dims = num_dims, 
                                                   max_level = max_level,
                                                   max_num_points = 1000,
                                                   test_pts = test_pts, 
                                                   test_vals = test_vals,
                                                   breaks = None )
    
    print 'num points in sparse grid', SG.num_points()
    pred_vals = SG.evaluate_set( test_pts ).squeeze()
    print SG.num_function_evaluations()
    print 'test_mean', test_vals.mean()
    print 'grid mean', pred_vals.mean()
    print 'sparse grid error: ', get_l2_error( test_vals, pred_vals )

    
    ##---------------------------------------------------------------------
    # Convert the sparse grid into a PCE
    ##--------------------------------------------------------------------- 
    print 'Building PCE'
    # test conversion to pce
    poly_1d = OrthogPolyVector()
    poly_1d.resize( 1 )
    poly_1d[0] = LegendrePolynomial1D()
    basis = TensorProductBasis()
    basis.set_domain( domain )
    basis.set_bases( poly_1d, 
                     [numpy.arange( num_dims, dtype=numpy.int32 )], 
                     num_dims )
    pce = PolynomialChaosExpansion()
    pce.domain( domain )
    pce.basis( basis )
    SG.convert_to_polynomial_chaos_expansion( pce )
    print 'PCE error: ', get_l2_error( test_vals, 
                                       pce.evaluate_set( test_pts ).squeeze() )

    oracle_basis = PolyIndexVector()
    pce.get_basis_indices( oracle_basis )

    ##---------------------------------------------------------------------
    # Initialise PCEs
    ##---------------------------------------------------------------------
    pce_least = PolynomialChaosExpansion()
    pce_least.domain( domain )
    pce_least.basis( basis )

    pce_least_oracle = PolynomialChaosExpansion()
    pce_least_oracle.domain( domain )
    pce_least_oracle.basis( basis )

    pce_omp = PolynomialChaosExpansion()
    pce_omp.domain( domain )
    pce_omp.basis( basis )

    pce_omp_oracle = PolynomialChaosExpansion()
    pce_omp_oracle.domain( domain )
    pce_omp_oracle.basis( basis )

    pce_nesta = PolynomialChaosExpansion()
    pce_nesta.domain( domain )
    pce_nesta.basis( basis )

    pce_nesta_oracle = PolynomialChaosExpansion()
    pce_nesta_oracle.domain( domain )
    pce_nesta_oracle.basis( basis )

    pce_spgl1 = PolynomialChaosExpansion()
    pce_spgl1.domain( domain )
    pce_spgl1.basis( basis )

    pce_lsq_oracle = PolynomialChaosExpansion()
    pce_lsq_oracle.domain( domain )
    pce_lsq_oracle.basis( basis )

    sovler_basenames = ['lsq-oracle', 'least-interpolant', 'nesta', 
                        'nesta_oracle', 'omp', 'omp-oracle']

    ##---------------------------------------------------------------------
    # Pre compute information necessary for best N term PCE
    ##---------------------------------------------------------------------
    coeff = pce.get_coefficients()
    basis_indices = PolyIndexVector()
    pce.get_basis_indices( basis_indices )
    indices = sorted( basis_indices, 
                      key = lambda x: x.get_array_index() )

    
    ##---------------------------------------------------------------------
    # Perfrom convergence study
    ##---------------------------------------------------------------------

    num_pts_sg = num_pts
    max_num_points = 200
    num_pts = numpy.logspace( numpy.log10(10*num_dims), numpy.log10( max_num_points ), 5 )
    num_pts = numpy.asarray( num_pts, dtype = numpy.int32 )

    nterm_pce_error = numpy.empty( ( len( num_pts ) ), numpy.double )
    least_error = numpy.empty( ( len( num_pts ) ), numpy.double )
    least_oracle_error = numpy.empty( ( len( num_pts ) ), numpy.double )
    omp_error = numpy.empty( ( len( num_pts ) ), numpy.double )
    omp_oracle_error = numpy.empty( ( len( num_pts ) ), numpy.double )
    nesta_error = numpy.empty( ( len( num_pts ) ), numpy.double )
    nesta_oracle_error = numpy.empty( ( len( num_pts ) ), numpy.double )
    spgl1_error = numpy.empty( ( len( num_pts ) ), numpy.double )
    lsq_oracle_error = numpy.empty( ( len( num_pts ) ), numpy.double )


    I = numpy.argsort( numpy.absolute( coeff[:,0]**2 ) )[::-1]
    for n in xrange( len( num_pts ) ):
        #
        # Construct best N term approximation
        #
        num_indices = 0
        pce.set_coefficients( coeff[I[:num_pts[n]]].reshape( ( num_pts[n], 1 ) ))
        N_term_basis_indices = PolyIndexVector()
        N_term_basis_indices.resize( int(num_pts[n]) )
        max_degree = 0
        for i in I[:num_pts[n]]:
            N_term_basis_indices[num_indices] = ( indices[i] )
            indices[i].set_array_index( num_indices )
            num_indices += 1
            max_degree = max( max_degree, indices[i].level_sum() )
            #print indices[i]
        print 'max_degree', max_degree, num_pts[n]
        pce.set_basis_indices( N_term_basis_indices )
        pred_vals = pce.evaluate_set( test_pts ).squeeze()
        l2_error = get_l2_error( test_vals, pred_vals )

        print 'best N term pce error: ', l2_error
        nterm_pce_error[n] = l2_error
        
        #
        # Construct PCE on a lHD and repeat to account for statistical variation
        # in the LHD
        #
        from math_tools_cpp import ilhs, lhs
        seed = 1
        num_reps = 10
        least_errors = numpy.ones( num_reps )
        least_oracle_errors = numpy.ones( num_reps )
        omp_errors = numpy.ones( num_reps )
        omp_oracle_errors = numpy.ones( num_reps )
        nesta_errors = numpy.ones( num_reps )
        nesta_oracle_errors = numpy.ones( num_reps )
        spgl1_errors = numpy.ones( num_reps )
        lsq_oracle_errors = numpy.ones( num_reps )
        for k in xrange( num_reps ):
            build_pts = lhs( num_dims, num_pts[n], seed )
            seed += 1
            #build_pts = ilhs( num_dims, num_pts[n], 2, 0 );
            build_pts = hypercube_map( build_pts, unit_hypercube, domain )
            build_vals = f.evaluate_set( build_pts )
            """
            # least interpolant
            pce_least.set_basis_indices( PolyIndexVector() )
            least_interpolation( build_pts, 
                                 build_vals.reshape( build_vals.shape[0], 1 ), 
                                 pce_least );
            pred_vals = pce_least.evaluate_set( test_pts ).squeeze()
            l2_error = get_l2_error( test_vals, pred_vals )
            least_errors[k] = l2_error
            print 'least interpolant ', l2_error
           
            # oracle least interpolant
            pce_least_oracle.set_basis_indices( N_term_basis_indices )
            least_interpolation( build_pts, 
            build_vals.reshape( build_vals.shape[0], 1 ), 
            pce_least_oracle );
            pred_vals = pce_least_oracle.evaluate_set( test_pts ).squeeze()
            l2_error = get_l2_error( test_vals, pred_vals )
            least_oracle_errors[k] = l2_error
            
            
            # NESTA pce
            pce_nesta.set_basis_indices( oracle_basis )
            NESTA( build_pts, 
                   build_vals.reshape( build_vals.shape[0], 1 ), 
                   pce_nesta );
            pred_vals = pce_nesta.evaluate_set( test_pts ).squeeze()
            l2_error = get_l2_error( test_vals, pred_vals )
            nesta_errors[k] = l2_error
            print 'nesta ', l2_error

            # NESTA oracle pce
            pce_nesta_oracle.set_basis_indices( N_term_basis_indices )
            NESTA( build_pts, 
                   build_vals.reshape( build_vals.shape[0], 1 ), 
                   pce_nesta_oracle );
            pred_vals = pce_nesta_oracle.evaluate_set( test_pts ).squeeze()
            l2_error = get_l2_error( test_vals, pred_vals )
            nesta_oracle_errors[k] = l2_error
            print 'nesta oracle', l2_error
            
            # SPGL1 pce
            #pce_spgl1.set_basis_indices( N_term_basis_indices )
            pce_spgl1.set_basis_indices( oracle_basis )
            SPGL1( build_pts, 
            build_vals.reshape( build_vals.shape[0], 1 ), 
            pce_spgl1, test_pts, test_vals );
            pred_vals = pce_spgl1.evaluate_set( test_pts ).squeeze()
            l2_error = get_l2_error( test_vals, pred_vals )
            spgl1_errors[k] = l2_error
            print 'spgl1', l2_error
            
            
            # least squares orcale pce
            pce_lsq_oracle.set_basis_indices( N_term_basis_indices )
            least_squares( build_pts, 
                           build_vals.reshape( build_vals.shape[0], 1 ), 
                           pce_lsq_oracle );
            pred_vals = pce_lsq_oracle.evaluate_set( test_pts ).squeeze()
            l2_error = get_l2_error( test_vals, pred_vals )
            lsq_oracle_errors[k] = l2_error
            print 'lsq', l2_error
            """
            # omp pce
            from indexing_cpp import  get_hyperbolic_indices
            total_degree_indices = PolyIndexVector()
            get_hyperbolic_indices( num_dims, 8, 1, 
                                    total_degree_indices )
            #pce_omp.set_basis_indices( oracle_basis )
            pce_omp.set_basis_indices( total_degree_indices )
            OMP_fast_cv( build_pts, 
                 build_vals.reshape( build_vals.shape[0], 1 ), 
                 pce_omp );
            pred_vals = pce_omp.evaluate_set( test_pts ).squeeze()
            l2_error = get_l2_error( test_vals, pred_vals )
            omp_errors[k] = l2_error
            print 'omp', l2_error
            
            # omp oracle pce
            #pce_omp_oracle.set_basis_indices( N_term_basis_indices )
            #pce_omp_oracle.set_basis_indices( oracle_basis )
            pce_omp_oracle.set_basis_indices( total_degree_indices )
            OMP_brute_cv( build_pts, 
                          build_vals.reshape( build_vals.shape[0], 1 ), 
                          pce_omp_oracle );
            pred_vals = pce_omp_oracle.evaluate_set( test_pts ).squeeze()
            l2_error = get_l2_error( test_vals, pred_vals )
            omp_oracle_errors[k] = l2_error
            print 'omp oracle', l2_error
            
            

        least_error[n] = least_errors.mean()
        least_oracle_error[n] = least_oracle_errors.mean()
        #bregman_error[n] = bregman_errors.mean()
        omp_error[n] = omp_errors.mean()
        omp_oracle_error[n] = omp_oracle_errors.mean()
        nesta_error[n] = nesta_errors.mean()
        nesta_oracle_error[n] = nesta_oracle_errors.mean()
        spgl1_error[n] = spgl1_errors.mean()
        lsq_oracle_error[n] = lsq_oracle_errors.mean()
        print 'least interpolant error: ',  least_error[n]
        print 'least interpolant oracle error: ',  least_oracle_error[n]
        print 'omp error: ',  omp_error[n]
        print 'omp oracle error: ',  omp_oracle_error[n]
        print 'nesta error: ',  nesta_error[n]
        print 'nesta oracle error: ',  nesta_oracle_error[n]
        print 'spgl1 error: ',  spgl1_error[n]
        print 'lsq oracle error: ',  lsq_oracle_error[n]
        
    print 'sparse grid ', sg_error
    print 'n term', nterm_pce_error
    print 'least', least_error
    print 'least_oracle', least_oracle_error
    print 'omp error: ',  omp_error
    print 'omp oracle error: ',  omp_oracle_error
    print 'nesta', nesta_error
    print 'nesta oracle', nesta_oracle_error
    print 'spgl1', spgl1_error
    print 'lsq oracle', lsq_oracle_error

    print 'serializing'
    import pickle
    pickle.dump( nterm_pce_error, open( 'nterm_pce_error.p', 'wb' ) )
    pickle.dump( sg_error, open( 'sg_error.p', 'wb' ) )
    pickle.dump( least_error, open( 'least_error.p', 'wb' ) )
    pickle.dump( omp_error, open( 'omp_error.p', 'wb' ) )
    pickle.dump( omp_oracle_error, open( 'omp_oracle_error.p', 'wb' ) )
    pickle.dump( nesta_error, open( 'nesta_error.p', 'wb' ) )
    pickle.dump( nesta_oracle_error, open( 'nesta_oracle_error.p', 'wb' ) )

    print 'serialization complete'

    # plot convergence
    import pylab
    #pylab.loglog( num_pts_sg, sg_error, 'o-', label = 'sparse grid' )
    pylab.loglog( num_pts, nterm_pce_error, 'o-', label = 'best N term')
    
    pylab.loglog( num_pts, least_error, 'o-', label = 'least interpolant')
    #pylab.loglog( num_pts, least_oracle_error, 'o-', 
    #              label = 'oracle least interpolant')
    pylab.loglog( num_pts, omp_error, 'o-', 
                  label = 'OMP')
    pylab.loglog( num_pts, omp_oracle_error, 'o-', 
                  label = 'oracle OMP')
    print nesta_error.shape, nesta_oracle_error.shape
    pylab.loglog( num_pts, nesta_error, 'o-', 
                  label = 'NESTA')
    pylab.loglog( num_pts, nesta_oracle_error, 'o-', 
                  label = 'NESTA oracle')
    #pylab.loglog( num_pts, spgl1_error, 'o-', 
    #              label = 'SPGL1')                  
    #pylab.loglog( num_pts, lsq_oracle_error, 'o-', 
    #              label = 'Least squars')
    pylab.legend()
    pylab.show()
    assert False

    # plot pce coefficients
    from utilities.visualisation import plot_pce_coefficients
    indices = sorted( pce.basis_indices, 
                      key = lambda x: x.norm_of_index( 1 ) )
    coeff = numpy.empty( ( pce.coeff.shape[0] ), numpy.double )
    for i in xrange( len( indices ) ):
        coeff[i] = pce.coeff[indices[i].get_array_index()]
        if ( abs( coeff[i] ) < numpy.finfo( numpy.double ).eps ):
            coeff[i] = 0.
    plot_pce_coefficients( coeff.reshape(coeff.shape[0],1) )
def OMP_fast_cv( pts, vals, pce, test_pts = None, test_vals = None ):
    
    i = 0
    indices = PolyIndexVector()
    pce.get_basis_indices( indices )
    for index in indices:
        index.set_array_index( i )
        i += 1 


        
    all_train_indices = []
    all_validation_indices = []
    #cv_iterator = LeaveOneOutCrossValidationIterator( pts.shape[1] )
    cv_iterator = KFoldCrossValidationIterator( 10, pts.shape[1] )
    for train_indices, validation_indices in cv_iterator:
        all_train_indices.append( train_indices )
        all_validation_indices.append( validation_indices )


    A = pce.build_vandermonde( pts )
    out = orthogonal_matching_pursuit_cholesky( A, vals.squeeze(), 
                                                all_train_indices,
                                                all_validation_indices, 
                                                0., 
                                                numpy.iinfo(numpy.int32).max, 
                                                0 )
    """    
    # check cholesky omp is producing correct one at a time estimates
    new_indices = PolyIndexVector() 
    new_indices.resize( 20 )
    pce.get_basis_indices( indices )
    for i in xrange( 20 ):
    new_indices[int(i)] = indices[int(out[1][1,i])]
    new_indices[int(i)].set_array_index( i )
    pce.set_basis_indices( new_indices )
    i = 0
    error = numpy.empty( (len(all_validation_indices)), numpy.double )
    for train_indices, validation_indices in cv_iterator:
    A = pce.build_vandermonde( pts[:,train_indices] )
    coeff = svd_solve_default( A, vals[train_indices] )
    pce.set_coefficients( coeff[0] );
    pred_vals = pce.evaluate_set( pts[:,validation_indices] )
    print pred_vals, pts[:,validation_indices], i
    error[i] = vals[validation_indices,0] - pred_vals.squeeze()
    i += 1
    print error
    print numpy.asarray( out[2][19] )[:,0]
    assert False
    """
    
    
    solutions = out[0] #num_folds x num_steps x num_pts_per_fold
    metrics = out[1]
    cv_residuals = out[2]

    print len( cv_residuals ), len( cv_residuals[0] ), len( cv_residuals[0][0] )

    num_steps = len( cv_residuals )
    num_folds = len( cv_residuals[0] )

    scores = numpy.zeros( ( num_steps ), numpy.double )
    for i in xrange( num_folds ):
        for j in xrange( num_steps ):
            scores[j] += numpy.sqrt( numpy.mean( cv_residuals[j][i]**2, axis = 0 ) )
    scores /= num_folds

    argmin = int( numpy.argmin( scores ) )
    
    # plot true error vs cross validation error

    if test_pts is not None:
        # plot cv errors
        tau = numpy.sum( numpy.absolute( out[0] ), axis = 0 )
        pylab.semilogy( tau, scores, 'bh-' )
        pylab.semilogy( tau[argmin], scores[argmin], 'gd', )

        # plot true errors
        pce.set_coefficients( out[0] )
        pred_vals = pce.evaluate_set( test_pts )
        print numpy.linalg.norm( pred_vals[:,30] - test_vals )
        errors = pred_vals - numpy.tile( test_vals.reshape(test_vals.shape[0], 1 ), ( 1, pred_vals.shape[1] ) )
        error_norms = numpy.sqrt( numpy.sum( errors**2, axis = 0 ) / 
                                  test_vals.shape[0] )
        error_argmin = numpy.argmin( error_norms )

        omp_tau, cv_error = OMP_brute_cv( pts, vals, pce, True )
        omp_argmin = numpy.argmin( cv_error )
        pylab.semilogy( omp_tau[::-1], cv_error[::-1], 'o-' )
        pylab.semilogy( [omp_tau[omp_argmin]], [cv_error[omp_argmin]], 'go' )
        print [omp_tau[omp_argmin]], [cv_error[omp_argmin]]
        print cv_error, omp_argmin

        #pylab.loglog( tau, out[1][0,:], 'o-' )
        pylab.semilogy( tau, error_norms, 'ks-' )
        pylab.semilogy( [tau[error_argmin]], [error_norms[error_argmin]], 'rd' )
        pylab.semilogy( [tau[argmin]], [error_norms[argmin]], 'gd' )
        print error_norms[error_argmin]
        pylab.show()

    
    best_indices = PolyIndexVector()
    best_indices.resize( argmin )
    for i in xrange( argmin ):
        best_indices[i] = indices[int(metrics[1,i])]
        best_indices[i].set_array_index( i )
    
    pce.set_basis_indices( best_indices )

    A = pce.build_vandermonde( pts )
    coeff = svd_solve( A, vals )
    pce.set_coefficients( coeff[0] )
def least_squares_modified( pts, vals, pce, max_degree ):
        
    num_pts = pts.shape[1]
    num_reps = 2
    num_dims = pts.shape[0]

    indices = PolyIndexVector()
    from indexing_cpp import get_hyperbolic_indices
    all_indices = PolyIndexVector()
    get_hyperbolic_indices( num_dims, max_degree, 1, all_indices )
    #for index in all_indices:
    #    print index

    num_basis_terms =  min( 2*num_pts, len( all_indices ) / ( num_reps /2 ) )
    #num_basis_terms =  len( all_indices )
    print num_basis_terms

    key_sorter = lambda x: x.level_sum()

    indices_dict = dict()
    for n in xrange( num_reps ):
        I = numpy.random.randint( 0, len( all_indices ), (num_basis_terms) )
        #I = numpy.arange( len( all_indices ) )
        indices = PolyIndexVector()
        indices.resize( num_basis_terms )
        for i in xrange( I.shape[0] ):
            indices[i] = all_indices[int(I[i])]
            indices[i].set_array_index( i )

        indices_list = sorted( indices, 
                          key = key_sorter )[::-1]

        for i in xrange( len( indices ) ):
            indices[i] = indices_list[i]
            indices[i].set_array_index( i )

        #for index in indices:
        #    print index, max_degree

        pce.set_basis_indices( indices )
            
        A = pce.build_vandermonde( pts )
        
        #coeff = svd_solve_default( A, vals )[0]

        #pce.set_coefficients( coeff.reshape( ( coeff.shape[0], 1 ) ) )

        from compressed_sensing_cpp import orthogonal_matching_pursuit
        sols, metrics = orthogonal_matching_pursuit( A, vals.squeeze(), 
                                                     0., 
                                                     num_pts,
                                                     0 )

        coeff = sols[:,-1].reshape( sols.shape[0], 1 )
            
        #for index in indices:
        for i in metrics[1,:]:
            index = indices[int(i)]
            if not indices_dict.has_key( index ):
                indices_dict[index] = [coeff[index.get_array_index(),0]]
            else:
                indices_dict[index].append( coeff[index.get_array_index(),0] )

    best_indices = PolyIndexVector()
    best_indices.resize( len( indices_dict ) )
    i = 0
    for index in indices_dict:
        best_indices[i] = index
        i += 1
    
    key_sorter = lambda x: ( numpy.absolute( numpy.asarray( indices_dict[x] ) ).mean() )

    best_indices = sorted( best_indices, 
                           key = key_sorter )[::-1]

    
    i = 0
    for index in best_indices:
        print index, key_sorter( index )
        i += 1
        print i
        if i >= num_pts: 
            print i
            break
        
        assert False