Пример #1
0
def genAnnualShapePoints(coord,
                         gdalDriver,
                         workingDirectory,
                         rasterResolution,
                         classToKeep,
                         dataField,
                         tile,
                         validityThreshold,
                         validityRaster,
                         classificationRaster,
                         masks,
                         inlearningShape,
                         outlearningShape,
                         epsg,
                         region_field_name,
                         runs,
                         annu_repartition,
                         logger=logger):

    #Const
    region_pos = 2  #in mask name if splited by '_'
    learn_flag = "learn"
    undetermined_flag = "XXXX"

    tile_pos = 0
    currentTile = os.path.splitext(os.path.basename(inlearningShape))[0]
    classifName = currentTile + "_Classif.tif"

    #check HPC mode
    try:
        PathWd = os.environ['TMPDIR']
    except:
        PathWd = None
    projection = int(epsg)

    vector_regions = []
    add = 0

    for current_seed in range(runs):

        for currentMask in masks:
            currentRegion = os.path.split(currentMask)[-1].split(
                "_")[region_pos]
            vector_region = os.path.join(
                workingDirectory, "Annual_" + currentTile + "_region_" +
                currentRegion + "_seed_" + str(current_seed) + ".sqlite")
            vector_regions.append(vector_region)
            rasterRdy = workingDirectory + "/" + classifName.replace(
                ".tif", "_RDY_" + str(currentRegion) + "_seed_" +
                str(current_seed) + ".tif")

            mapReg = otb.Registry.CreateApplication(
                "ClassificationMapRegularization")
            mapReg.SetParameterString("io.in", classificationRaster)
            mapReg.SetParameterString("ip.undecidedlabel", "0")
            mapReg.Execute()

            useless = otb.Registry.CreateApplication("BandMath")
            useless.SetParameterString("exp", "im1b1")
            useless.SetParameterStringList("il", [validityRaster])
            useless.SetParameterString("ram", "10000")
            useless.Execute()

            uselessMask = otb.Registry.CreateApplication("BandMath")
            uselessMask.SetParameterString("exp", "im1b1")
            uselessMask.SetParameterStringList("il", [currentMask])
            uselessMask.SetParameterString("ram", "10000")
            uselessMask.Execute()

            valid = otb.Registry.CreateApplication("BandMath")
            valid.SetParameterString(
                "exp", "im1b1>" + str(validityThreshold) + "?im2b1:0")
            valid.AddImageToParameterInputImageList(
                "il", useless.GetParameterOutputImage("out"))
            valid.AddImageToParameterInputImageList(
                "il", mapReg.GetParameterOutputImage("io.out"))
            valid.SetParameterString("ram", "10000")
            valid.Execute()

            rdy = otb.Registry.CreateApplication("BandMath")
            rdy.SetParameterString("exp", "im1b1*(im2b1>=1?1:0)")
            rdy.AddImageToParameterInputImageList(
                "il", valid.GetParameterOutputImage("out"))
            rdy.AddImageToParameterInputImageList(
                "il", uselessMask.GetParameterOutputImage("out"))
            rdy.SetParameterString(
                "out", rasterRdy +
                "?&streaming:type=stripped&streaming:sizemode=nbsplits&streaming:sizevalue=10"
            )
            rdy.SetParameterOutputImagePixelType("out",
                                                 otb.ImagePixelType_uint8)
            rdy.ExecuteAndWriteOutput()

            rasterArray = raster2array(rasterRdy)
            rasterFile = gdal.Open(rasterRdy)
            x_origin, y_origin = rasterFile.GetGeoTransform(
            )[0], rasterFile.GetGeoTransform()[3]
            sizeX, sizeY = rasterFile.GetGeoTransform(
            )[1], rasterFile.GetGeoTransform()[5]

            driver = ogr.GetDriverByName(gdalDriver)
            if os.path.exists(vector_region):
                driver.DeleteDataSource(vector_region)

            data_source = driver.CreateDataSource(vector_region)

            srs = osr.SpatialReference()
            srs.ImportFromEPSG(projection)

            layerName = "output"  #layerName
            layerOUT = data_source.CreateLayer(layerName, srs, ogr.wkbPoint)

            add_origin_fields(inlearningShape, layerOUT, region_field_name,
                              runs)

            for currentVal in classToKeep.data:
                try:
                    nbSamples = annu_repartition[str(
                        currentVal)][currentRegion][current_seed]
                except:
                    logger.info(
                        "class : {} does not exists in {} at seed {} in region {}"
                        .format(currentVal, inlearningShape, current_seed,
                                currentRegion))
                    continue
                Y, X = np.where(rasterArray == int(currentVal))
                XYcoordinates = []
                for y, x in zip(Y, X):
                    X_c, Y_c = pixCoordinates(x, y, x_origin, y_origin, sizeX,
                                              sizeY)
                    XYcoordinates.append((X_c, Y_c))
                if nbSamples > len(XYcoordinates):
                    nbSamples = len(XYcoordinates)
                for Xc, Yc in random.sample(
                        XYcoordinates,
                        nbSamples):  #"0" for nbSamples allready manage ?
                    if coord and (Xc, Yc) not in coord:
                        feature = ogr.Feature(layerOUT.GetLayerDefn())
                        feature.SetField(dataField, int(currentVal))
                        wkt = "POINT(%f %f)" % (Xc, Yc)
                        point = ogr.CreateGeometryFromWkt(wkt)
                        feature.SetGeometry(point)
                        layerOUT.CreateFeature(feature)
                        feature.Destroy()
                        add += 1

            data_source.Destroy()
            os.remove(rasterRdy)
            layerOUT = None

            #Add region column and value
            addField(vector_region,
                     region_field_name,
                     str(currentRegion),
                     valueType=str,
                     driver_name="SQLite")

            #Add seed columns and value
            for run in range(runs):
                if run == current_seed:
                    addField(vector_region,
                             "seed_" + str(run),
                             learn_flag,
                             valueType=str,
                             driver_name="SQLite")
                else:
                    addField(vector_region,
                             "seed_" + str(run),
                             undetermined_flag,
                             valueType=str,
                             driver_name="SQLite")

    outlearningShape_name = os.path.splitext(
        os.path.split(outlearningShape)[-1])[0]
    outlearningShape_dir = os.path.split(outlearningShape)[0]

    fu.mergeSQLite(outlearningShape_name, outlearningShape_dir, vector_regions)

    for vec in vector_regions:
        os.remove(vec)

    if add == 0:
        return False
    else:
        return True
