Esempio n. 1
0
def main(params,
         snap_coord=None,
         resolution=30,
         n_sizes=5,
         max_features=None,
         n_jobs=1):
    t0 = time.time()

    inputs, df_var = stem.read_params(params)

    # Convert params to named variables and check for required vars
    for i in inputs:
        exec("{0} = str({1})").format(i, inputs[i])

    try:
        sets_per_cell = int(sets_per_cell)
        cell_size = [int(s) for s in cell_size.split(',')]
        min_size = int(min_size)
        max_size = int(max_size)
    except NameError as e:
        missing_var = str(e).split("'")[1]
        msg = "Variable '%s' not specified in param file:\n%s" % (missing_var,
                                                                  params)
        raise NameError(msg)

    # Read in training samples and check that df_train has exactly the same
    #   columns as variables specified in df_vars
    df_train = pd.read_csv(sample_txt, sep='\t')
    n_samples = len(df_train)
    unmatched_vars = [
        v for v in df_var.index if v not in [c for c in df_train]
    ]
    if len(unmatched_vars) != 0:
        unmatched_str = '\n\t'.join(unmatched_vars)
        msg = 'Columns not in sample_txt but specified in params:\n\t' + unmatched_str
        import pdb
        pdb.set_trace()
        raise NameError(msg)
    if target_col not in df_train.columns:
        raise NameError('target_col "%s" not in sample_txt: %s' %
                        (target_col, sample_txt))
    if 'max_target_val' in inputs:
        max_target_val = int(max_target_val)
    else:
        max_target_val = df_train[target_col].max()
    if 'n_jobs' in inputs:
        n_jobs = int(n_jobs)

    predict_cols = sorted(
        np.unique(
            [c for c in df_train.columns for v in df_var.index if v in c]))
    df_var = df_var.reindex(df_var.index.sort_values(
    ))  # Make sure predict_cols and df_var are in the same order

    if snap_coord:
        snap_coord = [int(c) for c in snap_coord.split(',')]

    t1 = time.time()
    if model_type.lower() == 'classifier':
        model_func = stem.fit_tree_classifier
    else:
        model_func = stem.fit_tree_regressor

    # Make grid
    x_res = resolution
    y_res = -resolution
    tx, extent = stem.tx_from_shp(mosaic_path,
                                  x_res,
                                  y_res,
                                  snap_coord=snap_coord)
    min_x, max_x, min_y, max_y = [int(i) for i in extent]
    cells = stem.generate_gsrd_grid(cell_size, min_x, min_y, max_x, max_y,
                                    x_res, y_res)
    grid = pd.DataFrame(cells, columns=['ul_x', 'ul_y', 'lr_x', 'lr_y'])
    grid.to_csv(out_txt.replace('.txt', '_grid.txt'))
    #import pdb; pdb.set_trace()
    grid = intersecting_cells(grid, mosaic_path)
    stem.coords_to_shp(grid, '/vol/v2/stem/extent_shp/CAORWA.shp',
                       out_txt.replace('.txt', '_grid.shp'))

    if 'set_sizes' in inputs:
        set_sizes = np.sort([int(s) for s in set_sizes.split(',')])
    else:
        if 'n_sizes' in inputs:
            n_sizes = int(n_sizes)
        set_sizes = np.arange(min_size, max_size + 1,
                              (max_size - min_size) / n_sizes)

    # Sample grid
    dfs = []
    for i, cell in grid.iterrows():
        ul_x, ul_y, lr_x, lr_y = cell
        min_x, max_x = min(ul_x, lr_x), max(ul_x, lr_x)
        min_y, max_y = min(ul_y, lr_y), max(ul_y, lr_y)

        # Calculate support set centers
        x_centers = [
            int(stem.snap_coordinate(x, snap_coord[0], x_res))
            for x in random.sample(xrange(min_x, max_x + 1), sets_per_cell)
        ]
        y_centers = [
            int(stem.snap_coordinate(y, snap_coord[1], y_res))
            for y in random.sample(xrange(min_y, max_y + 1), sets_per_cell)
        ]

        for size in set_sizes:
            df = stem.sample_gsrd_cell(sets_per_cell,
                                       cell,
                                       size,
                                       size,
                                       x_res,
                                       y_res,
                                       tx,
                                       snap_coord,
                                       center_coords=(zip(
                                           x_centers, y_centers)))
            df['set_size'] = size
            df['cell_id'] = i
            dfs.append(df)

    support_sets = pd.concat(dfs, ignore_index=True)
    n_sets = len(support_sets)
    #import pdb; pdb.set_trace()
    print 'Testing set sizes with %s jobs...\n' % n_jobs
    oob_metrics = _par_train_estimator(n_jobs, n_sets, df_train, predict_cols,
                                       target_col, support_sets, model_func,
                                       model_type, max_features,
                                       max_target_val)
    '''args = [[i, n_sets, start_time, df_train, predict_cols, target_col, support_set, model_func, model_type, max_features, max_target_val] for i, (si, support_set) in enumerate(support_sets.ix[:100].iterrows())]
    oob_metrics = []
    for arg in args:
        oob_metrics.append(par_train_estimator(arg))'''

    oob_metrics = pd.DataFrame(oob_metrics)
    oob_metrics.set_index('set_id', inplace=True)
    support_sets = pd.merge(support_sets,
                            oob_metrics,
                            left_index=True,
                            right_index=True)
    #import pdb; pdb.set_trace()
    support_sets.to_csv(out_txt)
