def segmentPortions(segment, geoTransform, cropSize):
    p1px = utils.CoordinateToPixel(geoTransform, segment[0])
    p2px = utils.CoordinateToPixel(geoTransform, segment[1])

    d = max(abs(p1px[0] - p2px[0]), abs(p1px[1] - p2px[1]))

    numPoints = d // (cropSize / 9)
    if numPoints == 0:
        return np.array([1.0])
    else:
        return np.arange(0, 1.01, 1.0 / float(numPoints))
def _addPoints(list, p1, p2, step, geotransform):
    for t in np.arange(0., 1.0, step):
        px = (p2[0] - p1[0]) * t + p1[0]
        py = (p2[1] - p1[1]) * t + p1[1]

        pixel = utils.CoordinateToPixel(geotransform, (px, py))  # Col, Row
        list.add((pixel[1], pixel[0]))
def getVertices(geometry, geotransform):
    ring = geometry.GetGeometryRef(0)
    pX = []
    pY = []
    for i in range(ring.GetPointCount()):
        lon, lat, z = ring.GetPoint(i)
        p = utils.CoordinateToPixel(geotransform, (lon, lat))
        pX.append(p[0])
        pY.append(p[1])

    return pX, pY
def getIntersectionPoints(ring, geotImg):
    minX, minY = float('inf'), float('inf')
    maxX, maxY = -float('inf'), -float('inf')

    for j in range(ring.GetPointCount()):
        lon, lat, z = ring.GetPoint(j)
        minX = min(minX, lon)
        minY = min(minY, lat)
        maxX = max(maxX, lon)
        maxY = max(maxY, lat)

    pMinX = pMinY = float('inf')
    pMaxX = pMaxY = float('inf')

    if not (math.isinf(minX) or math.isinf(minY) or math.isinf(maxX)
            or math.isinf(maxY)):
        pMinX, pMinY = utils.CoordinateToPixel(geotImg, (minX, minY))
        pMaxX, pMaxY = utils.CoordinateToPixel(geotImg, (maxX, maxY))

    # image and coordinates have diferent referential points
    # while coordinates has origin in bottom left, images has top left origin
    # so min lat is equivalent to max row
    # then, it's necessary to swap y (min goes to max and vice versa)
    return (pMinX, pMaxY, pMaxX, pMinY)
def genPoint(extent_list,
             bounds_list,
             geotransformations,
             total_imgs,
             maps=None):
    m = -1
    if maps is None:
        m = random.randint(0, total_imgs - 1)
        mi = m
        while len(bounds_list[mi]) == 0:
            mi += 1
            mi %= total_imgs
            if mi == m:
                print "\n\n ------------------------------------------- \n\n"
                sys.exit(
                    " Error: There is no intersection between images and Railway shape.\n\n ------------------------------------------- \n\n"
                )

        m = mi
    else:
        m = random.choice(maps)

    part = random.randint(0,
                          len(bounds_list[m]) -
                          1) if len(bounds_list[m]) > 1 else 0

    t = random.uniform(0, 1)
    p = bounds_list[m][part]

    p1 = p[0]
    p2 = p[1]

    x = (p2[0] - p1[0]) * t + p1[0]
    y = (p2[1] - p1[1]) * t + p1[1]

    pixels = utils.CoordinateToPixel(geotransformations[m], (x, y))
    col = pixels[0] + random.randint(-100, 100)
    row = pixels[1] + random.randint(-100, 100)
    if row < 0:
        row = 0
    elif row > extent_list[m][1]:
        row = extent_list[m][1]
    if col < 0:
        col = 0
    elif col > extent_list[m][0]:
        col = extent_list[m][0]

    return row, col, m
def getVertices(geometry, geotransform):
    '''
    Return two arrays with the col, row vertices of the geometry
    in a image
    '''
    if geometry.GetGeometryName() == "GEOMETRYCOLLECTION":
        for k in range(0, geometry.GetGeometryCount()):
            g = geometry.GetGeometryRef(k)
            if g is not None:
                if g.GeometryName() == "POLYGON":
                    return getVertices(g, geoTransform)

    ring = geometry.GetGeometryRef(0)
    if ring is None:
        return 0, 0
    pX = []
    pY = []
    for i in range(ring.GetPointCount()):
        lon, lat, z = ring.GetPoint(i)
        p = utils.CoordinateToPixel(geotransform, (lon, lat))
        pX.append(p[0])
        pY.append(p[1])

    return pX, pY