Пример #2
0
def mergeFinalClassifications(iota2_dir,
                              dataField,
                              nom_path,
                              colorFile,
                              runs=1,
                              pixType='uint8',
                              method="majorityvoting",
                              undecidedlabel=255,
                              dempstershafer_mob="precision",
                              keep_runs_results=True,
                              enableCrossValidation=False,
                              validationShape=None,
                              workingDirectory=None,
                              logger=logger):
    """function use to merge classifications by majorityvoting or dempstershafer's method and evaluate it.

    get all classifications Classif_Seed_*.tif in the /final directory and fusion them
    under the raster call Classifications_fusion.tif. Then compute statistics using the
    results_utils library

    Parameters
    ----------

    iota2_dir : string
        path to the iota2's output path
    dataField : string
        data's field name
    nom_path : string
        path to the nomenclature file
    colorFile : string
        path to the color file description
    runs : int
        number of iota2 runs (random learning splits)
    pixType : string
        output pixel format (available in OTB)
    method : string
        fusion's method (majorityvoting/dempstershafer)
    undecidedlabel : int
        label for label for un-decisions
    dempstershafer_mob : string
        mass of belief measurement (precision/recall/accuracy/kappa)
    keep_runs_results : bool
        flag to inform if seeds results could be overwritten
    enableCrossValidation : bool
        flag to inform if cross validation is enable
    validationShape : string
        path to a shape dedicated to validate fusion of classifications
    workingDirectory : string
        path to a working directory

    See Also
    --------

    results_utils.gen_confusion_matrix_fig
    results_utils.stats_report
    """
    import shutil

    from Common import OtbAppBank as otbApp
    from Validation import ResultsUtils as ru
    from Common import CreateIndexedColorImage as color

    fusion_name = "Classifications_fusion.tif"
    new_results_seed_file = "RESULTS_seeds.txt"
    fusion_vec_name = "fusion_validation"  #without extension
    confusion_matrix_name = "fusionConfusion.png"

    if not method in ["majorityvoting", "dempstershafer"]:
        err_msg = "the fusion method must be 'majorityvoting' or 'dempstershafer'"
        logger.error(err_msg)
        raise Exception(err_msg)
    if not dempstershafer_mob in ["precision", "recall", "accuracy", "kappa"]:
        err_msg = "the dempstershafer MoB must be 'precision' or 'recall' or 'accuracy' or 'kappa'"
        logger.error(err_msg)
        raise Exception(err_msg)

    iota2_dir_final = os.path.join(iota2_dir, "final")
    wd = iota2_dir_final
    wd_merge = os.path.join(iota2_dir_final, "merge_final_classifications")
    if workingDirectory:
        wd = workingDirectory
        wd_merge = workingDirectory

    final_classifications = [
        fut.FileSearch_AND(iota2_dir_final, True,
                           "Classif_Seed_{}.tif".format(run))[0]
        for run in range(runs)
    ]
    fusion_path = os.path.join(wd, fusion_name)

    fusion_options = compute_fusion_options(iota2_dir_final,
                                            final_classifications, method,
                                            undecidedlabel, dempstershafer_mob,
                                            pixType, fusion_path)
    logger.debug("fusion options:")
    logger.debug(fusion_options)
    fusion_app = otbApp.CreateFusionOfClassificationsApplication(
        fusion_options)
    logger.debug("START fusion of final classifications")
    fusion_app.ExecuteAndWriteOutput()
    logger.debug("END fusion of final classifications")

    fusion_color_index = color.CreateIndexedColorImage(
        fusion_path,
        colorFile,
        co_option=["COMPRESS=LZW"],
        output_pix_type=gdal.GDT_Byte
        if pixType == "uint8" else gdal.GDT_UInt16)

    confusion_matrix = os.path.join(iota2_dir_final,
                                    "merge_final_classifications",
                                    "confusion_mat_maj_vote.csv")
    if enableCrossValidation is False:
        vector_val = fut.FileSearch_AND(
            os.path.join(iota2_dir_final, "merge_final_classifications"), True,
            "_majvote.sqlite")
    else:
        vector_val = fut.FileSearch_AND(os.path.join(iota2_dir, "dataAppVal"),
                                        True, "val.sqlite")
    if validationShape:
        validation_vector = validationShape
    else:
        fut.mergeSQLite(fusion_vec_name, wd_merge, vector_val)
        validation_vector = os.path.join(wd_merge, fusion_vec_name + ".sqlite")

    confusion = otbApp.CreateComputeConfusionMatrixApplication({
        "in":
        fusion_path,
        "out":
        confusion_matrix,
        "ref":
        "vector",
        "ref.vector.nodata":
        "0",
        "ref.vector.in":
        validation_vector,
        "ref.vector.field":
        dataField.lower(),
        "nodatalabel":
        "0",
        "ram":
        "5000"
    })
    confusion.ExecuteAndWriteOutput()

    maj_vote_conf_mat = os.path.join(iota2_dir_final, confusion_matrix_name)
    ru.gen_confusion_matrix_fig(csv_in=confusion_matrix,
                                out_png=maj_vote_conf_mat,
                                nomenclature_path=nom_path,
                                undecidedlabel=undecidedlabel,
                                dpi=900)

    if keep_runs_results:
        seed_results = fut.FileSearch_AND(iota2_dir_final, True,
                                          "RESULTS.txt")[0]
        shutil.copy(seed_results,
                    os.path.join(iota2_dir_final, new_results_seed_file))

    maj_vote_report = os.path.join(iota2_dir_final, "RESULTS.txt")

    ru.stats_report(csv_in=[confusion_matrix],
                    nomenclature_path=nom_path,
                    out_report=maj_vote_report,
                    undecidedlabel=undecidedlabel)

    if workingDirectory:
        shutil.copy(fusion_path, iota2_dir_final)
        shutil.copy(fusion_color_index, iota2_dir_final)
        os.remove(fusion_path)
