def yield_jobs():

        for i,n in enumerate(n_samples):
            
            Mphi, Mrew = BellmanBasis.get_mixing_matrices(n, lam, gam, sampled = True, eps = eps)

            for r in xrange(n_runs):
                
                # initialize features with unit norm
                theta_init = numpy.random.standard_normal((dim+1, k))
                if reward_init:
                    theta_init[:-1,-1] = m.R # XXX set last column to reward
                    theta_init[-1,-1] = 0
                theta_init /= numpy.sqrt((theta_init * theta_init).sum(axis=0))

                w_init = numpy.random.standard_normal((k+1,1)) 
                w_init = w_init / numpy.linalg.norm(w_init)

                # sample data: training, validation, and test sets
                S, Sp, R, _, = mdp.sample_grid_world(n, distribution = weighting); 
                S = numpy.vstack((S, Sp[-1,:]))
                S_val, Sp_val, R_val, _, = mdp.sample_grid_world(n, distribution = weighting)
                S_val = scipy.sparse.vstack((S_val, Sp_val[-1,:]))
                S_test, Sp_test, R_test, _, = mdp.sample_grid_world(n, distribution = weighting)
                S_test = scipy.sparse.vstack((S_test, Sp_test[-1,:]))
                
                bb = BellmanBasis(dim+1, k, beta_ratio, partition = partition, 
                    theta = theta_init, w = w_init, record_loss = losses, nonlin = nonlin)
                
                for j,tm in enumerate(training_methods):
                    
                    yield (condor_job,[(i,r,j), bb, m, tm, 
                            S, R, S_val, R_val, S_test, R_test,
                            Mphi, Mrew, patience, max_iter, weighting])
    def yield_jobs(): 
        
        for i,n in enumerate(n_samples or [n_states]):
            
            logger.info('creating job with %i samples/states' % n)
            
            # build bellman operator matrices
            logger.info('making mixing matrices')
            Mphi, Mrew = BellmanBasis.get_mixing_matrices(n, lam, gam, 
                                    sampled = bool(n_samples), eps = eps)
            
            for r in xrange(n_runs):

                n_features = encoder.n_features
                # initialize parameters
                theta_init = numpy.random.standard_normal((n_features, k))
                theta_init /= numpy.sqrt((theta_init * theta_init).sum(axis=0))
                w_init = numpy.random.standard_normal((k+1,1)) 
                w_init = w_init / numpy.linalg.norm(w_init)
                

                # sample or gather full info data
                X_data, R_data, weighting = sample(n) if n_samples else full_info()

                bb_params = [n_features, [k], beta_ratio]
                bb_dict = dict( alpha = alpha, reg_tuple = reg, nonlin = nonlin,
                                nonzero = nonzero, thetas = [theta_init])
        
                for j, tm in enumerate(training_methods):
                    loss_list, wrt_list = tm
                    assert len(loss_list) == len(wrt_list)
                    
                    run_param_values = [k, tm, encoder, n, 
                                n_reward_samples, n_reward_runs, 
                                env_size, weighting, 
                                lam, gam, alpha, eta, 
                                reg[0]+str(reg[1]) if reg else 'None',
                                nonlin if nonlin else 'None']

                    d_run_params = dict(izip(run_param_keys, run_param_values))
                     
                    yield (train_basis,[d_run_params, bb_params, bb_dict,
                                        env, m, losses, # environment, model and loss list
                                        X_data, R_data, Mphi, Mrew, # training data
                                        max_iter, patience, min_imp, min_delta, # optimization params 
                                        fldir, record_runs]) # recording params