Esempio n. 2
0
def main(params,
         inventory_txt=None,
         constant_vars=None,
         mosaic_shp=None,
         resolution=30,
         n_jobs=0,
         n_jobs_agg=0,
         mosaic_nodata=0,
         snap_coord=None,
         overwrite_tiles=False,
         tile_id_field='name'):
    inputs = stem.read_params(params)
    for i in inputs:
        exec("{0} = str({1})").format(i, inputs[i])
    df_var = pd.read_csv(var_info, sep='\t', index_col='var_name')
    df_var.data_band = [int(b)
                        for b in df_var.data_band]  #sometimes read as float

    try:
        support_size = [int(i) for i in support_size.split(',')]
        nodata = int(nodata)
        str_check = model_dir, mosaic_path, out_dir, train_params
    except NameError as e:
        missing_var = str(e).split("'")[1]
        msg = "Variable '%s' not specified in param file:\n%s" % (missing_var,
                                                                  params)
        raise NameError(msg)

    # Check that all the variables given were used in training and vice versa
    try:
        train_inputs = stem.read_params(train_params)
    except:
        raise NameError('train_params not specified or does not exist')
    train_vars = pd.read_csv(train_inputs['var_info'].replace('"', ''),
                             sep='\t',
                             index_col='var_name')
    train_vars = sorted(train_vars.index)
    pred_vars = sorted(df_var.index)
    # Make sure vars are sorted alphabetically since they were for training
    df_var = df_var.reindex(pred_vars)

    unmatched_vars = [v for v in pred_vars if v not in train_vars]
    if len(unmatched_vars) != 0:
        unmatched_str = '\n'.join(unmatched_vars)
        msg = 'Columns not in predict params but specified in train params:\n' + unmatched_str
        raise NameError(msg)

    if not os.path.exists(out_dir): os.mkdir(out_dir)
    else:        print ('WARNING: out_dir already exists:\n%s\nAny existing files ' + \
  'will be overwritten...\n') % out_dir
    if not os.path.exists(os.path.join(out_dir, os.path.basename(params))):
        shutil.copy2(params, out_dir)  #Copy the params for reference

    if 'confusion_params' in inputs:
        conf_bn = os.path.basename(confusion_params)
        new_conf_path = os.path.join(out_dir, conf_bn)
        if not os.path.exists(new_conf_path):
            shutil.copy2(confusion_params, out_dir)
        confusion_params = new_conf_path

    if overwrite_tiles.lower() == 'false':
        overwrite_tiles = False

    if not os.path.exists(model_dir):
        sys.exit('model_dir does not exist:\n%s' % model_dir)
    if not os.path.exists(mosaic_path):
        sys.exit('mosaic_path does not exist:\n%s' % mosaic_path)

    if not 'file_stamp' in inputs: file_stamp = os.path.basename(model_dir)
    db_path = os.path.join(model_dir, os.path.basename(model_dir) + '.db')
    if os.path.exists(db_path):
        engine = sqlalchemy.create_engine('sqlite:///%s' % db_path)
        with engine.connect() as con, con.begin():
            df_sets = pd.read_sql_table('support_sets',
                                        con,
                                        index_col='set_id')  #'''
    else:
        set_txt = stem.find_file(model_dir, '*support_sets.txt')
        if not os.path.isfile(set_txt):
            raise IOError('No database or support set txt file found')
        df_sets = pd.read_csv(set_txt, sep='\t', index_col='set_id')

    if mosaic_path.endswith('.shp'):
        mosaic_type = 'vector'
        # if subset specified, clip the mosaic and set mosaic path to clipped shp
        if 'subset_shp' in inputs:
            out_shp_bn = os.path.basename(mosaic_path).replace(
                '.shp', '_clipped.shp')
            out_shp = os.path.join(out_dir, out_shp_bn)
            cmd = 'ogr2ogr -clipsrc {clip_shp} {out_shp} {in_shp}'.format(
                clip_shp=subset_shp, out_shp=out_shp, in_shp=mosaic_path)
            subprocess.call(cmd, shell=True)  #'''
            mosaic_path = out_shp
        mosaic_dataset = ogr.Open(mosaic_path, 1)
        mosaic_ds = mosaic_dataset.GetLayer()
        min_x, max_x, min_y, max_y = mosaic_ds.GetExtent()
        if 'resolution' not in inputs:
            warnings.warn('Resolution not specified. Using default of 30...\n')
        # If subset specified, just get sets that overlap the subset
        if 'subset_shp' in inputs:
            mosaic_geom = ogr.Geometry(ogr.wkbMultiPolygon)
            i = 0
            for feature in mosaic_ds:
                g = feature.GetGeometryRef()
                # Check that the feature is valid. Clipping can produce a feautre
                #  w/ an area of 0
                if g.GetArea() > 1:
                    mosaic_geom.AddGeometry(g)
                else:
                    fid = feature.GetFID()
                    feature.Destroy()
                    mosaic_ds.DeleteFeature(fid)
            #import pdb; pdb.set_trace()
            df_sets = stem.get_overlapping_sets(df_sets,
                                                mosaic_geom.UnionCascaded())
        xsize = int((max_x - min_x) / resolution)
        ysize = int((max_y - min_y) / resolution)
        prj = mosaic_ds.GetSpatialRef().ExportToWkt()
        x_res = resolution
        y_res = -resolution
        x_rot = 0
        y_rot = 0
        if 'snap_coord' in train_inputs:
            snap_coord = train_inputs['snap_coord'].replace('"', '')
            snap_coord = [float(c) for c in snap_coord.split(',')]  #'''
        mosaic_tx, extent = stem.tx_from_shp(mosaic_path,
                                             x_res,
                                             y_res,
                                             snap_coord=snap_coord)
        tiles = stem.attributes_to_df(
            mosaic_path)  # Change to accept arbittary geometry

    else:
        mosaic_type = 'raster'
        mosaic_ds = gdal.Open(mosaic_path)
        mosaic_tx = mosaic_ds.GetGeoTransform()
        xsize = mosaic_ds.RasterXSize
        ysize = mosaic_ds.RasterYSize
        prj = mosaic_ds.GetProjection()
        driver = mosaic_ds.GetDriver()
        m_ulx, x_res, x_rot, m_uly, y_rot, y_res = mosaic_tx
    #driver = gdal.GetDriverByName('gtiff')

    # If number of tiles not given, need to set it
    if 'n_tiles' not in inputs:
        print 'n_tiles not specified. Using default: 90 x 40 ...\n'
        n_tiles = 90, 40
    else:
        n_tiles = [int(i) for i in n_tiles.split(',')]
    #df_tiles, df_tiles_rc, tile_size = stem.get_tiles(n_tiles, xsize, ysize, mosaic_tx)

    total_sets = len(df_sets)
    t0 = time.time()
    last_dts = pd.Series()
    agg_stats = [s.strip().lower() for s in agg_stats.split(',')]
    n_jobs = int(n_jobs)
    tile_dir = os.path.join(out_dir, '_temp_tiles')
    #tile_dir = '/home/server/pi/homes/shooper/delete_test'
    if not os.path.isdir(tile_dir):
        os.mkdir(tile_dir)
    tile_path_template = os.path.join(tile_dir, 'tile_{tile_id}_%(stat)s.tif')
    n_tiles = len(tiles)

    if not overwrite_tiles:
        files = os.listdir(tile_dir)
        tile_files = pd.DataFrame(columns=agg_stats,
                                  index=tiles[tile_id_field])
        for stat in agg_stats:
            pattern = re.compile('tile_\d+_%s.tif' % stat)
            stat_match = [f.split('_')[1] for f in files if pattern.match(f)]
            try:
                tile_files[stat] = pd.Series(np.ones(len(stat_match)),
                                             index=stat_match)
            except:
                pass  #import pdb; pdb.set_trace()
        index_field = tiles.index.name
        tiles[index_field] = tiles.index
        tiles = tiles.set_index(tile_id_field, drop=False)
        tiles.set_index(index_field, inplace=True)  #'''
    tiles['ul_x'] = [
        stem.get_ul_coord(xmin, xmax, x_res)
        for i, (xmin, xmax) in tiles[['xmin', 'xmax']].iterrows()
    ]
    tiles['ul_y'] = [
        stem.get_ul_coord(ymin, ymax, y_res)
        for i, (ymin, ymax) in tiles[['ymin', 'ymax']].iterrows()
    ]
    tiles['lr_x'] = [
        xmax if ulx == xmin else xmin
        for i, (ulx, xmin, xmax) in tiles[['ul_x', 'xmin', 'xmin']].iterrows()
    ]
    tiles['lr_y'] = [
        ymax if uly == ymin else ymin
        for i, (uly, ymin, ymax) in tiles[['ul_y', 'ymin', 'ymin']].iterrows()
    ]

    support_nrows = int(support_size[0] / abs(y_res))
    support_ncols = int(support_size[1] / abs(x_res))
    t1 = time.time()

    # Patch for unknown Landcover screwup
    args = [(i + 1, n_tiles, t1, tile_info, mosaic_path, mosaic_tx, df_sets,
             df_var, (support_nrows, support_ncols), agg_stats,
             tile_path_template, prj, nodata, snap_coord)
            for i, (t_ind,
                    tile_info) in enumerate(tiles.loc[tiles['name'].isin([
                        '1931', '2810', '0705', '0954', '2814', '1986', '2552',
                        '2019', '2355', '3354', '2278', '2559'
                    ])].iterrows())]

    args = [(i + 1, n_tiles, t1, tile_info, mosaic_path, mosaic_tx, df_sets,
             df_var, (support_nrows, support_ncols), agg_stats,
             tile_path_template, prj, nodata, snap_coord)
            for i, (t_ind, tile_info) in enumerate(tiles.loc[
                tiles['name'].isin(['0705'])].iterrows())]

    # Patch for the GEE subset 2 outside-of-buffer 'slice'
    #args = [(i + 1, n_tiles, t1, tile_info, mosaic_path, mosaic_tx, df_sets, df_var, (support_nrows, support_ncols), agg_stats, tile_path_template, prj, nodata, snap_coord) for i, (t_ind, tile_info) in enumerate(tiles.loc[tiles['name'].isin(['0639','0718','0797','0876','0955','1034'])].iterrows())]

    # Original line
    #args = [(i + 1, n_tiles, t1, tile_info, mosaic_path, mosaic_tx, df_sets, df_var, (support_nrows, support_ncols), agg_stats, tile_path_template, prj, nodata, snap_coord) for i, (t_ind, tile_info) in enumerate(tiles.loc[tile_files.isnull().any(axis=1).values].iterrows())]

    limits = []

    for arg in args:
        print tile_info[tile_id_field]
        limits.append(stem.par_predict_tile(arg))  #'''

    ###

    return
    print '\n\nFinished predicting in %.1f hours. \n\nStitching tiles...' % (
        (time.time() - t1) / 3600)

    try:
        limits = pd.concat(limits)
    except:
        # They're all None
        pass

    t1 = time.time()
    mosaic_ul = mosaic_tx[0], mosaic_tx[3]
    driver = gdal.GetDriverByName('gtiff')
    for stat in agg_stats:
        #dtype = mosaic.get_min_numpy_dtype(limits[stat])
        dtype = np.int16
        if stat == 'stdv':
            this_nodata = -9999
            ar = np.full((ysize, xsize), this_nodata, dtype=np.int16)  #dtype)
        else:
            this_nodata = nodata
            ar = np.full((ysize, xsize), this_nodata, dtype=dtype)

        for tile_id, tile_coords in tiles.iterrows():
            tile_file = os.path.join(
                tile_dir,
                'tile_%s_%s.tif' % (tile_coords[tile_id_field], stat))
            try:
                ds = gdal.Open(tile_file)
            except:
                print 'Tile not found'
                continue
            tile_tx = ds.GetGeoTransform()
            tile_ul = tile_tx[0], tile_tx[3]
            row_off, col_off = stem.calc_offset(mosaic_ul, tile_ul, mosaic_tx)
            # Make sure the tile doesn't exceed the size of ar
            tile_rows = min(ds.RasterYSize + row_off, ysize) - row_off
            tile_cols = min(ds.RasterXSize + col_off, xsize) - col_off
            ar_tile = ds.ReadAsArray(0, 0, tile_cols, tile_rows)
            try:
                ar[row_off:row_off + tile_rows,
                   col_off:col_off + tile_cols] = ar_tile
            except Exception as e:
                pass  #import pdb; pdb.set_trace()

        out_path = os.path.join(out_dir, '%s_%s.tif' % (file_stamp, stat))
        #out_path = os.path.join('/home/server/pi/homes/shooper/delete_test', '%s_%s.tif' % (file_stamp, stat))
        gdal_dtype = gdal_array.NumericTypeCodeToGDALTypeCode(ar.dtype)
        mosaic.array_to_raster(ar,
                               mosaic_tx,
                               prj,
                               driver,
                               out_path,
                               gdal_dtype,
                               nodata=this_nodata)

    # Clean up the tiles
    #shutil.rmtree(tile_dir)
    print 'Time for stitching: %.1f minutes\n' % ((time.time() - t1) / 60)

    # Get feature importances and max importance per set
    t1 = time.time()
    print 'Getting importance values...'
    importance_cols = sorted([c for c in df_sets.columns if 'importance' in c])
    df_sets['max_importance'] = nodata
    if len(importance_cols) == 0:
        # Loop through and get importance
        importance_per_var = []
        for s, row in df_sets.iterrows():
            with open(row.dt_file, 'rb') as f:
                dt_model = pickle.load(f)
            max_importance, this_importance = stem.get_max_importance(dt_model)
            df_sets.ix[s, 'max_importance'] = max_importance
            importance_per_var.append(this_importance)
        importance = np.array(importance_per_var).mean(axis=0)
    else:
        df_sets['max_importance'] = np.argmax(df_sets[importance_cols].values,
                                              axis=1)
        importance = df_sets[importance_cols].mean(axis=0).values
    pct_importance = importance / importance.sum()
    print '%.1f minutes\n' % ((time.time() - t1) / 60)

    # Save the importance values
    importance = pd.DataFrame({
        'variable': pred_vars,
        'pct_importance': pct_importance,
        'index': range(len(pred_vars))
    })
    importance.set_index('index', inplace=True)
    importance['rank'] = [
        int(r) for r in importance.pct_importance.rank(method='first',
                                                       ascending=False)
    ]
    out_txt = os.path.join(out_dir, '%s_importance.txt' % file_stamp)
    importance.to_csv(out_txt, sep='\t')  #'''

    print '\nTotal prediction runtime: %.1f hours\n' % (
        (time.time() - t0) / 3600)