Пример #3
0
def DoAugmentation(samples,
                   class_augmentation,
                   strategy,
                   field,
                   excluded_fields=[],
                   Jstdfactor=None,
                   Sneighbors=None,
                   workingDirectory=None,
                   logger=logger):
    """perform data augmentation according to input parameters

    Parameters
    ----------

    samples : string
        path to the set of samples to augment (OGR geometries must be 'POINT')
    class_augmentation : dict
        number of new samples to compute by class
    strategy : string
        which method to use in order to perform data augmentation (replicate/jitter/smote)
    field : string
        data's field
    excluded_fields : list
        do not consider these fields to perform data augmentation
    Jstdfactor : float
        Factor for dividing the standard deviation of each feature
    Sneighbors : int
        Number of nearest neighbors (smote's method)
    workingDirectory : string
        path to a working directory

    Note
    ----
    This function use the OTB's application **SampleAugmentation**,
    more documentation
    `here <http://www.orfeo-toolbox.org/Applications/SampleAugmentation.html>`_
    """
    from Common import OtbAppBank

    samples_dir_o, samples_name = os.path.split(samples)
    samples_dir = samples_dir_o
    if workingDirectory:
        samples_dir = workingDirectory
        shutil.copy(samples, samples_dir)
    samples = os.path.join(samples_dir, samples_name)

    augmented_files = []
    for class_name, class_samples_augmentation in list(
            class_augmentation.items()):
        logger.info(
            "{} samples of class {} will be generated by data augmentation ({} method) in {}"
            .format(class_samples_augmentation, class_name, strategy, samples))
        sample_name_augmented = "_".join([
            os.path.splitext(samples_name)[0],
            "aug_class_{}.sqlite".format(class_name)
        ])
        output_sample_augmented = os.path.join(samples_dir,
                                               sample_name_augmented)
        parameters = {
            "in": samples,
            "field": field,
            "out": output_sample_augmented,
            "label": class_name,
            "strategy": strategy,
            "samples": class_samples_augmentation
        }
        if excluded_fields:
            parameters["exclude"] = excluded_fields
        if strategy.lower() == "jitter":
            parameters["strategy.jitter.stdfactor"] = Jstdfactor
        elif strategy.lower() == "smote":
            parameters["strategy.smote.neighbors"] = Sneighbors

        augmentation_application = OtbAppBank.CreateSampleAugmentationApplication(
            parameters)
        augmentation_application.ExecuteAndWriteOutput()
        logger.debug("{} samples of class {} were added in {}".format(
            class_samples_augmentation, class_name, samples))
        augmented_files.append(output_sample_augmented)

    outputVector = os.path.join(
        samples_dir,
        "_".join([os.path.splitext(samples_name)[0], "augmented.sqlite"]))

    fut.mergeSQLite("_".join([os.path.splitext(samples_name)[0], "augmented"]),
                    samples_dir, [samples] + augmented_files)
    logger.info("Every data augmentation done in {}".format(samples))
    shutil.move(outputVector, os.path.join(samples_dir_o, samples_name))

    #clean-up
    for augmented_file in augmented_files:
        os.remove(augmented_file)