def train_basis(d_run_params, basis_params, basis_dict, 
                env, model, losses, 
                S_data, R_data, Mphi, Mrew, 
                max_iter, patience, min_imp, min_delta, 
                fl_dir, record_runs):

    method, weighting, encoder, n_reward_samples, n_reward_runs, env_size = map(lambda x: d_run_params[x], 
                                      'method weighting encoding reward_samples reward_runs size'.split())
    logger.info('training basis using training method: %s' % str(method))

    S, S_val, S_test = S_data
    R, R_val, R_test = R_data

    n_rows = float(S.shape[0])

    # initialize loss dictionary
    d_loss_learning = {}
    for key in losses:
        d_loss_learning[key] = numpy.array([])
     
    loss_list, wrt_list = method
    assert len(loss_list) == len(wrt_list)

    logger.info('constructing bellman basis')
    basis = BellmanBasis(*basis_params, **basis_dict)
    
    def record_loss(d_loss):
        # TODO automate/shorten this
        # record losses with test set
        for loss, arr in d_loss.items():
            if loss == 'sample-reward':
                val = sample_policy_reward(env, basis.estimated_value(encoder.B, False), n_reward_samples, n_reward_runs)
            elif loss == 'test-training':
                val = basis.loss(basis.flat_params, S_test, R_test, Mphi, Mrew) / n_rows
            elif loss == 'test-bellman':
                val = basis.loss_be(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'test-lsbellman':
                val = basis.loss_lsbe(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'test-reward':
                val = basis.loss_r(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'test-model':
                val = basis.loss_m(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'test-fullmodel':
                val = basis.loss_fm(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'true-policy':
                val = model.policy_distance(Bs, weighting = 'policy')
            elif loss == 'true-policy-uniform':
                val = model.policy_distance(Bs, weighting = 'uniform')
            elif loss == 'true-bellman':
                val = model.bellman_error(Bs, w = basis.params[-1], weighting = weighting)
            elif loss == 'true-lsbellman':
                val = model.bellman_error(Bs, weighting = weighting)
            elif loss == 'true-reward':
                val = model.reward_error(Bs, weighting = weighting)
            elif loss == 'true-model':
                val = model.model_error(Bs, weighting = weighting)
            elif loss == 'true-fullmodel':
                val = model.fullmodel_error(Bs, weighting = weighting)
            elif loss == 'true-lsq':
                val = model.lsq_error(Bs, weighting = weighting)
            else: print loss; assert False
            
            d_loss[loss] = numpy.append(arr, val)
        return d_loss
    
    switch = [] # list of indices where a training method switch occurred
    it = 0
    
    # train once on w to initialize
    basis.set_loss('bellman', ['w'])
    basis.set_params(scipy.optimize.fmin_cg(
            basis.loss, basis.flat_params, basis.grad,
            args = (S, R, Mphi, Mrew),
            full_output = False,
            maxiter = max_iter,
            ))
    
    # TODO keep?
    IM = encoder.weights_to_basis(basis.thetas[-1])
    Bs = basis.encode(encoder.B, False)
    if len(basis.thetas) == 1:
        assert (IM == Bs).all()
    d_loss_learning = record_loss(d_loss_learning)

    for loss, wrt in zip(loss_list, wrt_list):
        
        waiting = 0
        best_params = None
        best_test_loss = 1e20
        
        if 'w' in wrt_list: # initialize w to the lstd soln given the current basis
            logger.info('initializing w to lstd soln')
            basis.params[-1] = BellmanBasis.lstd_weights(basis.encode(S), R, Mphi, Mrew) # TODO change to iteration of opt on w?
        
        try:
            while (waiting < patience):
                it += 1
                logger.info('*** iteration ' + str(it) + '***')
                
                old_params = copy.deepcopy(basis.flat_params)
                for loss_, wrt_ in ((loss, wrt), ('bellman', ['w'])):
                    basis.set_loss(loss_, wrt_)
                    basis.set_params(scipy.optimize.fmin_cg(
                            basis.loss, basis.flat_params, basis.grad,
                            args = (S, R, Mphi, Mrew),
                            full_output = False,
                            maxiter = max_iter,
                            ))
                basis.set_loss(loss, wrt) # reset loss back from bellman
                 
                delta = numpy.linalg.norm(old_params-basis.flat_params)
                logger.info('delta theta: %.2f' % delta)
                
                norms = numpy.apply_along_axis(numpy.linalg.norm, 0, basis.thetas[0])
                logger.info( 'column norms: %.2f min / %.2f avg / %.2f max' % (
                    norms.min(), norms.mean(), norms.max()))
                
                err = basis.loss(basis.flat_params, S_val, R_val, Mphi, Mrew)
                
                if err < best_test_loss:
                    
                    if ((best_test_loss - err) / best_test_loss > min_imp) and (delta > min_delta):
                        waiting = 0
                    else:
                        waiting += 1
                        logger.info('iters without better %s loss: %i' % (basis.loss_type, int(waiting)))

                    best_test_loss = err
                    best_params = copy.deepcopy(basis.flat_params)
                    logger.info('new best %s loss: %.2f' % (basis.loss_type, best_test_loss))
                    
                else:
                    waiting += 1
                    logger.info('iters without better %s loss: %i' % (basis.loss_type, int(waiting)))

                Bs = basis.encode(encoder.B, False)
                d_loss_learning = record_loss(d_loss_learning)

        except KeyboardInterrupt:
            logger.info( '\n user stopped current training loop')
        
        # set params to best params from last loss
        basis.set_params(vec = best_params)
        switch.append(it-1)
    
    sparse_eps = 1e-5
    Bs = basis.encode(encoder.B, False)
    d_loss_learning = record_loss(d_loss_learning)
    logger.info( 'final test bellman error: %.2f' % model.bellman_error(Bs, weighting = weighting))
    logger.info( 'final sparsity: ' + str( [(numpy.sum(abs(th) < sparse_eps) / float(len(th.flatten()))) for th in basis.params]))

    # edit d_run_params to not include wrt list in method
    d_run_params['method'] = '-'.join(d_run_params['method'][0])
    
    # TODO change to output log file and plot afterwards
    if record_runs:
        
        # save results!
        # plot basis functions
        plot_stacked_features(Bs[:, :36])
        figst = out_string(fl_dir+'sirf/output/plots/learning/', 'basis_stacked', d_run_params, '.pdf')
        plt.savefig(figst)

        # plot the basis functions again!
        plot_features(Bs)
        figst = out_string(fl_dir+'sirf/output/plots/learning/', 'basis_all', d_run_params, '.pdf')
        plt.savefig(figst)
        
        # plot learning curves
        pltd = plot_learning_curves(d_loss_learning, switch, filt = 'test')
        if pltd: # if we actually plotted
            plt.savefig(out_string(fl_dir+'sirf/output/plots/learning/', 'test_loss', d_run_params, '.pdf') )
        pltd = plot_learning_curves(d_loss_learning, switch, filt = 'true')
        if pltd:
            plt.savefig(out_string(fl_dir+'sirf/output/plots/learning/', 'true_loss', d_run_params, '.pdf'))        
        
        # plot value functions
        plot_value_functions(env_size, model, Bs)
        plt.savefig(out_string(fl_dir+'sirf/output/plots/learning/', 'value_funcs', d_run_params, '.pdf'))

        # plot spectrum of reward and features
        gen_spectrum(Bs, model.P, model.R)
        plt.savefig(out_string(fl_dir+'sirf/output/plots/learning/', 'spectrum', d_run_params, '.pdf'))
    
    d_loss_batch = dict(izip(d_loss_learning.keys(), map(lambda x: x[-1], 
                                                    d_loss_learning.values())))
    # returns keys_array, values_array
    return reorder_columns(d_run_params, d_loss_batch)