def main(n_tiles,
         tile_path=None,
         add_field=True,
         out_path=None,
         snap=True,
         clip=True):

    try:
        if add_field.lower() == 'false':
            add_field = False
    except:
        pass
    try:
        if snap.lower() == 'false':
            snap = False
    except:
        pass

    if tile_path is None:
        tile_path = TILE_PATH

    if not os.path.exists(tile_path):
        raise RuntimeError('tile_path does not exist: %s' % tile_path)

    try:
        n_tiles = tuple([int(i) for i in n_tiles.split(',')])
    except:
        raise ValueError(
            'Could not parse n_tiles %s. It must be given as "n_tiles, n_x_tiles"'
            % n_tiles)

    # Get processing tiles
    tx, (xmin, xmax, ymin, ymax) = tx_from_shp(tile_path, XRES, YRES)
    xsize = abs(int(xmax - xmin) / XRES)
    ysize = abs(int(ymax - ymin) / YRES)
    tiles, _, _ = get_tiles(n_tiles, xsize, ysize, tx=tx)
    tile_id_field = 'eetile%sx%s' % n_tiles
    tiles[tile_id_field] = tiles.index

    if snap:
        coords, _ = get_coords(tile_path, multipart='split')
        coords = np.array(coords)  #shape is (nfeatures, ncoords, 2)
        xcoords = np.unique(coords[:, :, 0])
        ycoords = np.unique(coords[:, :, 1])
        for i, processing_coords in tiles.iterrows():
            tiles.loc[i, 'ul_x'] = xcoords[np.argmin(
                np.abs(xcoords - processing_coords.ul_x))]
            tiles.loc[i, 'lr_x'] = xcoords[np.argmin(
                np.abs(xcoords - processing_coords.lr_x))]
            tiles.loc[i, 'ul_y'] = ycoords[np.argmin(
                np.abs(ycoords - processing_coords.ul_y))]
            tiles.loc[i, 'lr_y'] = ycoords[np.argmin(
                np.abs(ycoords - processing_coords.lr_y))]

    if not out_path:
        out_path = os.path.join(OUT_DIR,
                                'ee_processing_tiles_%sx%s.shp' % n_tiles)
    coords_to_shp(tiles, tile_path, out_path)
    descr = ('Tiles for processing data on Google Earth Engine. The tiles ' +
            'have %s row(s) and %s col(s) and are bounded by the extent of %s') %\
            (n_tiles[0], n_tiles[1], tile_path)
    '''if clip:
        ds = ogr.Open(tile_path)
        lyr = ds.GetLayer()
        geoms = ogr.Geometry(ogr.wkbMultiPolygon)
        for feature in lyr:
            g = feature.GetGeometryRef()
            geoms.AddGeometry(g)
        union = geoms.UnionCascaded()
        base_path, ext = os.path.splitext(tile_path)
        temp_file = tile_path.replace(ext, '_uniontemp' + ext)
        feature'''

    createMetadata(sys.argv, out_path, description=descr)
    print '\nNew processing tiles written to', out_path

    # Find which features processing tile touches which each CONUS storage tile
    #   use get_overallping_sets() to find which
    # Read in the CONUS storage tiles
    if add_field:
        conus_tiles = attributes_to_df(tile_path)

        # Make a temporary copy of it
        base_path, ext = os.path.splitext(tile_path)
        temp_file = tile_path.replace(ext, '_temp' + ext)
        df_to_shp(conus_tiles, tile_path, temp_file, copy_fields=False)

        # Loop through each processing tile and find all overlapping
        conus_tiles[tile_id_field] = -1
        ds = ogr.Open(tile_path)
        lyr = ds.GetLayer()
        for p_fid, processing_coords in tiles.iterrows():
            wkt = 'POLYGON (({0} {1}, {2} {1}, {2} {3}, {0} {3}, {0} {1}))'.format(
                processing_coords.ul_x, processing_coords.ul_y,
                processing_coords.lr_x, processing_coords.lr_y)
            p_geom = ogr.CreateGeometryFromWkt(wkt)
            p_geom.CloseRings()
            for c_fid in conus_tiles.index:
                feature = lyr.GetFeature(c_fid)
                geom = feature.GetGeometryRef()
                if geom.Intersection(p_geom).GetArea() > 0:
                    conus_tiles.loc[c_fid, tile_id_field] = p_fid
        lyr, feature = None, None

        # re-write the CONUS tiles shapefile with the new field
        df_to_shp(conus_tiles, tile_path, tile_path, copy_fields=False)

        # delete temporary file
        driver = ds.GetDriver()
        driver.DeleteDataSource(temp_file)
        ds = None
        print '\nField with processing tile ID added to', tile_path

        # if the metadata text file exists, add a line about appending the field.
        #   otherwise, make a new metadata file.
        meta_file = tile_path.replace(ext, '_meta.txt')
        if os.path.exists(meta_file):
            with open(meta_file, 'a') as f:
                f.write(
                    '\n\nAppended field %s with IDs from the overlapping feature of %s'
                    % (tile_id_field, out_path))
        else:
            descr = 'Tile system with appended field %s with IDs from the overlapping feature of %s' % (
                tile_id_field, out_path)
            createMetadata(sys.argv, tile_path, description=descr)