def processInput(shapefile,
                 shapeLabels,
                 railwayShp,
                 imgList,
                 gtFolders,
                 stepSize,
                 patchSize,
                 ignoreRelevant=False):
    print "\n\n ------------- Creating points of interest ------------------- \n\n"
    shpfile = shapefile
    imgs = imgList

    ds = ogr.Open(shpfile)
    dsRails = ogr.Open(railwayShp)
    layer = ds.GetLayer(0)
    layerRails = dsRails.GetLayer(0)
    points = []
    bounds = []
    extents = []
    geotransformations = []
    imgs_files = {}
    gt_files = {}

    shpRef = layer.GetSpatialRef()
    rlwRef = layerRails.GetSpatialRef()

    tpoints = []
    rpoints = []

    erosionFeatures = []
    with open(shapeLabels) as file:
        file.readline()
        text = file.read().split('\r\n')
        for line in text:
            lsplit = line.split(',')
            if len(line) < 2:
                break

            fid, label = lsplit[0], lsplit[1]
            if label == 'Erosao':
                erosionFeatures.append(int(fid))

    print erosionFeatures
    for num, file in enumerate(imgs):
        points_img = []
        imgs_files[num] = file
        img = gdal.Open(file)
        geot = img.GetGeoTransform()
        geotransformations.append(geot)
        xAxis = img.RasterXSize  # Max columns
        yAxis = img.RasterYSize  # Max rows
        extents.append((xAxis, yAxis))

        imgRef = osr.SpatialReference(wkt=img.GetProjectionRef())
        transform = osr.CoordinateTransformation(shpRef, imgRef)
        rlwTransform = osr.CoordinateTransformation(rlwRef, imgRef)

        ext = utils.GetExtentGeometry(geot, xAxis, yAxis)
        ext.FlattenTo2D()

        # Gera os limites de geracao de pontos para a imagem
        layerRails.ResetReading()
        lims = []

        minRailCol, minRailRow = float('inf'), float('inf')
        maxRailCol, maxRailRow = -float('inf'), -float('inf')

        for feature in layerRails:
            railway = feature.GetGeometryRef()
            railway.Transform(rlwTransform)
            if ext.Intersect(railway):
                print num
                intersection = ext.Intersection(railway)
                if intersection.GetGeometryName() == 'MULTILINESTRING':
                    for l in range(intersection.GetGeometryCount()):
                        line = intersection.GetGeometryRef(l)
                        for i in range(line.GetPointCount() - 1):
                            pnts = []
                            p1p = line.GetPoint(i)
                            p2p = line.GetPoint(i + 1)

                            p1 = (p1p[0], p1p[1])
                            p2 = (p2p[0], p2p[1])
                            pxl1 = utils.CoordinateToPixel(geot, p1)
                            pxl2 = utils.CoordinateToPixel(geot, p1)

                            minRailCol = min(minRailCol, pxl1[0], pxl2[0])
                            minRailRow = min(minRailRow, pxl1[1], pxl2[1])
                            maxRailCol = max(maxRailCol, pxl1[0], pxl2[0])
                            maxRailRow = max(maxRailRow, pxl1[1], pxl2[1])

                            pnts.append(p1)
                            pnts.append(p2)
                            lims.append(pnts)
                else:
                    for i in range(intersection.GetPointCount() - 1):
                        pnts = []
                        p1p = intersection.GetPoint(i)
                        p2p = intersection.GetPoint(i + 1)

                        p1 = (p1p[0], p1p[1])
                        p2 = (p2p[0], p2p[1])

                        pxl1 = utils.CoordinateToPixel(geot, p1)
                        pxl2 = utils.CoordinateToPixel(geot, p1)

                        minRailCol = min(minRailCol, pxl1[0], pxl2[0])
                        minRailRow = min(minRailRow, pxl1[1], pxl2[1])
                        maxRailCol = max(maxRailCol, pxl1[0], pxl2[0])
                        maxRailRow = max(maxRailRow, pxl1[1], pxl2[1])

                        pnts.append(p1)
                        pnts.append(p2)
                        lims.append(pnts)

        bounds.append(lims)

        groundTruth = np.zeros((yAxis, xAxis), dtype='bool8')

        # Adiciona os pontos de erosao a pool de pontos
        layer.ResetReading()
        for fid, feature in enumerate(layer):
            if fid in erosionFeatures:
                geometry = feature.GetGeometryRef()
                geometry.Transform(transform)
                if ext.Intersect(geometry):
                    #gera lista de pontos do poligono que pertencem a imagem
                    #AddPolygon(geometry, geot, points_img, num, xAxis, yAxis)
                    intersection = ext.Intersection(geometry)
                    FillImage(groundTruth, intersection, geot)

        #print os.path.split(file)[-1]
        gtName = os.path.splitext(os.path.split(file)[-1])
        if not os.path.isdir(gtFolders[num]):
            os.mkdir(gtFolders[num])
        #gt_files[num] = gtFolder + str(num) + gtName[0] + gtName[1]
        gt_files[num] = gtFolders[num] + "mask_" + gtName[0] + ".png"
        if not os.path.isfile(gt_files[num]):
            io.imsave(gt_files[num], groundTruth)

        print "Create Mask for img {0} [{1}]".format(num, gtName[0])
        #if num == 0:
        #gdalnumeric.SaveArray(groundTruth, "GT.tif", format="GTiff", prototype=file)

        img = None
        railLimits = [minRailRow, minRailCol, maxRailRow, maxRailCol]
        if not os.path.isfile(
                '/media/tcu/PointDistribution/Parte3/targetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy'):
            targetPoints, relevantPoints = GetInterestPoints(
                groundTruth, geot, num, yAxis, xAxis, railLimits, lims,
                stepSize, patchSize)

            print "TARGET POINTS (%d)" % len(targetPoints)
            print "RELEVANT POINTS (%d)" % len(relevantPoints)

            tpoints.append(targetPoints)
            rpoints.append(relevantPoints)

    #print "Total points in polygons: {0}".format(len(points[0]))
    print "\n\n ------------------------------------------------------------- \n\n"

    if not os.path.isfile('/media/tcu/PointDistribution/Parte3/targetpoints_' +
                          str(stepSize) + '_' + str(patchSize) + '.npy'):
        tpoints = np.asarray(tpoints, dtype=np.dtype(object))
        rpoints = np.asarray(rpoints, dtype=np.dtype(object))

        np.save(
            open(
                '/media/tcu/PointDistribution/Parte3/targetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'wb'), tpoints)
        np.save(
            open(
                '/media/tcu/PointDistribution/Parte3/relevantpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'wb'), rpoints)
    else:
        tpoints = np.load(
            open(
                '/media/tcu/PointDistribution/Parte3/targetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'rb'))
        rpoints = np.load(
            open(
                '/media/tcu/PointDistribution/Parte3/relevantpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'rb'))

    if not os.path.isfile(
            '/media/tcu/PointDistribution/Parte3/balancedtargetpoints_' +
            str(stepSize) + '_' + str(patchSize) + '.npy'):
        tpoints, rpoints = balancePatches(tpoints, rpoints, gt_files,
                                          patchSize)
        np.save(
            open(
                '/media/tcu/PointDistribution/Parte3/balancedtargetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'wb'), tpoints)
        np.save(
            open(
                '/media/tcu/PointDistribution/Parte3/balancedrelevantpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'wb'), rpoints)
    else:
        tpoints = np.load(
            open(
                '/media/tcu/PointDistribution/Parte3/balancedtargetpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'rb'))
        rpoints = np.load(
            open(
                '/media/tcu/PointDistribution/Parte3/balancedrelevantpoints_' +
                str(stepSize) + '_' + str(patchSize) + '.npy', 'rb'))

    # Create train/test distribution
    zeros = [i for i in range(len(imgs)) if len(tpoints[i]) == 0]
    #print zeros
    ids = np.array(range(len(imgs)))
    filteredids = np.delete(ids, zeros)

    testPointsIdx = random.sample(
        filteredids, int(len(filteredids) * 0.2 + 1)
    )  #np.asarray(random.sample(filteredids, int(len(filteredids)*0.2 + 1)))
    trainPointsIdx = [
        x for x in filteredids if x not in testPointsIdx
    ]  #np.asarray([x for x in range(len(imgs)) if x not in zeros or x not in testPointsIdx])
    #points = np.asarray(points, dtype=np.dtype(object))

    if ignoreRelevant:
        dummy = [None] * len(imgs)
        rpoints = np.asarray(dummy, dtype=object)
    #print points[0:10]
    #testPoints = points[testPointsIdx]
    #trainPoints = np.delete(points,testPointsIdx)

    #return tpoints, rpoints, trainPointsIdx, testPointsIdx, bounds, extents, geotransformations, imgs_files, gt_files
    return tpoints, rpoints, trainPointsIdx, testPointsIdx, imgs_files, gt_files
def createDatasetFromImg(imgName,
                         railwayShapefile,
                         featuresShapefile,
                         cropSize,
                         process,
                         outputPath,
                         task,
                         uniqueSegments,
                         trainFids,
                         valFids,
                         trainFile=None,
                         valFile=None,
                         valInfoFile=None):
    # Open image
    img = gdal.Open(imgName)
    name = os.path.split(imgName)[-1].replace('.img',
                                              '').replace('.tif', '').replace(
                                                  '.tiff', '')

    geoTransform = img.GetGeoTransform()
    xAxis = img.RasterXSize
    yAxis = img.RasterYSize

    # Open railway shapefile
    rlwDs = ogr.Open(railwayShapefile)
    rlwLayer = rlwDs.GetLayer(0)
    rlwSpr = rlwLayer.GetSpatialRef()

    # Open first shapefile in featuresShapes
    shapeZero = None
    shapeZeroLayer = None
    shapeZeroSpr = None
    transformZeroToImg = None

    if process == "train":
        shapeZero = ogr.Open(featuresShapefile[0])
        shapeZeroLayer = shapeZero.GetLayer(0)
        shapeZeroSpr = shapeZeroLayer.GetSpatialRef()

    # translate coordinates of shapefiles into images'
    imgSpr = osr.SpatialReference(wkt=img.GetProjectionRef())
    transformRlwToImg = osr.CoordinateTransformation(rlwSpr, imgSpr)

    if process == "train":
        transformZeroToImg = osr.CoordinateTransformation(shapeZeroSpr, imgSpr)

    # Image extent
    imgExt = utils.GetExtentGeometry(geoTransform, xAxis, yAxis)
    imgExt.FlattenTo2D()

    # Railway Segments
    segments = extractRailwaySegmentsFromShapefile(rlwLayer, transformRlwToImg,
                                                   imgExt)
    segments = sortSegments(segments)

    cropCount = 0
    isTrain = process == 'train'
    createMask = isTrain and task == 'segmentation'
    referencePoints = []
    uniqueCenters = set()

    for w, segment in enumerate(segments):
        if tuple(segment[0]) + tuple(segment[1]) in uniqueSegments:
            continue
        else:
            curPoint = checkSegmentOverlap(imgExt, segment, uniqueSegments)

        p1 = curPoint[0]
        p2 = curPoint[1]
        x = 0
        y = 0
        tValues = segmentPortions(segment, geoTransform, cropSize)
        for t in tValues:
            x = (p2[0] - p1[0]) * t + p1[0]
            y = (p2[1] - p1[1]) * t + p1[1]

            if (x, y) in uniqueCenters:
                continue
            else:
                uniqueCenters.add((x, y))

            xPixel, yPixel = utils.CoordinateToPixel(geoTransform, (x, y))
            if (yPixel - cropSize / 2 >= 0 and yPixel + cropSize / 2 <= yAxis
                    and xPixel - cropSize / 2 >= 0
                    and xPixel + cropSize / 2 <= xAxis):

                # Create Crop and CropExtent
                cropExt, crop, cropMask = createPatches(img,
                                                        geoTransform,
                                                        x,
                                                        y,
                                                        cropSize,
                                                        featuresShapefile,
                                                        imgName,
                                                        outputPath,
                                                        isTrain=isTrain,
                                                        createMask=createMask,
                                                        cropCount=cropCount)

                # check if the patch has more than 30% of pixels with black color
                # this was created for a specific case when the raster is larger but most of it is composed of black
                if np.bincount(crop.astype(
                        int).flatten())[0] > cropSize * cropSize * 0.3:
                    continue

                # The MaskRCNN needs reference points to recreate the complete image segmentation after processing
                if task == 'segmentation' and process == 'test':
                    referencePoints.append((yPixel, xPixel))

                # If it's for test we only need the crop from image. The dataset is organized in different ways for detection and segmentation.
                if process == 'test':
                    if task == 'detection':
                        # Save crop
                        scipy.misc.imsave(
                            os.path.join(
                                outputPath, 'JPEGImages', name + '_' +
                                str(xPixel) + '_' + str(yPixel) + '.png'),
                            crop)
                        valFile.write(name + '_' + str(xPixel) + '_' +
                                      str(yPixel) + '\n')
                        valInfoFile.write(name + '_' + str(xPixel) + '_' +
                                          str(yPixel) + ' -1' + '\n')
                    if task == 'segmentation':
                        scipy.misc.imsave(
                            os.path.join(
                                outputPath, 'JPEGImages',
                                name + '_' + 'crop' + str(cropCount) + '.png'),
                            crop)
                        cropCount += 1

                # If it's for training we need more informations to feed the networks
                # Check if crop contains selected train features
                if process == 'train':
                    fids = []
                    shapeZeroLayer.ResetReading()
                    interPs = []
                    for fid, feature in enumerate(shapeZeroLayer):
                        geometry = feature.GetGeometryRef()
                        geometry.Transform(transformZeroToImg)
                        if cropExt.Intersect(geometry):
                            # In segmentation only matters if the crop intersects a feature
                            if task == 'segmentation':
                                fids.append(fid)
                            # But for detection the size of this intersection is important to don't allow
                            # small overlaps that are virtualy not detectable
                            if task == 'detection':
                                intersection = cropExt.Intersection(geometry)
                                interPnts = getIntersectionPoints(
                                    intersection.GetGeometryRef(0),
                                    geoTransform)
                                if abs(interPnts[2] - interPnts[0]) * abs(
                                        interPnts[1] - interPnts[3]) > 100:
                                    interPs.append(interPnts)
                                    fids.append(fid)

                    # Create XML for detection task
                    if fids and task == 'detection':
                        saveXML(os.path.join(outputPath, 'Annotations'),
                                name + '_' + str(xPixel) + '_' + str(yPixel),
                                yAxis, xAxis, interPs,
                                (xPixel - cropSize // 2),
                                (yPixel - cropSize // 2))

                    # If intersects some train selected features
                    if np.any(np.isin(fids, trainFids)):
                        if task == 'segmentation':
                            # Save crop
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'Train', 'JPEGImages',
                                    name + "_crop" + str(cropCount) + '.png'),
                                crop)
                            # Save crop mask
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'Train', 'Masks', name +
                                    "_crop" + str(cropCount) + '_mask.png'),
                                cropMask)
                            cropCount += 1
                        if task == 'detection':
                            # Save crop
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'JPEGImages', name + '_' +
                                    str(xPixel) + '_' + str(yPixel) + '.png'),
                                crop)
                            # Write in file that this crop it's for train
                            trainFile.write(
                                os.path.join(name + '_' + str(xPixel) + '_' +
                                             str(yPixel)) + '\n')
                            cropCount += 1
                    else:
                        if task == 'segmentation':
                            # Save crop
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'Validation', 'JPEGImages',
                                    name + "_crop" + str(cropCount) + '.png'),
                                crop)
                            # Save crop mask
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'Validation', 'Masks', name +
                                    "_crop" + str(cropCount) + '_mask.png'),
                                cropMask)
                            cropCount += 1
                        if task == 'detection':
                            # Save crop
                            scipy.misc.imsave(
                                os.path.join(
                                    outputPath, 'JPEGImages', name + '_' +
                                    str(xPixel) + '_' + str(yPixel) + '.png'),
                                crop)
                            # Write in file that this crop it's for validation/test
                            valInfoFile.write(
                                os.path.join(name + '_' + str(xPixel) + '_' +
                                             str(yPixel)) + ' ' +
                                ('1' if fids else '-1') + '\n')
                            valFile.write(
                                os.path.join(name + '_' + str(xPixel) + '_' +
                                             str(yPixel)) + '\n')
                            # print(name)
                            cropCount += 1

        uniqueSegments[tuple(segment[0]) + tuple(segment[1])] = (x, y)

    # if (task == "detection" and process == "test"):
    #scipy.misc.imsave(os.path.join(outputPath, 'JPEGImages', name + "_crop" + str(cropCount) + '.png'), crop)
    # Write in file that this crop it's for validation/test
    #valInfoFile.write(os.path.join(outputPath, 'JPEGImages', name + "_crop" + str(cropCount)) + ' ' + ('1' if fids else '-1') + '\n')
    # valFile.write(os.path.join(outputPath, 'JPEGImages', name + "_crop" + str(cropCount)) + '\n')
    # print(name)
    # cropCount += 1

    # The MaskRCNN needs reference points to recreate the complete image segmentation after processing
    # This saves the points
    if task == 'segmentation' and process == 'test':
        np.save(
            os.path.join(outputPath, 'ReferencePoints',
                         name + '_refpoints.npy'), np.asarray(referencePoints))
        np.save(
            os.path.join(outputPath, 'ReferencePoints', name + '_refpath.npy'),
            np.asarray([imgName]))
def createPatches(img,
                  geotImg,
                  x,
                  y,
                  outputWindowSize,
                  shapefiles,
                  imgName,
                  datasetPath,
                  isTrain=False,
                  createMask=False,
                  cropCount=0):
    '''
    Create a crop of img with center x,y and total size 'outputWindowSize'
    if needed, can generate a mask for the crop based on the shapefiles
    '''

    # get info from the geoTransform of the image
    xOrigin = geotImg[0]
    yOrigin = geotImg[3]
    pixelWidth = geotImg[1]
    offsetX, offsetY = 0, 0
    # convert central coordinate to pixel
    centralX, centralY = utils.CoordinateToPixel(geotImg, (x, y))

    # get initial pixel - upper left
    pX, pY = centralX - int(outputWindowSize / 2), centralY - int(
        outputWindowSize / 2)
    if pX < 0:
        offsetX = 0 - pX
        pX = 0
    if pY < 0:
        offsetY = 0 - pY
        pY = 0

    # get final pixel - right down
    pXSize, pYSize = centralX + int(
        outputWindowSize / 2) + offsetX, centralY + int(
            outputWindowSize / 2) + offsetY

    # transform pixels back to coordinates
    xBegin, yBegin = utils.PixelToCoordinate(geotImg, (pX, pY))
    xFinal, yFinal = utils.PixelToCoordinate(geotImg, (pXSize, pYSize))

    # create polygon (or patch) based on the coordinates
    poly = ogr.Geometry(ogr.wkbPolygon)
    ring = ogr.Geometry(ogr.wkbLinearRing)
    ring.AddPoint(xBegin, yBegin)
    ring.AddPoint(xBegin, yFinal)
    ring.AddPoint(xFinal, yFinal)
    ring.AddPoint(xFinal, yBegin)
    ring.AddPoint(xBegin, yBegin)
    ring.CloseRings()
    poly.AddGeometry(ring)

    # create patch array
    xoff = int((xBegin - xOrigin) / pixelWidth)
    yoff = int((yOrigin - yBegin) / pixelWidth)
    # xcount = int(np.round(abs(xFinal - xBegin) / pixelWidth))
    # ycount = int(np.round(abs(yFinal - yBegin) / pixelWidth))
    xcount = outputWindowSize
    ycount = outputWindowSize
    # print('xoff_v', xoff, yoff, xcount, ycount, pixelWidth)

    npImageArray = np.moveaxis(img.ReadAsArray(xoff, yoff, xcount, ycount), 0,
                               -1)[:, :, 0:3]
    # print('shape', npImageArray.shape)

    imgref = osr.SpatialReference(wkt=img.GetProjectionRef())
    xoffCoord, yoffCoord = utils.PixelToCoordinate(geotImg, (xoff, yoff))
    geotCoord = (xoffCoord, geotImg[1], geotImg[2], yoffCoord, geotImg[4],
                 geotImg[5])

    npMask = None
    if isTrain and createMask:
        npMask = createGroundTruth(shapefiles,
                                   npImageArray,
                                   imgref,
                                   geotCoord,
                                   imgName,
                                   datasetPath,
                                   cropCount=cropCount)

    return poly, npImageArray, npMask