def createGroundTruth(shapefiles,
                      img,
                      imgref,
                      geot,
                      imgName,
                      datasetPath,
                      cropCount=0):
    '''
    Create a mask for image based on a list of shapefiles
    if exists more shapefiles than the permited the excess will be discarded
    '''
    c = 1

    xAxis = img.shape[1]  # Max columns
    yAxis = img.shape[0]  # Max rows
    groundTruth = np.zeros((yAxis, xAxis), dtype='uint8')

    ext = utils.GetExtentGeometry(geot, xAxis, yAxis)
    ext.FlattenTo2D()

    for s in shapefiles:
        ds = ogr.Open(s)
        layer = ds.GetLayer(0)
        shpref = layer.GetSpatialRef()
        transform = osr.CoordinateTransformation(shpref, imgref)

        layer.ResetReading()
        for fid, feature in enumerate(layer):
            geometry = feature.GetGeometryRef()
            geometry.Transform(transform)
            if ext.Intersect(geometry):
                intersection = geometry.Intersection(ext)
                FillImage(groundTruth,
                          intersection,
                          geot,
                          fill=c,
                          cropCount=cropCount)

        c += 1

    return groundTruth
def createDetectionData(shapefile,
                        railwayshapefile,
                        imglist,
                        outputfolder,
                        size,
                        test=False):
    #print "\n\n ------------- Creating points of interest ------------------- \n\n"
    imgs = imglist

    rlwds = ogr.Open(railwayshapefile)
    rlwlayer = rlwds.GetLayer(0)
    rlwref = rlwlayer.GetSpatialRef()

    ds = None
    layer = None
    shpref = None
    closetorlw = None

    imgs_files = {}

    if not test:
        ds = ogr.Open(shapefile)
        layer = ds.GetLayer(0)
        shpref = layer.GetSpatialRef()
        rlwtransform = osr.CoordinateTransformation(rlwref, shpref)

        # Checa os poligonos que estao a menos de 5,1m do shp da ferrovia
        closetorlw = [False] * len(layer)
        for fid, f in enumerate(layer):
            geometry = f.GetGeometryRef()
            d = float('inf')
            rlwlayer.ResetReading()
            for feature in rlwlayer:
                railway = feature.GetGeometryRef()
                railway.Transform(rlwtransform)
                d = min(d, geometry.Distance(railway))

                if geometry.Distance(railway) <= 5.1:
                    closetorlw[fid] = True
                    break

    for num, file in enumerate(imgs):
        points_img = []
        references = set()
        imgs_files[num] = file
        img = gdal.Open(file)
        geot = img.GetGeoTransform()
        xAxis = img.RasterXSize  # Max columns
        yAxis = img.RasterYSize  # Max rows

        # Para fazer transformacoes de coordenadas nos shapefiles
        imgref = osr.SpatialReference(wkt=img.GetProjectionRef())
        if not test:
            transform = osr.CoordinateTransformation(shpref, imgref)
        rlwtransform = osr.CoordinateTransformation(rlwref, imgref)

        # Extent da Imagem // Usado para verificar se poligonos estao sobre uma imagem
        ext = utils.GetExtentGeometry(geot, xAxis, yAxis)
        ext.FlattenTo2D()

        # Ground Truth // Matriz e Arquivos
        groundTruth = np.zeros((yAxis, xAxis), dtype='bool8')
        gtname = os.path.splitext(os.path.split(file)[-1])
        if not os.path.isdir(os.path.join(outputfolder, "GroundTruths")):
            os.mkdir(os.path.join(outputfolder, "GroundTruths"))
        gtfile = os.path.join(outputfolder, "GroundTruths",
                              "mask_" + gtname[0] + ".png")

        # Percorre as linhas do shape da ferrovia e adiciona os pontos centrais que irao gerar os crops
        rlwlayer.ResetReading()
        for feature in rlwlayer:
            railway = feature.GetGeometryRef()
            railway.Transform(rlwtransform)
            if ext.Intersect(railway):
                intersection = ext.Intersection(railway)
                addPoints(references, intersection, geot)

        # Adiciona os poligonos de erosao no groundTruth (somente para fine tunning//treino)
        if not test:
            layer.ResetReading()
            for fid, feature in enumerate(layer):
                #print(feature.items())
                if closetorlw[fid]:
                    geometry = feature.GetGeometryRef()
                    geometry.Transform(transform)
                    if ext.Intersect(geometry):
                        intersection = geometry.Intersection(ext)
                        FillImage(groundTruth, intersection, geot)

        refs = list(references)
        train_ref = []
        val_ref = []
        if not test:
            val_ref, train_ref = create_sets(np.asarray(refs))
            createData(file, groundTruth,
                       os.path.join(outputfolder, 'Validation'), val_ref, geot,
                       size)
            createData(file, groundTruth, os.path.join(outputfolder, 'Train'),
                       train_ref, geot, size)

        else:
            # Cria os crops tanto da imagem quanto do groundTruth e salva o arquivo de groundTruth
            createData(file, groundTruth, outputfolder, references, geot, size)
            scipy.misc.imsave(gtfile, groundTruth * 255)
            # Caso seja para teste, salva os arrays auxiliares para a
            # reconstrucao da predicao completa em uma img
            np.save(os.path.join(outputfolder, "referencepoints.npy"),
                    np.asarray(refs))
            np.save(os.path.join(outputfolder, "referenceimgspaths.npy"),
                    np.asarray([file]))