Esempio n. 4
0
def main(txt, n_sample, out_txt, bins, train_params, by_psu=True, extract_predictors=True):
    
    n_sample = int(n_sample) 
    bins = parse_bins(bins)
    
    df = pd.read_csv(txt, sep='\t', dtype={'tile_id': object})
    sample = pd.DataFrame(columns=df.columns)
    n_bins = len(bins)
    psu_ids = df.tile_id.unique()
    
    train_params = stem.read_params(train_params)
    for var in train_params:
        exec ("{0} = str({1})").format(var, train_params[var])
    tiles = attributes_to_df(MOSAIC_SHP)
    
    if extract_predictors:
        var_info = pd.read_csv(var_info, sep='\t', index_col='var_name')
        for i, tile in enumerate(psu_ids):
            print("extracting %s of %s" % (i, len(psu_ids)))
            sample_mask = df.tile_id == tile
            this_sample = df.loc[sample_mask]
            tile_ul = tiles.loc[tiles['name'] == tile, ['xmin', 'ymax']].values[0]
            #point_dict = get_point_dict(df, psu_ids)
            mosaic_tx, extent = stem.tx_from_shp(MOSAIC_SHP, 30, -30)
            
            row_off, col_off = stem.calc_offset([mosaic_tx[0], mosaic_tx[3]], tile_ul, mosaic_tx)
            this_sample['local_row'] = this_sample.row - row_off
            this_sample['local_col'] = this_sample.col - col_off
    
            for var_name, var_row in var_info.iterrows():
                #tiles = pd.DataFrame({'tile_id': psu_ids, 'tile_str': psu_ids})
                file_path = stem.find_file(var_row.basepath, var_row.search_str, tile)
                ds = gdal.Open(file_path)
                ar = ds.GetRasterBand(var_row.data_band).ReadAsArray()
                try:
                    if len(this_sample) == ar.size:
                        df.loc[sample_mask, var_name] = ar.ravel()
                    else:
                        df.loc[sample_mask, var_name] = ar[this_sample.local_row, this_sample.local_col]
                except Exception as e:
                    print(e)
                    import pdb; pdb.set_trace()
                ds = None
        df.to_csv(txt.replace('.txt', '_predictors.txt'))
    #df[var_name], _ = extract.extract_var('', var_name, var_row.by_tile, var_row.data_band, var_row.data_type, tiles, df, point_dict, var_row.basepath, var_row.search_str, var_row.path_filter, mosaic_tx, 0, 0, silent=True)
                
    if by_psu: 
        
        n_per_psu = n_sample/len(psu_ids)
        n_per_bin = n_per_psu/n_bins
        
        for i, pid in enumerate(psu_ids):
            psu_pixels = df.loc[df.tile_id == pid]
            print("Sampling for %s of %s PSUs" % (i + 1, len(psu_ids)))
            for l, u in bins:
                this_bin = psu_pixels.loc[(l < psu_pixels.value) & (psu_pixels.value <= u)]
                if len(this_bin) > 0:
                    bin_sample_size = min(n_per_bin, len(this_bin))
                    sample = pd.concat([sample, this_bin.sample(bin_sample_size)])
                    print("Sampled %s for bin %s-%s" % (n_per_bin, l, u))
                else:
                    print("No pixels between %s and %s found" % (l, u))
            print("")
    
    else:
        n_per_bin = n_sample/n_bins
        for l, u in bins:
            sample = pd.concat([sample, df.sample(n_per_bin)])
    
    sample.to_csv(out_txt, index=False)
    
    print 'Sample written to ', out_txt
