def train_basis(d_run_params, basis_params, basis_dict, 
                env, model, losses, 
                S_data, R_data, Mphi, Mrew, 
                max_iter, patience, min_imp, min_delta, 
                fl_dir, record_runs):

    method, weighting, encoder, n_reward_samples, n_reward_runs, env_size = map(lambda x: d_run_params[x], 
                                      'method weighting encoding reward_samples reward_runs size'.split())
    logger.info('training basis using training method: %s' % str(method))

    S, S_val, S_test = S_data
    R, R_val, R_test = R_data

    n_rows = float(S.shape[0])

    # initialize loss dictionary
    d_loss_learning = {}
    for key in losses:
        d_loss_learning[key] = numpy.array([])
     
    loss_list, wrt_list = method
    assert len(loss_list) == len(wrt_list)

    logger.info('constructing bellman basis')
    basis = BellmanBasis(*basis_params, **basis_dict)
    
    def record_loss(d_loss):
        # TODO automate/shorten this
        # record losses with test set
        for loss, arr in d_loss.items():
            if loss == 'sample-reward':
                val = sample_policy_reward(env, basis.estimated_value(encoder.B, False), n_reward_samples, n_reward_runs)
            elif loss == 'test-training':
                val = basis.loss(basis.flat_params, S_test, R_test, Mphi, Mrew) / n_rows
            elif loss == 'test-bellman':
                val = basis.loss_be(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'test-lsbellman':
                val = basis.loss_lsbe(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'test-reward':
                val = basis.loss_r(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'test-model':
                val = basis.loss_m(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'test-fullmodel':
                val = basis.loss_fm(*(basis.params + [S_test, R_test, Mphi, Mrew])) / n_rows
            elif loss == 'true-policy':
                val = model.policy_distance(Bs, weighting = 'policy')
            elif loss == 'true-policy-uniform':
                val = model.policy_distance(Bs, weighting = 'uniform')
            elif loss == 'true-bellman':
                val = model.bellman_error(Bs, w = basis.params[-1], weighting = weighting)
            elif loss == 'true-lsbellman':
                val = model.bellman_error(Bs, weighting = weighting)
            elif loss == 'true-reward':
                val = model.reward_error(Bs, weighting = weighting)
            elif loss == 'true-model':
                val = model.model_error(Bs, weighting = weighting)
            elif loss == 'true-fullmodel':
                val = model.fullmodel_error(Bs, weighting = weighting)
            elif loss == 'true-lsq':
                val = model.lsq_error(Bs, weighting = weighting)
            else: print loss; assert False
            
            d_loss[loss] = numpy.append(arr, val)
        return d_loss
    
    switch = [] # list of indices where a training method switch occurred
    it = 0
    
    # train once on w to initialize
    basis.set_loss('bellman', ['w'])
    basis.set_params(scipy.optimize.fmin_cg(
            basis.loss, basis.flat_params, basis.grad,
            args = (S, R, Mphi, Mrew),
            full_output = False,
            maxiter = max_iter,
            ))
    
    # TODO keep?
    IM = encoder.weights_to_basis(basis.thetas[-1])
    Bs = basis.encode(encoder.B, False)
    if len(basis.thetas) == 1:
        assert (IM == Bs).all()
    d_loss_learning = record_loss(d_loss_learning)

    for loss, wrt in zip(loss_list, wrt_list):
        
        waiting = 0
        best_params = None
        best_test_loss = 1e20
        
        if 'w' in wrt_list: # initialize w to the lstd soln given the current basis
            logger.info('initializing w to lstd soln')
            basis.params[-1] = BellmanBasis.lstd_weights(basis.encode(S), R, Mphi, Mrew) # TODO change to iteration of opt on w?
        
        try:
            while (waiting < patience):
                it += 1
                logger.info('*** iteration ' + str(it) + '***')
                
                old_params = copy.deepcopy(basis.flat_params)
                for loss_, wrt_ in ((loss, wrt), ('bellman', ['w'])):
                    basis.set_loss(loss_, wrt_)
                    basis.set_params(scipy.optimize.fmin_cg(
                            basis.loss, basis.flat_params, basis.grad,
                            args = (S, R, Mphi, Mrew),
                            full_output = False,
                            maxiter = max_iter,
                            ))
                basis.set_loss(loss, wrt) # reset loss back from bellman
                 
                delta = numpy.linalg.norm(old_params-basis.flat_params)
                logger.info('delta theta: %.2f' % delta)
                
                norms = numpy.apply_along_axis(numpy.linalg.norm, 0, basis.thetas[0])
                logger.info( 'column norms: %.2f min / %.2f avg / %.2f max' % (
                    norms.min(), norms.mean(), norms.max()))
                
                err = basis.loss(basis.flat_params, S_val, R_val, Mphi, Mrew)
                
                if err < best_test_loss:
                    
                    if ((best_test_loss - err) / best_test_loss > min_imp) and (delta > min_delta):
                        waiting = 0
                    else:
                        waiting += 1
                        logger.info('iters without better %s loss: %i' % (basis.loss_type, int(waiting)))

                    best_test_loss = err
                    best_params = copy.deepcopy(basis.flat_params)
                    logger.info('new best %s loss: %.2f' % (basis.loss_type, best_test_loss))
                    
                else:
                    waiting += 1
                    logger.info('iters without better %s loss: %i' % (basis.loss_type, int(waiting)))

                Bs = basis.encode(encoder.B, False)
                d_loss_learning = record_loss(d_loss_learning)

        except KeyboardInterrupt:
            logger.info( '\n user stopped current training loop')
        
        # set params to best params from last loss
        basis.set_params(vec = best_params)
        switch.append(it-1)
    
    sparse_eps = 1e-5
    Bs = basis.encode(encoder.B, False)
    d_loss_learning = record_loss(d_loss_learning)
    logger.info( 'final test bellman error: %.2f' % model.bellman_error(Bs, weighting = weighting))
    logger.info( 'final sparsity: ' + str( [(numpy.sum(abs(th) < sparse_eps) / float(len(th.flatten()))) for th in basis.params]))

    # edit d_run_params to not include wrt list in method
    d_run_params['method'] = '-'.join(d_run_params['method'][0])
    
    # TODO change to output log file and plot afterwards
    if record_runs:
        
        # save results!
        # plot basis functions
        plot_stacked_features(Bs[:, :36])
        figst = out_string(fl_dir+'sirf/output/plots/learning/', 'basis_stacked', d_run_params, '.pdf')
        plt.savefig(figst)

        # plot the basis functions again!
        plot_features(Bs)
        figst = out_string(fl_dir+'sirf/output/plots/learning/', 'basis_all', d_run_params, '.pdf')
        plt.savefig(figst)
        
        # plot learning curves
        pltd = plot_learning_curves(d_loss_learning, switch, filt = 'test')
        if pltd: # if we actually plotted
            plt.savefig(out_string(fl_dir+'sirf/output/plots/learning/', 'test_loss', d_run_params, '.pdf') )
        pltd = plot_learning_curves(d_loss_learning, switch, filt = 'true')
        if pltd:
            plt.savefig(out_string(fl_dir+'sirf/output/plots/learning/', 'true_loss', d_run_params, '.pdf'))        
        
        # plot value functions
        plot_value_functions(env_size, model, Bs)
        plt.savefig(out_string(fl_dir+'sirf/output/plots/learning/', 'value_funcs', d_run_params, '.pdf'))

        # plot spectrum of reward and features
        gen_spectrum(Bs, model.P, model.R)
        plt.savefig(out_string(fl_dir+'sirf/output/plots/learning/', 'spectrum', d_run_params, '.pdf'))
    
    d_loss_batch = dict(izip(d_loss_learning.keys(), map(lambda x: x[-1], 
                                                    d_loss_learning.values())))
    # returns keys_array, values_array
    return reorder_columns(d_run_params, d_loss_batch)