def processInput(shapefile,
                 shapeLabels,
                 railwayShp,
                 imgList,
                 gtFolders,
                 stepSize,
                 patchSize,
                 ignoreRelevant=False):
    print "\n\n ------------- Creating points of interest ------------------- \n\n"
    shpfile = shapefile
    imgs = imgList

    ds = ogr.Open(shpfile)
    dsRails = ogr.Open(railwayShp)
    layer = ds.GetLayer(0)
    layerRails = dsRails.GetLayer(0)
    points = []
    bounds = []
    extents = []
    geotransformations = []
    imgs_files = {}
    gt_files = {}

    shpRef = layer.GetSpatialRef()
    rlwRef = layerRails.GetSpatialRef()

    tpoints = []
    rpoints = []

    erosionFeatures = []
    with open(shapeLabels) as file:
        file.readline()
        text = file.read().split('\r\n')
        for line in text:
            lsplit = line.split(',')
            if len(line) < 2:
                break

            fid, label = lsplit[0], lsplit[1]
            if label == 'Erosao':
                erosionFeatures.append(int(fid))

    print erosionFeatures
    for num, file in enumerate(imgs):
        points_img = []
        imgs_files[num] = file
        img = gdal.Open(file)
        geot = img.GetGeoTransform()
        geotransformations.append(geot)
        xAxis = img.RasterXSize  # Max columns
        yAxis = img.RasterYSize  # Max rows
        extents.append((xAxis, yAxis))

        imgRef = osr.SpatialReference(wkt=img.GetProjectionRef())
        transform = osr.CoordinateTransformation(shpRef, imgRef)
        rlwTransform = osr.CoordinateTransformation(rlwRef, imgRef)

        ext = utils.GetExtentGeometry(geot, xAxis, yAxis)
        ext.FlattenTo2D()

        # Gera os limites de geracao de pontos para a imagem
        layerRails.ResetReading()
        lims = []

        minRailCol, minRailRow = float('inf'), float('inf')
        maxRailCol, maxRailRow = -float('inf'), -float('inf')

        for feature in layerRails:
            railway = feature.GetGeometryRef()
            railway.Transform(rlwTransform)
            if ext.Intersect(railway):
                print num
                intersection = ext.Intersection(railway)
                if intersection.GetGeometryName() == 'MULTILINESTRING':
                    for l in range(intersection.GetGeometryCount()):
                        line = intersection.GetGeometryRef(l)
                        for i in range(line.GetPointCount() - 1):
                            pnts = []
                            p1p = line.GetPoint(i)
                            p2p = line.GetPoint(i + 1)

                            p1 = (p1p[0], p1p[1])
                            p2 = (p2p[0], p2p[1])
                            pxl1 = utils.CoordinateToPixel(geot, p1)
                            pxl2 = utils.CoordinateToPixel(geot, p1)

                            minRailCol = min(minRailCol, pxl1[0], pxl2[0])
                            minRailRow = min(minRailRow, pxl1[1], pxl2[1])
                            maxRailCol = max(maxRailCol, pxl1[0], pxl2[0])
                            maxRailRow = max(maxRailRow, pxl1[1], pxl2[1])

                            pnts.append(p1)
                            pnts.append(p2)
                            lims.append(pnts)
                else:
                    for i in range(intersection.GetPointCount() - 1):
                        pnts = []
                        p1p = intersection.GetPoint(i)
                        p2p = intersection.GetPoint(i + 1)

                        p1 = (p1p[0], p1p[1])
                        p2 = (p2p[0], p2p[1])

                        pxl1 = utils.CoordinateToPixel(geot, p1)
                        pxl2 = utils.CoordinateToPixel(geot, p1)

                        minRailCol = min(minRailCol, pxl1[0], pxl2[0])
                        minRailRow = min(minRailRow, pxl1[1], pxl2[1])
                        maxRailCol = max(maxRailCol, pxl1[0], pxl2[0])
                        maxRailRow = max(maxRailRow, pxl1[1], pxl2[1])

                        pnts.append(p1)
                        pnts.append(p2)
                        lims.append(pnts)

        bounds.append(lims)

        groundTruth = np.zeros((yAxis, xAxis), dtype='bool8')

        # Adiciona os pontos de erosao a pool de pontos
        layer.ResetReading()
        for fid, feature in enumerate(layer):
            if fid in erosionFeatures:
                geometry = feature.GetGeometryRef()
                geometry.Transform(transform)
                if ext.Intersect(geometry):
                    #gera lista de pontos do poligono que pertencem a imagem
                    #AddPolygon(geometry, geot, points_img, num, xAxis, yAxis)
                    intersection = ext.Intersection(geometry)
                    FillImage(groundTruth, intersection, geot)

        #print os.path.split(file)[-1]
        gtName = os.path.splitext(os.path.split(file)[-1])
        if not os.path.isdir(gtFolders[num]):
            os.mkdir(gtFolders[num])
        #gt_files[num] = gtFolder + str(num) + gtName[0] + gtName[1]
        gt_files[num] = gtFolders[num] + "mask_" + gtName[0] + ".png"
        if not os.path.isfile(gt_files[num]):
            io.imsave(gt_files[num], groundTruth)

        print "Create Mask for img {0} [{1}]".format(num, gtName[0])
        #if num == 0:
        #gdalnumeric.SaveArray(groundTruth, "GT.tif", format="GTiff", prototype=file)

        img = None
        railLimits = [minRailRow, minRailCol, maxRailRow, maxRailCol]
        if not os.path.isfile(
                '/media/tcu/PointDistribution/Parte3/targetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy'):
            targetPoints, relevantPoints = GetInterestPoints(
                groundTruth, geot, num, yAxis, xAxis, railLimits, lims,
                stepSize, patchSize)

            print "TARGET POINTS (%d)" % len(targetPoints)
            print "RELEVANT POINTS (%d)" % len(relevantPoints)

            tpoints.append(targetPoints)
            rpoints.append(relevantPoints)

    #print "Total points in polygons: {0}".format(len(points[0]))
    print "\n\n ------------------------------------------------------------- \n\n"

    if not os.path.isfile('/media/tcu/PointDistribution/Parte3/targetpoints_' +
                          str(stepSize) + '_' + str(patchSize) + '.npy'):
        tpoints = np.asarray(tpoints, dtype=np.dtype(object))
        rpoints = np.asarray(rpoints, dtype=np.dtype(object))

        np.save(
            open(
                '/media/tcu/PointDistribution/Parte3/targetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'wb'), tpoints)
        np.save(
            open(
                '/media/tcu/PointDistribution/Parte3/relevantpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'wb'), rpoints)
    else:
        tpoints = np.load(
            open(
                '/media/tcu/PointDistribution/Parte3/targetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'rb'))
        rpoints = np.load(
            open(
                '/media/tcu/PointDistribution/Parte3/relevantpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'rb'))

    if not os.path.isfile(
            '/media/tcu/PointDistribution/Parte3/balancedtargetpoints_' +
            str(stepSize) + '_' + str(patchSize) + '.npy'):
        tpoints, rpoints = balancePatches(tpoints, rpoints, gt_files,
                                          patchSize)
        np.save(
            open(
                '/media/tcu/PointDistribution/Parte3/balancedtargetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'wb'), tpoints)
        np.save(
            open(
                '/media/tcu/PointDistribution/Parte3/balancedrelevantpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'wb'), rpoints)
    else:
        tpoints = np.load(
            open(
                '/media/tcu/PointDistribution/Parte3/balancedtargetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'rb'))
        rpoints = np.load(
            open(
                '/media/tcu/PointDistribution/Parte3/balancedrelevantpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'rb'))

    # Create train/test distribution
    zeros = [i for i in range(len(imgs)) if len(tpoints[i]) == 0]
    #print zeros
    ids = np.array(range(len(imgs)))
    filteredids = np.delete(ids, zeros)

    testPointsIdx = random.sample(
        filteredids, int(len(filteredids) * 0.2 + 1)
    )  #np.asarray(random.sample(filteredids, int(len(filteredids)*0.2 + 1)))
    trainPointsIdx = [
        x for x in filteredids if x not in testPointsIdx
    ]  #np.asarray([x for x in range(len(imgs)) if x not in zeros or x not in testPointsIdx])
    #points = np.asarray(points, dtype=np.dtype(object))

    if ignoreRelevant:
        dummy = [None] * len(imgs)
        rpoints = np.asarray(dummy, dtype=object)
    #print points[0:10]
    #testPoints = points[testPointsIdx]
    #trainPoints = np.delete(points,testPointsIdx)

    #return tpoints, rpoints, trainPointsIdx, testPointsIdx, bounds, extents, geotransformations, imgs_files, gt_files
    return tpoints, rpoints, trainPointsIdx, testPointsIdx, imgs_files, gt_files
def createDatasetFromImg(imgName,
                         railwayShapefile,
                         featuresShapefile,
                         cropSize,
                         process,
                         outputPath,
                         task,
                         uniqueSegments,
                         trainFids,
                         valFids,
                         trainFile=None,
                         valFile=None,
                         valInfoFile=None):
    # Open image
    img = gdal.Open(imgName)
    name = os.path.split(imgName)[-1].replace('.img',
                                              '').replace('.tif', '').replace(
                                                  '.tiff', '')

    geoTransform = img.GetGeoTransform()
    xAxis = img.RasterXSize
    yAxis = img.RasterYSize

    # Open railway shapefile
    rlwDs = ogr.Open(railwayShapefile)
    rlwLayer = rlwDs.GetLayer(0)
    rlwSpr = rlwLayer.GetSpatialRef()

    # Open first shapefile in featuresShapes
    shapeZero = None
    shapeZeroLayer = None
    shapeZeroSpr = None
    transformZeroToImg = None

    if process == "train":
        shapeZero = ogr.Open(featuresShapefile[0])
        shapeZeroLayer = shapeZero.GetLayer(0)
        shapeZeroSpr = shapeZeroLayer.GetSpatialRef()

    # translate coordinates of shapefiles into images'
    imgSpr = osr.SpatialReference(wkt=img.GetProjectionRef())
    transformRlwToImg = osr.CoordinateTransformation(rlwSpr, imgSpr)

    if process == "train":
        transformZeroToImg = osr.CoordinateTransformation(shapeZeroSpr, imgSpr)

    # Image extent
    imgExt = utils.GetExtentGeometry(geoTransform, xAxis, yAxis)
    imgExt.FlattenTo2D()

    # Railway Segments
    segments = extractRailwaySegmentsFromShapefile(rlwLayer, transformRlwToImg,
                                                   imgExt)
    segments = sortSegments(segments)

    cropCount = 0
    isTrain = process == 'train'
    createMask = isTrain and task == 'segmentation'
    referencePoints = []
    uniqueCenters = set()

    for w, segment in enumerate(segments):
        if tuple(segment[0]) + tuple(segment[1]) in uniqueSegments:
            continue
        else:
            curPoint = checkSegmentOverlap(imgExt, segment, uniqueSegments)

        p1 = curPoint[0]
        p2 = curPoint[1]
        x = 0
        y = 0
        tValues = segmentPortions(segment, geoTransform, cropSize)
        for t in tValues:
            x = (p2[0] - p1[0]) * t + p1[0]
            y = (p2[1] - p1[1]) * t + p1[1]

            if (x, y) in uniqueCenters:
                continue
            else:
                uniqueCenters.add((x, y))

            xPixel, yPixel = utils.CoordinateToPixel(geoTransform, (x, y))
            if (yPixel - cropSize / 2 >= 0 and yPixel + cropSize / 2 <= yAxis
                    and xPixel - cropSize / 2 >= 0
                    and xPixel + cropSize / 2 <= xAxis):

                # Create Crop and CropExtent
                cropExt, crop, cropMask = createPatches(img,
                                                        geoTransform,
                                                        x,
                                                        y,
                                                        cropSize,
                                                        featuresShapefile,
                                                        imgName,
                                                        outputPath,
                                                        isTrain=isTrain,
                                                        createMask=createMask,
                                                        cropCount=cropCount)

                # check if the patch has more than 30% of pixels with black color
                # this was created for a specific case when the raster is larger but most of it is composed of black
                if np.bincount(crop.astype(
                        int).flatten())[0] > cropSize * cropSize * 0.3:
                    continue

                # The MaskRCNN needs reference points to recreate the complete image segmentation after processing
                if task == 'segmentation' and process == 'test':
                    referencePoints.append((yPixel, xPixel))

                # If it's for test we only need the crop from image. The dataset is organized in different ways for detection and segmentation.
                if process == 'test':
                    if task == 'detection':
                        # Save crop
                        scipy.misc.imsave(
                            os.path.join(
                                outputPath, 'JPEGImages', name + '_' +
                                str(xPixel) + '_' + str(yPixel) + '.png'),
                            crop)
                        valFile.write(name + '_' + str(xPixel) + '_' +
                                      str(yPixel) + '\n')
                        valInfoFile.write(name + '_' + str(xPixel) + '_' +
                                          str(yPixel) + ' -1' + '\n')
                    if task == 'segmentation':
                        scipy.misc.imsave(
                            os.path.join(
                                outputPath, 'JPEGImages',
                                name + '_' + 'crop' + str(cropCount) + '.png'),
                            crop)
                        cropCount += 1

                # If it's for training we need more informations to feed the networks
                # Check if crop contains selected train features
                if process == 'train':
                    fids = []
                    shapeZeroLayer.ResetReading()
                    interPs = []
                    for fid, feature in enumerate(shapeZeroLayer):
                        geometry = feature.GetGeometryRef()
                        geometry.Transform(transformZeroToImg)
                        if cropExt.Intersect(geometry):
                            # In segmentation only matters if the crop intersects a feature
                            if task == 'segmentation':
                                fids.append(fid)
                            # But for detection the size of this intersection is important to don't allow
                            # small overlaps that are virtualy not detectable
                            if task == 'detection':
                                intersection = cropExt.Intersection(geometry)
                                interPnts = getIntersectionPoints(
                                    intersection.GetGeometryRef(0),
                                    geoTransform)
                                if abs(interPnts[2] - interPnts[0]) * abs(
                                        interPnts[1] - interPnts[3]) > 100:
                                    interPs.append(interPnts)
                                    fids.append(fid)

                    # Create XML for detection task
                    if fids and task == 'detection':
                        saveXML(os.path.join(outputPath, 'Annotations'),
                                name + '_' + str(xPixel) + '_' + str(yPixel),
                                yAxis, xAxis, interPs,
                                (xPixel - cropSize // 2),
                                (yPixel - cropSize // 2))

                    # If intersects some train selected features
                    if np.any(np.isin(fids, trainFids)):
                        if task == 'segmentation':
                            # Save crop
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'Train', 'JPEGImages',
                                    name + "_crop" + str(cropCount) + '.png'),
                                crop)
                            # Save crop mask
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'Train', 'Masks', name +
                                    "_crop" + str(cropCount) + '_mask.png'),
                                cropMask)
                            cropCount += 1
                        if task == 'detection':
                            # Save crop
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'JPEGImages', name + '_' +
                                    str(xPixel) + '_' + str(yPixel) + '.png'),
                                crop)
                            # Write in file that this crop it's for train
                            trainFile.write(
                                os.path.join(name + '_' + str(xPixel) + '_' +
                                             str(yPixel)) + '\n')
                            cropCount += 1
                    else:
                        if task == 'segmentation':
                            # Save crop
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'Validation', 'JPEGImages',
                                    name + "_crop" + str(cropCount) + '.png'),
                                crop)
                            # Save crop mask
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'Validation', 'Masks', name +
                                    "_crop" + str(cropCount) + '_mask.png'),
                                cropMask)
                            cropCount += 1
                        if task == 'detection':
                            # Save crop
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'JPEGImages', name + '_' +
                                    str(xPixel) + '_' + str(yPixel) + '.png'),
                                crop)
                            # Write in file that this crop it's for validation/test
                            valInfoFile.write(
                                os.path.join(name + '_' + str(xPixel) + '_' +
                                             str(yPixel)) + ' ' +
                                ('1' if fids else '-1') + '\n')
                            valFile.write(
                                os.path.join(name + '_' + str(xPixel) + '_' +
                                             str(yPixel)) + '\n')
                            # print(name)
                            cropCount += 1

        uniqueSegments[tuple(segment[0]) + tuple(segment[1])] = (x, y)

    # if (task == "detection" and process == "test"):
    #scipy.misc.imsave(os.path.join(outputPath, 'JPEGImages', name + "_crop" + str(cropCount) + '.png'), crop)
    # Write in file that this crop it's for validation/test
    #valInfoFile.write(os.path.join(outputPath, 'JPEGImages', name + "_crop" + str(cropCount)) + ' ' + ('1' if fids else '-1') + '\n')
    # valFile.write(os.path.join(outputPath, 'JPEGImages', name + "_crop" + str(cropCount)) + '\n')
    # print(name)
    # cropCount += 1

    # The MaskRCNN needs reference points to recreate the complete image segmentation after processing
    # This saves the points
    if task == 'segmentation' and process == 'test':
        np.save(
            os.path.join(outputPath, 'ReferencePoints',
                         name + '_refpoints.npy'), np.asarray(referencePoints))
        np.save(
            os.path.join(outputPath, 'ReferencePoints', name + '_refpath.npy'),
            np.asarray([imgName]))