def main(params, inventory_txt=None, constant_vars=None, mosaic_shp=None, resolution=30, n_jobs=0, n_jobs_agg=0, mosaic_nodata=0, snap_coord=None, overwrite_tiles=False, tile_id_field='name'):
    inputs, df_var = stem.read_params(params)
    for i in inputs:
        exec ("{0} = str({1})").format(i, inputs[i])    
    df_var.data_band = [int(b) for b in df_var.data_band]#sometimes read as float

    try:
        support_size = [int(i) for i in support_size.split(',')]
        nodata = int(nodata)
        str_check = model_dir, mosaic_path, out_dir, train_params
    except NameError as e:
        missing_var = str(e).split("'")[1]
        msg = "Variable '%s' not specified in param file:\n%s" % (missing_var, params)
        raise NameError(msg)
    
    # Check that all the variables given were used in training and vice versa
    try:
        train_inputs, train_vars = stem.read_params(train_params)
    except:
        raise NameError('train_params not specified or does not exist')
    train_vars = sorted(train_vars.index)
    pred_vars  = sorted(df_var.index)
    # Make sure vars are sorted alphabetically since they were for training
    df_var = df_var.reindex(pred_vars)
    
    unmatched_vars = [v for v in pred_vars if v not in train_vars]
    if len(unmatched_vars) != 0:
        unmatched_str = '\n'.join(unmatched_vars)
        msg = 'Columns not in predict params but specified in train params:\n' + unmatched_str
        raise NameError(msg)
    
    if not os.path.exists(out_dir): os.mkdir(out_dir)
    else: print ('WARNING: out_dir already exists:\n%s\nAny existing files ' + \
    'will be overwritten...\n') % out_dir
    if not os.path.exists(os.path.join(out_dir, os.path.basename(params))):
        shutil.copy2(params, out_dir) #Copy the params for reference
    
    if 'confusion_params' in inputs: 
        conf_bn = os.path.basename(confusion_params)
        new_conf_path = os.path.join(out_dir, conf_bn)
        if not os.path.exists(new_conf_path):
            shutil.copy2(confusion_params, out_dir)
        confusion_params = new_conf_path
    
    if not os.path.exists(model_dir):
        sys.exit('model_dir does not exist:\n%s' % model_dir)
    if not os.path.exists(mosaic_path):
        sys.exit('mosaic_path does not exist:\n%s' % mosaic_path)
    
    predict_dir = os.path.join(out_dir, 'decisiontree_predictions')
    if not os.path.exists(predict_dir):
        os.mkdir(predict_dir)
    
    if not 'file_stamp' in inputs: file_stamp = os.path.basename(model_dir)
    db_path = os.path.join(model_dir, file_stamp + '.db')
    try:
        engine = sqlalchemy.create_engine('sqlite:///%s' % db_path)
        with engine.connect() as con, con.begin():
            df_sets = pd.read_sql_table('support_sets', con, index_col='set_id')#'''
    except:
        set_txt = glob.glob(os.path.join(model_dir, 'decisiontree_models/*support_sets.txt'))[0]
        if not os.path.isfile(set_txt):
            raise IOError('No database or support set txt file found')
        df_sets = pd.read_csv(set_txt, sep='\t', index_col='set_id')
    
    if mosaic_path.endswith('.shp'):
        mosaic_type = 'vector'
        # if subset specified, clip the mosaic and set mosaic path to clipped shp
        if 'subset_shp' in inputs:
            out_shp_bn = os.path.basename(mosaic_path).replace('.shp', '_clipped.shp')
            out_shp = os.path.join(out_dir, out_shp_bn)
            cmd = 'ogr2ogr -clipsrc {clip_shp} {out_shp} {in_shp}'.format(clip_shp=subset_shp, out_shp=out_shp, in_shp=mosaic_path)
            subprocess.call(cmd, shell=True)#'''
            mosaic_path = out_shp
        mosaic_dataset = ogr.Open(mosaic_path)
        mosaic_ds = mosaic_dataset.GetLayer()
        min_x, max_x, min_y, max_y = mosaic_ds.GetExtent()
        if 'resolution' not in inputs:
            warnings.warn('Resolution not specified. Using default of 30...\n')
        # If subset specified, just get sets that overlap the subset
        if 'subset_shp' in inputs:
            mosaic_geom = ogr.Geometry(ogr.wkbMultiPolygon)
            for feature in mosaic_ds:
                mosaic_geom.AddGeometry(feature.GetGeometryRef())
            df_sets = stem.get_overlapping_sets(df_sets, mosaic_geom)
        xsize = int((max_x - min_x)/resolution)
        ysize = int((max_y - min_y)/resolution)
        prj = mosaic_ds.GetSpatialRef().ExportToWkt()
        x_res = resolution
        y_res = -resolution
        x_rot = 0
        y_rot = 0
        if 'snap_coord' in train_inputs:
            snap_coord = train_inputs['snap_coord'].replace('"','')
            snap_coord = [float(c) for c in snap_coord.split(',')]#'''
        mosaic_tx, extent = stem.tx_from_shp(mosaic_path, x_res, y_res, snap_coord=snap_coord)
        tiles = stem.attributes_to_df(mosaic_path) # Change to accept arbittary geometry
        
    else:
        mosaic_type = 'raster'
        mosaic_ds = gdal.Open(mosaic_path)
        mosaic_tx = mosaic_ds.GetGeoTransform()
        xsize = mosaic_ds.RasterXSize
        ysize = mosaic_ds.RasterYSize
        prj = mosaic_ds.GetProjection()
        driver = mosaic_ds.GetDriver()
        m_ulx, x_res, x_rot, m_uly, y_rot, y_res = mosaic_tx
    #driver = gdal.GetDriverByName('gtiff')
        
    # If number of tiles not given, need to set it
    if 'n_tiles' not in inputs:
        print 'n_tiles not specified. Using default: 25 x 15 ...\n'
        n_tiles = 90, 40
    else:
        n_tiles = [int(i) for i in n_tiles.split(',')]
    #df_tiles, df_tiles_rc, tile_size = stem.get_tiles(n_tiles, xsize, ysize, mosaic_tx)
    
    total_sets = len(df_sets)
    t0 = time.time()
    last_dts = pd.Series()
    agg_stats = [s.strip().lower() for s in agg_stats.split(',')]
    n_jobs = int(n_jobs)
    tile_dir = os.path.join(model_dir, 'temp_tiles')
    #tile_dir = '/home/server/pi/homes/shooper/delete_test'
    if not os.path.isdir(tile_dir):
        os.mkdir(tile_dir)
    tile_path_template = os.path.join(tile_dir, 'tile_{tile_id}_%(stat)s.tif')
    n_tiles = len(tiles)
    
    if not overwrite_tiles:
        files = os.listdir(tile_dir)
        tile_files = pd.DataFrame(columns=agg_stats, index=tiles[tile_id_field])
        for stat in agg_stats:
            stat_match = [f.split('_')[1] for f in fnmatch.filter(files, 'tile*%s.tif' % stat)]
            tile_files[stat] = pd.Series(np.ones(len(stat_match)), index=stat_match)
        index_field = tiles.index.name
        tiles[index_field] = tiles.index
        tiles = tiles.set_index(tile_id_field, drop=False)[tile_files.isnull().any(axis=1)]
        tiles.set_index(index_field, inplace=True)
    
    tiles['ul_x'] = [stem.get_ul_coord(xmin, xmax, x_res) 
                    for i, (xmin, xmax) in tiles[['xmin','xmax']].iterrows()]
    tiles['ul_y'] = [stem.get_ul_coord(ymin, ymax, y_res) 
                    for i, (ymin, ymax) in tiles[['ymin','ymax']].iterrows()]
    tiles['lr_x'] = [xmax if ulx == xmin else xmin for i, (ulx, xmin, xmax)
                    in tiles[['ul_x', 'xmin','xmin']].iterrows()]
    tiles['lr_y'] = [ymax if uly == ymin else ymin for i, (uly, ymin, ymax) 
                    in tiles[['ul_y', 'ymin','ymin']].iterrows()]
    
    support_nrows = int(support_size[0]/abs(y_res))
    support_ncols = int(support_size[1]/abs(x_res))
    t1 = time.time()
    args = [(tile_info, mosaic_path, mosaic_tx, df_sets, df_var, (support_nrows, support_ncols), agg_stats, tile_path_template, prj, nodata, snap_coord) for i, (t_ind, tile_info) in enumerate(tiles[tiles['name'].isin(['1771', '3224', '0333', '0558'])].iterrows())]    
    #args = [(i + 1, n_tiles, t1, tile_info, mosaic_path, mosaic_tx, df_sets, df_var, (support_nrows, support_ncols), agg_stats, tile_path_template, prj, nodata, snap_coord) for i, (t_ind, tile_info) in enumerate(tiles.iterrows())]
    
    if n_jobs > 1:
        print 'Predicting with %s jobs...\n' % n_jobs
        pool = Pool(n_jobs)
        pool.map(stem.predict_tile, args, 1)
        pool.close()
        pool.join()
    else:
        for arg in args:
            print 'Predicting with 1 job ...\n'
            stem.predict_tile(*arg)#'''
    print '\n\nFinished predicting in %.1f hours. \n\nStitching tiles...' % ((time.time() - t1)/3600)
    
    t1 = time.time()
    mosaic_ul = mosaic_tx[0], mosaic_tx[3]
    driver = gdal.GetDriverByName('gtiff')
    for stat in agg_stats:
        if stat == 'stdv':
            this_nodata = -9999
            ar = np.full((ysize, xsize), this_nodata, dtype=np.int16) 
        else:
            this_nodata = nodata
            ar = np.full((ysize, xsize), this_nodata, dtype=np.uint8)
        
        for tile_id, tile_coords in tiles.iterrows():
            tile_file = os.path.join(tile_dir, 'tile_%s_%s.tif' % (tile_coords[tile_id_field], stat))
            ds = gdal.Open(tile_file)
            tile_tx = ds.GetGeoTransform()
            tile_ul = tile_tx[0], tile_tx[3]
            row_off, col_off = stem.calc_offset(mosaic_ul, tile_ul, mosaic_tx)
            # Make sure the tile doesn't exceed the size of ar
            tile_rows = min(ds.RasterYSize + row_off, ysize) - row_off
            tile_cols = min(ds.RasterXSize + col_off, xsize) - col_off
            ar_tile = ds.ReadAsArray(0, 0, tile_cols, tile_rows)
            try:
                ar[row_off : row_off + tile_rows, col_off : col_off + tile_cols] = ar_tile
            except Exception as e:
                import pdb; pdb.set_trace()
        
        out_path = os.path.join(model_dir, '%s_%s.tif' % (file_stamp, stat))
        #out_path = os.path.join('/home/server/pi/homes/shooper/delete_test', '%s_%s.tif' % (file_stamp, stat))
        gdal_dtype = gdal_array.NumericTypeCodeToGDALTypeCode(ar.dtype)
        mosaic.array_to_raster(ar, mosaic_tx, prj, driver, out_path, gdal_dtype, nodata=this_nodata)
    
    # Clean up the tiles
    shutil.rmtree(tile_dir)
    print 'Time for stitching: %.1f minutes\n' % ((time.time() - t1)/60)
    
    # Get feature importances and max importance per set
    t1 = time.time()
    print 'Getting importance values...'
    importance_cols = sorted([c for c in df_sets.columns if 'importance' in c])
    df_sets['max_importance'] = nodata
    if len(importance_cols) == 0:
        # Loop through and get importance
        importance_per_var = []
        for s, row in df_sets.iterrows():
            with open(row.dt_file, 'rb') as f: 
                dt_model = pickle.load(f)
            max_importance, this_importance = stem.get_max_importance(dt_model)
            df_sets.ix[s, 'max_importance'] = max_importance
            importance_per_var.append(this_importance)
        importance = np.array(importance_per_var).mean(axis=0)
    else:
        df_sets['max_importance'] = np.argmax(df_sets[importance_cols].values, axis=1)
        importance = df_sets[importance_cols].mean(axis=0).values
    pct_importance = importance / importance.sum()
    print '%.1f minutes\n' % ((time.time() - t1)/60)
    
    # Save the importance values
    importance = pd.DataFrame({'variable': pred_vars,
                               'pct_importance': pct_importance,
                               'index': range(len(pred_vars))
                               })
    importance.set_index('index', inplace=True)
    importance['rank'] = [int(r) for r in importance.pct_importance.rank(method='first', ascending=False)]
    out_txt = os.path.join(out_dir, '%s_importance.txt' % file_stamp)
    importance.to_csv(out_txt, sep='\t')#'''
    
    if 'confusion_params' in locals():
        import confusion_matrix as confusion

        ''' 
         Read the mean or vote back in '''
        if 'vote' in agg_stats:
            vote_path = os.path.join(out_dir, '%s_vote.tif' % file_stamp)
            ar_vote = gdal.Open(vote_path)
            print '\nComputing confusion matrix for vote...'
            vote_dir = os.path.join(model_dir, 'evaluation_vote')
            out_txt = os.path.join(vote_dir, 'confusion.txt')
            df_v = confusion.main(confusion_params, ar_vote, out_txt, match=True)
            vote_acc = df_v.ix['producer', 'user']
            vote_kap = df_v.ix['producer', 'kappa']
            '''try:
                out_txt = os.path.join(vote_dir, 'confusion_avg_kernel.txt')
                df_v_off = confusion.main(confusion_params, ar_vote, out_txt)
            except Exception as e:
                print e'''

                
        if 'mean' in agg_stats:
            mean_path = os.path.join(out_dir, '%s_mean.tif' % file_stamp)
            ar_mean = gdal.Open(mean_path)
            print '\nGetting confusion matrix for mean...'
            mean_dir = os.path.join(model_dir, 'evaluation_mean')
            out_txt = os.path.join(mean_dir, 'confusion.txt')
            df_m = confusion.main(confusion_params, ar_mean, out_txt, match=True)
            mean_acc = df_m.ix['user','producer']
            mean_kap = df_m.ix['user', 'kappa']
            '''try:
                out_txt = os.path.join(mean_dir, 'confusion_avg_kernel.txt')
                df_m_off = confusion.main(confusion_params, ar_mean, out_txt)
            except Exception as e:
                print e#'''


        if 'inventory_txt' in inputs:
            df_inv = pd.read_csv(inventory_txt, sep='\t', index_col='stamp')
            cols = ['vote_accuracy', 'vote_kappa']#, 'vote_mask', 'mean_accuracy', 'mean_kappa', 'vote_mask']
            df_inv.ix[file_stamp, cols] = vote_acc, vote_kap#, False, mean_acc, mean_kap, False
            df_inv.to_csv(inventory_txt, sep='\t')
        else:
            print '\n"inventory_txt" was not specified.' +\
            ' Model evaluation scores will not be recorded...'
            
        print ''
        if 'vote' in agg_stats:
            print 'Vote accuracy .............. ', vote_acc
            print 'Vote kappa ................. ', vote_kap
        if 'mean' in agg_stats:
            print 'Mean accuracy .............. ', mean_acc
            print 'Mean kappa ................. ', mean_kap
        
    else:
        print '\n"confusion_params" was not specified.' +\
            ' This model will not be evaluated...' #'''
    
    print '\nTotal prediction runtime: %.1f hours\n' % ((time.time() - t0)/3600)