Beispiel #1
0
def plotcm(args):
    cm_file = args.cm

    conf_arr = np.loadtxt(cm_file, "int", delimiter=" ")
    print(conf_arr)

    if args.class_names is False:
        classes = [str(i) for i in range(conf_arr.shape[0])]
    else:
        classes = args.class_names

    norm_conf = []
    for i in conf_arr:
        a = 0
        tmp_arr = []
        a = sum(i, 0)
        for j in i:
            tmp_arr.append(float(j) / float(a))
        norm_conf.append(tmp_arr)

    norm_conf = np.array(norm_conf)
    #    norm_diag=[[np.nan if x!=y else norm_conf[x][y] for x in range(norm_conf.shape[1])] for y in range(norm_conf.shape[0])]
    #    print(norm_conf)
    #    print(norm_diag)

    fig = plt.figure(figsize=(8.5, 7.5))
    plt.clf()
    ax = fig.add_subplot(111)
    ax.set_aspect(1)
    res = ax.imshow(norm_conf, cmap='RdYlBu_r', interpolation='nearest')
    #    res = ax.imshow(norm_diag, cmap='RdYlGn',
    #                    interpolation='nearest')

    width, height = conf_arr.shape
    if args.num:
        for x in range(width):
            for y in range(height):
                ax.annotate("{:.2f}".format(norm_conf[x][y] * 100),
                            xy=(y, x),
                            horizontalalignment='center',
                            verticalalignment='center',
                            fontsize=args.ft_size,
                            rotation=0)
    if args.class_names is False:
        rot = 0
    else:
        rot = 35
    cb = fig.colorbar(res, fraction=0.046, pad=0.04)
    plt.xticks(range(width), classes[:width], rotation=rot)
    cb.ax.tick_params(labelsize=args.ft_size)
    plt.setp(ax.get_xticklabels(), fontsize=args.ft_size)
    plt.setp(ax.get_yticklabels(), fontsize=args.ft_size)
    plt.yticks(range(height), classes[:height])
    plt.savefig(ost.pathBranch(cm_file) + '/confusion_matrix.png',
                format='png')
    plt.savefig(ost.pathBranch(cm_file) + '/confusion_matrix.eps',
                format='eps')
Beispiel #2
0
def crossTraining(args):
    #main function : manage cross validation iterations and plan its data repartition
    outAP=os.path.abspath(args.out)
    
    #load class names into args to be easily used everywhere
    args.c = readH5Class(args.meta_gt)
    
    #Planning of data repartition
    crossIndex=selec.selecter(args)
    iterartion = args.cvIter if args.cvIter<=args.cvSplit else args.cvSplit
    
    #print CV Planning
    print("Split: ", args.cvSplit, end="")
    if args.cvIter>=args.cvSplit : print(", Iter: ",args.cvIter," REDUCED TO ",iterartion)
    else : print(", Iter: ",args.cvIter)
    for i in range(iterartion):
        print("Iter "+str(i)+": train="+str(len(crossIndex[i][0]))+" val="+str(len(crossIndex[i][1])) \
              +" test="+str(len(crossIndex[i][2])))
    
    #cross validation loop
    for i in range(args.cvIter):
        outIter=ost.createDirIncremental(outAP+"/stepCrossVal", 0)
        model = full_train(args,crossIndex[i],outIter)
    
    return model
Beispiel #3
0
 def addLossResults(self, i_epoch, loss, dataset_instance, show=1):
     #function to save loss and print it
     self.list_loss[dataset_instance].append(loss)
     self.list_loss_register[dataset_instance].append(i_epoch)
     if show:
         #color given the dataset instance
         if dataset_instance == 0: color = (19, 161, 14)
         elif dataset_instance == 1: color = (193, 156, 0)
         else: color = (204, 0, 0)
         print(ost.PRINTCOLOR(*color), end="")  #set color
         text = self.dataset_names[dataset_instance]
         print('Epoch %3d -> %s Loss: %1.6f' % (i_epoch, text, loss))
         print(ost.RESETCOLOR, end="")  #reset color
Beispiel #4
0
def main(args):
    #function to rebuild full image as TIF or VRT recursively from all small tiles
    #TIF version is almost deprecated
    #VRT is way quicker

    #get folders and sort them by depth if recursive stitching
    if args.rec:
        l = ost.sortFoldersByDepth(*ost.checkFoldersWithDepth(args.dir))
    else:
        l = [args.dir]
    #display founded folders
    print([el[60:] for el in l])

    #recursive stitching
    for path in tqdm(l):
        path = os.path.abspath(path)
        current = os.getcwd()  #save current console position
        os.chdir(
            ost.pathBranch(path)
        )  #move current console position to outAP to save instruction lenght

        #stitch current folder tiles
        #"ulimit -s 65536;" is to avoid command line lenght restrictions"
        if args.vrt:
            out = path + "_merged.vrt"
            subprocess.call("ulimit -s 65536; gdalbuildvrt -vrtnodata " +
                            args.nodata + " " + out + " " + path + "/*.tif",
                            shell=True)
        else:
            out = path + "_merged.tif"
            subprocess.call("ulimit -s 65536; gdal_merge.py -o " + out +
                            " -n " + args.nodata + " -a_nodata " +
                            args.nodata + " " + path + "/*.tif",
                            shell=True)

        os.chdir(current)  #move current console position orignel one
    return 0
Beispiel #5
0
    def addFullResults(self, i_epoch, cm, loss, dataset_instance, show=1):
        #function to save loss and other metric from CM
        #RETURN : IS_BEST only if dataset instance is 1 (VALIDATION set)
        #ISBEST is TRUE if current epoch got best mIoU ever

        #color given the dataset instance
        if dataset_instance == 0: color = (19, 161, 14)
        elif dataset_instance == 1: color = (193, 156, 0)
        else: color = (204, 0, 0)
        #get dataset instance name (train/val/test)
        text = self.dataset_names[dataset_instance]

        print(ost.PRINTCOLOR(*color), end="")  #set color
        mIoU, ious = cm.class_IoU(show)  #calcul iou and miou and display IoU
        oa = cm.overall_accuracy()  #calcul OA
        #display general metrics
        if show:
            print('Epoch %3d -> %s Overall Accuracy: %3.2f%% %s mIoU : %3.2f%% %s Loss: %1.6f' \
                  % (i_epoch, text, oa,text, mIoU,text, loss))
        print(ost.RESETCOLOR, end="")  #reset color

        #save metric values in containers
        self.list_cm[dataset_instance].append(cm)
        self.list_loss[dataset_instance].append(loss)
        self.list_mIoU[dataset_instance].append(mIoU)
        self.list_ious[dataset_instance].append(ious)
        self.list_oa[dataset_instance].append(oa)
        self.list_register[dataset_instance].append(i_epoch)
        self.list_loss_register[dataset_instance].append(i_epoch)

        #check if current epoch is best epoch by comparing mIoU
        isBest = True if mIoU >= max(
            self.list_mIoU[dataset_instance]) else False
        if isBest and dataset_instance == 1:  #verif if we really are in VALIDATION DATASET INSTANCE
            self.best = i_epoch
            self.best_id = self.list_mIoU[dataset_instance].index(mIoU)
        return isBest
Beispiel #6
0
def multiprocessing_func(mpArg):
    #function to produce required inference product for one tile

    #extract arguments
    prediction = mpArg[0]
    gt = mpArg[1]
    base_name = mpArg[2]
    outs = mpArg[3]
    args = mpArg[4]
    colors = mpArg[5]

    #take arg max of prediction tensor for dif and inf product
    if args.inf or args.dif:
        pred = prediction.argmax(0).squeeze().numpy()
        #cut prediction if margin
        if args.margin != 0:
            pred = pred[args.margin:pred.shape[0] - args.margin,
                        args.margin:pred.shape[1] - args.margin]
            gt = gt[args.margin:gt.shape[0] - args.margin,
                    args.margin:gt.shape[1] - args.margin]

    #produce semantic segmentation map
    if args.inf:
        path = outs[0] + "/inf_" + ost.pathLeaf(base_name) + ".tif"
        ost.array2raster(pred,
                         path,
                         base_name,
                         rasterType='GTiff',
                         datatype=gdal.GDT_Byte,
                         noDataValue=args.nodata,
                         colors=colors,
                         margin=args.margin)

    #produce difference map between GT and prediction arg max map
    if args.dif:
        colors_dif = {0: (0, 0, 0, 0), 1: (255, 0, 0, 255)}
        dif = ost.arrayDif(gt, pred, args.nodata)
        path = outs[1] + "/dif_" + ost.pathLeaf(base_name) + ".tif"
        ost.array2raster(dif,
                         path,
                         base_name,
                         rasterType='GTiff',
                         datatype=gdal.GDT_Byte,
                         noDataValue=0,
                         colors=colors_dif,
                         margin=args.margin)

    #produce proba maps (1 image for each class)
    if args.proba:
        #convert logits to probabilities with softmax
        pred_soft = nn.functional.softmax(prediction, 0)
        #iterate on each class (output dir)
        for i, o in enumerate(outs[2]):
            path = o + "/proba_" + ost.pathLeaf(base_name) + ".tif"
            c = pred_soft[i].numpy()
            #cut prediction if margin
            if args.margin != 0:
                c = c[args.margin:c.shape[0] - args.margin,
                      args.margin:c.shape[1] - args.margin]
            ost.array2raster(c,
                             path,
                             base_name,
                             rasterType='GTiff',
                             datatype=gdal.GDT_Float32,
                             margin=args.margin)

    #produce proba maps in h5 format (saved as full 3D tensor)
    if args.probaH5:
        path = outs[3] + "/proba_" + ost.pathLeaf(base_name) + ".h5"
        #            print(path)
        with h5.File(path, "w") as f:
            c = prediction.numpy()
            #cut margin and save
            f.create_dataset("pred",
                             data=c[:, args.margin:c.shape[1] - args.margin,
                                    args.margin:c.shape[2] - args.margin])
Beispiel #7
0
def main(args):
    #MAIN FUNCTION FOR ARGS LOOK AT PARSER
    #RETURN : NOTHING ! HA ! but create all inference data in output dir

    #manage unmatching args (there may exist cleverer solution, it was implemented to bu simple)
    if not args.inf and not args.dif and not args.proba and not args.probaH5:
        print(
            "ERROR : WAKE UP !! ... select at least one things to do (inf/dif/proba/probaH5)"
        )
        return 0
    if args.dir_gt is False and args.meta_gt is False:
        args.noGT = True  #if no gt noGT is true
    elif args.dir_gt is not False and args.meta_gt is False:
        print("ERROR : Either set -dir_gt to False or give meta_gt path")
        return 0
    elif (args.dir_gt is False and args.meta_gt is not False):
        print("ERROR : Either set -meta_gt to False or give dir_gt path")
        return 0
    else:
        args.noGT = False
    if args.noGT and args.dif:
        print(
            "ERROR : You can't get dif between Inf and GT without GT ... Either set -noGT to False or -dif to False"
        )
        return 0
    if args.noGT and args.metric:
        print(
            "ERROR : You can't get metric without GT ... Either set -noGT to False or -metric to False"
        )
        return 0

    #create output dir
    args.out = ost.createDir(args.out + "/drawings")

    #if data repartition file given, create separate train/val/test dataset reader
    DS_list = []
    if args.cv is not False:
        args.train_set = True
        with open(args.cv, 'rb') as fp:
            train = pickle.load(fp)
            val = pickle.load(fp)
            test = pickle.load(fp)
        #create data sets
        DS_list.append(
            reader.DatasetTIF(args.dir_img,
                              args.dir_gt,
                              args.meta_img,
                              args.meta_gt,
                              train,
                              noGT=args.noGT))
        DS_list.append(
            reader.DatasetTIF(args.dir_img,
                              args.dir_gt,
                              args.meta_img,
                              args.meta_gt,
                              val,
                              noGT=args.noGT))
        DS_list.append(
            reader.DatasetTIF(args.dir_img,
                              args.dir_gt,
                              args.meta_img,
                              args.meta_gt,
                              test,
                              noGT=args.noGT))
    else:
        DS_list.append(
            reader.DatasetTIF(args.dir_img,
                              args.dir_gt,
                              args.meta_img,
                              args.meta_gt,
                              noGT=args.noGT))

    #create list of data loader given created data reader
    loader_list = [
        data.DataLoader(ds,
                        batch_size=args.bs,
                        shuffle=False,
                        drop_last=False,
                        num_workers=args.num_workers) for ds in DS_list
    ]

    #load model
    checkpoint = torch.load(args.state)
    print("Best epoch: ", checkpoint['epoch'])
    try:
        channels = checkpoint['channels']
        args.c = checkpoint['names']
        print("Model channels :", channels, "\nModel class :", args.c)
    except:
        print("ERROR : OLD MODEL, PLEASE ACTUALIZE")
        sys.exit()
    model = nb.OLIVENET(channels, len(args.c))
    model.load_state_dict(checkpoint['state_dict'])
    model.cuda()
    model.eval()  # model in training mode

    #create output dir regarding to asked operations
    if args.inf:
        ost.createDir(args.out + "/inf")
    if args.dif:
        ost.createDir(args.out + "/dif")
    if args.proba:
        ost.createDir(args.out + "/proba")
    if args.probaH5:
        ost.createDir(args.out + "/probaH5")

    #inference loop through all data loader created
    for i, loader in enumerate(loader_list):
        inference(i, model, loader, args)
Beispiel #8
0
def inference(ind, model, loader, args):
    #function to produce inference data for one data loader
    #ind : loop position
    #model : explicit
    #loader : current data loader
    #args : look at parser args

    #create all subfolder if train/val/test split, else don't create new subfolders
    outs = ["", "", [], []]
    name = ["train", "val", "test"]  #need to match the order of dataloader
    if args.inf:
        outs[0] = ost.createDir(
            args.out + "/inf/" +
            name[ind]) if args.train_set else args.out + "/inf"
    if args.dif:
        outs[1] = ost.createDir(
            args.out + "/dif/" +
            name[ind]) if args.train_set else args.out + "/dif"
    if args.proba:
        outHM = [ost.createDir(args.out + "/proba/" + c) for c in args.c]
        outs[2] = [ost.createDir(p + "/" + name[ind])
                   for p in outHM] if args.train_set else outHM
    if args.probaH5:
        outs[3] = ost.createDir(
            args.out + "/probaH5/" +
            name[ind]) if args.train_set else args.out + "/probaH5"


#        print(outs[3])

#get colors as dict for colortable
    if args.color is not False:
        c = np.loadtxt(args.color, delimiter=",", dtype=int)
        colors = {}
        for i in range(c.shape[0]):
            colors[int(c[i][0])] = (tuple(c[i][1:5]))
    else:
        colors = None

    #metric containers
    if args.metric:
        cm = m.ConfusionMatrix(len(args.c), args.c, args.nodata)

    #loop through batch given by data reader, unfold tuple
    for batch_ndx, (imgs, gt, names) in enumerate(tqdm(loader)):

        #if CUDA, load on GPU
        if args.cuda:
            batch_tensor = imgs.cuda()
        else:
            batch_tensor = imgs

        #generate prediction
        prediction = model(batch_tensor)
        prediction = prediction.cpu()

        #build instruction for multiprocess loop (1 instruction for each tile of the batch)
        if args.noGT:
            mpArg = [(prediction[i].detach(), None, names[i], outs, args,
                      colors) for i in range(prediction.shape[0])]
        else:
            mpArg = [(prediction[i].detach(), gt[i], names[i], outs, args,
                      colors) for i in range(prediction.shape[0])]

        #multi process loop to create inference data of the current batch
        with multiprocessing.Pool() as pool:
            pool.map(multiprocessing_func, mpArg)

        #calculate batch metric and add it to metric containers
        if args.metric:
            for i in range(prediction.size()[0]):
                pred = prediction[i].argmax(0).squeeze()
                cm.add_batch(gt[i].numpy(), pred.numpy())

        #free memory
        del imgs
        del gt
        del batch_tensor
        del prediction

    #produce current dataset metric to metric output folder
    if args.metric:
        out_perf = ost.createDir(
            args.out + "/" + name[ind] +
            "_perf_inf") if args.train_set else ost.createDir(args.out +
                                                              "/perf_inf")
        cm.printPerf(out_perf)
    return 0
Beispiel #9
0
def gtBuilder_v4(base_img,
                 path,
                 path_img,
                 nodata_value,
                 nodata_mask=False,
                 split_field='',
                 datatype=gdal.GDT_UInt16,
                 arrayBack=False,
                 colors=None):
    #Create Ground truth image given an base image extent by rasterizing a set of shapefiles from path
    #Return the class names or if arrayBack is True, return GT as array and class names
    #WARNING : order of read of shapefiles matters (natural order listing used here)
    #base_img : img on which groud truth extent is based
    #path : path of shape file or of folder where shapefiles are stored
    #path_img : path where to save ground truth image
    #split_field : deprecated
    #arrayBack: if true return GT as array and class names
    #nodata_mask : path to mask where GT will be produce. Values outside polygones will be set to no_data_values
    #color : color list in "r,g,b" may be deprecated
    path = os.path.abspath(path)
    print('image saved at: ', path_img)

    #manage case where path = shapefile or path = folder of shapefile
    list_shp = []
    if os.path.isdir(path):
        list_shp = ost.getFileByExt(path, ".shp")
    elif os.path.isfile(path):
        list_shp.append(path)
    else:
        print('PATH GIVEN IS NOT SHAPE FILE OR DIR')
        return 0
    if len(list_shp) == 0:
        print('NO SHAPE FILE FOUNDED')
        return 0

    #create GT raster by copying base img metadata
    base_img = os.path.abspath(base_img)
    base = gdal.Open(base_img)
    ras_c = ost.rasterCopy(base, path_img, datatype=datatype, bands=1)
    ras_c.GetRasterBand(1).SetNoDataValue(nodata_value)

    #get class names from shapefiles names
    names = [ost.pathLeaf(p) for p in list_shp]
    print(names)

    #if only 1 class, switch to binary mode and create other class and initialize GT with "others" class
    if len(names) == 1:
        print("/!\ SWITCH TO CLASSIFICATION BINAIRE...")
        ras_c.GetRasterBand(1).Fill(1)
        names.append("others")

    #rasterize all shapefile in natural name order
    for i, path_shp in enumerate(list_shp):
        try:
            shp = ogr.Open(path_shp)
            if shp:  # checks to see if shapefile was successfully defined
                print('loading: %s' % (path_shp))
            else:  # if it's not successfully defined
                print('COULD NOT LOAD SHAPE: %s' % (path_shp))
        except:  # Seems redundant, but if an exception is raised in the Open() call, you get a message
            print(
                'EXCEPTION RAISED WHILE LOADING: %s' % (path_shp)
            )  # if you want to see the full stacktrace - like you are currently getting,# then you can add the following:
            raise

        source_layer = shp.GetLayer()
        gdal.RasterizeLayer(ras_c, [1], source_layer, burn_values=[i])  #+1

    #apply color given
    if colors is not None:
        cT = gdal.ColorTable()
        # set color for each
        for i in range(len(colors)):
            cT.SetColorEntry(i, colors[i])

        cT.SetColorEntry(nodata_value, (0, 0, 0, 0))  #nodata
        # set color table and color interpretation
        ras_c.GetRasterBand(1).SetRasterColorTable(cT)
        ras_c.GetRasterBand(1).SetRasterColorInterpretation(
            gdal.GCI_PaletteIndex)

    if arrayBack:
        gt = ras_c.ReadAsArray()

    #save raster to have it available for
    ras_c = None

    #burn nodata zone with nodata_mask
    if nodata_mask is not False:
        print("Applying Nodata_Mask:", nodata_mask)
        gdal_instruc = [
            "gdal_rasterize", "-b", "1", "-i", "-burn",
            str(nodata_value), nodata_mask, path_img
        ]
        subprocess.call(
            gdal_instruc
        )  #,stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL)

    if arrayBack:
        return gt, names
    else:
        return names
Beispiel #10
0
def megaTailor(args):
    #main function for args look at parser

    #check if gt is given as shp or tif. IF TIF NEED NOMENCLATRUE
    if args.create_gt:
        if ost.pathExt(args.gt) == ".tif":
            gt_is_shp = False
            print("GT given as tif img")
            if args.nomenclature is False:
                print("ERROR : TIF GT NEED NOMENCLATURE")
                sys.exit()
        else:
            gt_is_shp = True
            print("GT given as shp")

    #create path variable
    imgAP = os.path.abspath(args.img)
    outAP = os.path.abspath(args.out)
    meta_img = outAP + "/meta_img.h5"
    meta_gt = outAP + "/meta_gt.h5"

    #check at least one of gt or tile are activated
    if not args.create_tile and not args.create_gt:
        print(
            "ERROR : WAKE UP !! ... Choose at least one between GT or TILE !!")
        sys.exit()

    #extract main img vital information
    ds = gdal.Open(imgAP)
    band = ds.GetRasterBand(1)
    xsize = band.XSize
    ysize = band.YSize
    tile_size_x = args.x
    tile_size_y = args.y
    lenght_dataset = (xsize // tile_size_x + 1) * (ysize // tile_size_y + 1)

    gtAP = None  #initialize gtAP path variable

    #work GT
    if args.create_gt:
        gtAP = outAP + "/gt.tif"

        #Create GT img or crop it to imgAP extent
        if gt_is_shp:

            #get colors as dict for colortable
            if args.color is not False:
                print("Color given by :", args.color)
                c = np.loadtxt(args.color, delimiter=",", dtype=int)
                colors = {}
                for i in range(c.shape[0]):
                    colors[int(c[i][0])] = (tuple(c[i][1:5]))
            else:
                colors = None

            #create gt and get class names from shapefiles names
            shpAP = os.path.abspath(args.gt)
            names = gtBuilder_v4(imgAP,
                                 shpAP,
                                 gtAP,
                                 nodata_value=args.nodata_value,
                                 nodata_mask=args.nodata_mask,
                                 datatype=gdal.GDT_Byte,
                                 colors=colors)

        else:
            geotransform = ds.GetGeoTransform()
            originX = geotransform[0]
            originY = geotransform[3]
            xres = geotransform[1]
            yres = geotransform[5]
            bottomX = originX + xsize * xres
            bottomY = originY + ysize * yres
            img_gtAP = os.path.abspath(args.gt)
            #crop GT TIF given to IMG extent
            gdal.Translate(gtAP,
                           img_gtAP,
                           projWin=[originX, originY, bottomX, bottomY],
                           xRes=xres,
                           yRes=yres,
                           resampleAlg="nearest",
                           noData=args.nodata_value)

        #Open newly created gt
        dsGT = gdal.Open(gtAP)
        band = dsGT.GetRasterBand(1)

        #replace names by nomenclature names given
        if args.nomenclature is not False:
            print("Class names given by :", args.nomenclature)
            names_indices = np.loadtxt(args.nomenclature,
                                       delimiter=" ",
                                       dtype=str)
            #            indices=names_indices[:,0]
            names = list(names_indices[:, 1])
#            print (names)
#            print (indices)

# get class proportion in GT img
        hist = band.GetHistogram(approx_ok=False)
        prop = np.array(hist[0:len(names)])
        propPercent = 100 * prop / np.sum(prop)

        totalPix = dsGT.RasterXSize * dsGT.RasterYSize
        propPercentTotal = 100 * prop / totalPix
        str_rep='GT size = '+str(args.x)+'*'+str(args.y)+'*'+str(lenght_dataset)+' with '+str(int(np.sum(prop)))+'/'+str(totalPix)+' labeled pixels\n'+\
                "Dataset repartition :\n"+\
                '{:20} : {:12} || {:12} || {:4}'.format('CLASS','PIX NB', '% IN LABELED ','% IN TOTAL\n')+\
                '{:20} : {:12d} || {:12.2f}% || {:12.2f}%\n'.format('no_data',int(totalPix-np.sum(prop)), 0,(1-(np.sum(prop)/totalPix))*100)+\
                '\n'.join('{:20} : {:12d} || {:12.2f}% || {:12.2f}%'.format( \
                name, int(p), percent,pt) for name, percent,p,pt in zip(names,propPercent,prop,propPercentTotal))
        print(str_rep)

    # get img border
    if args.create_tile:
        nb_channels = ds.RasterCount
        print("Getting img quartiles...")
        list_border_norm = []

        #load band limit if given
        if args.band_norm is not False:
            list_border_norm = np.loadtxt(args.band_norm, dtype=float)
            #expand dim if only 1 band
            try:
                list_border_norm.shape[1]
            except:
                list_border_norm = np.expand_dims(list_border_norm, axis=0)
#                print(file_limit)
#manage unmatching files
            if nb_channels != list_border_norm.shape[0]:
                print(
                    "ERROR BAND_NORM FILE DOES NOT MATCH RASTER BAND COUNT :",
                    nb_channels)
                for el in list_border_norm:
                    print(el)
                sys.exit()
        else:  #get border from img if not given (long to calculate for big image)
            for i in range(1, nb_channels + 1):
                list_border_norm.append(
                    ost.getBorderOutliers(ds.GetRasterBand(i).ReadAsArray(),
                                          lower=2,
                                          upper=98))

        #display border found
        for i in range(1, nb_channels + 1):
            print("Band {} normailzed to : {:7.2f}-{:7.2f}".format(
                i, list_border_norm[i - 1][0], list_border_norm[i - 1][1]))

    #create ou tiles folders and delete existant ones
    if args.create_tile:
        shutil.rmtree(outAP + "/tiles", ignore_errors=True)
        out_tiles = ost.createDir(outAP + "/tiles")
        output_filename = '/tile_'
        names_list_img = []

    if args.create_gt:
        shutil.rmtree(outAP + "/tiles_gt", ignore_errors=True)
        out_gt = ost.createDir(outAP + "/tiles_gt")
        output_filename_gt = '/gt_'
        names_list_gt = []

    #building tilling instructions
    instruc_list = []
    count = 0
    for i in range(0, xsize, tile_size_x):
        for j in range(0, ysize, tile_size_y):
            instructions = [None] * 5  #to have a stable form

            #img tiles instructions
            img_tile = None
            if args.create_tile:
                img_tile = out_tiles + output_filename + str(
                    count) + "_" + str(i) + "_" + str(j) + ".tif"
                instructions[2] = list_border_norm
            #gt tiles instructions
            gt_tile = None
            if args.create_gt:
                gt_tile = out_gt + output_filename_gt + str(count) + "_" + str(
                    i) + "_" + str(j) + ".tif"
            #general instructions
            instructions[0] = args
            instructions[1] = count
            instructions[3] = [
                i, j, tile_size_x, tile_size_y, imgAP, gtAP, img_tile, gt_tile
            ]

            #store tiles names
            if args.create_tile:
                names_list_img.append(ost.pathLeafExt(img_tile))
            if args.create_gt:
                names_list_gt.append(ost.pathLeafExt(gt_tile))

            #append instructions
            instruc_list.append(instructions)
            count += 1

    #generating tiles with instruction
    print("Generating tiles...")
    list_discarded = []
    with multiprocessing.Pool() as pool:
        for dis in tqdm(pool.imap_unordered(multiprocessing_func,
                                            instruc_list),
                        total=len(instruc_list)):
            if dis is not False:
                list_discarded.append(dis)

    #register discarded tiles
    if args.discard is True:
        dis_file = outAP + "/discarded_tiles.txt"
        np.savetxt(dis_file, list_discarded, fmt='%s')

#delete GT tiles if no img tiles and discard_file is given
    if args.discard_file is not False and args.create_tile is False:
        print(
            "Suppressing tiles according to discard file :",
            args.discard_file,
        )
        list_discarded = np.genfromtxt(args.discard_file, dtype='str')
        for (ti, tg) in list_discarded:
            try:
                os.remove(out_gt + "/" + tg)  #delete
            except:
                pass

    #rename tiles if discard operation to be sure to have all tiles with number from 0 to n without missing ones
    #renaming is vital for selecter algorithm
    if args.discard or args.discard_file is not False:
        print("Renaming tiles...")
        if args.create_tile:
            to_rename_img = ost.getFileByExt(out_tiles, ".tif")
            root = ost.pathBranch(ost.pathBranch(to_rename_img[0]))
            for i, name in tqdm(enumerate(to_rename_img)):
                old_name = ost.pathLeaf(name)
                name_split = old_name.split("_")
                x_name = name_split[2]
                y_name = name_split[3]
                new_name_img = root + "/tiles/tile_" + str(i) + "_" + str(
                    x_name) + "_" + str(y_name) + ".tif"
                os.rename(name, new_name_img)
            names_list_img = ost.getFileByExt(out_tiles, ".tif")

        if args.create_gt:
            to_rename_gt = ost.getFileByExt(out_gt, ".tif")
            root = ost.pathBranch(ost.pathBranch(to_rename_gt[0]))
            for i, name in tqdm(enumerate(to_rename_gt)):
                old_name = ost.pathLeaf(name)
                name_split = old_name.split("_")
                x_name = name_split[2]
                y_name = name_split[3]
                new_name_gt = root + "/tiles_gt/gt_" + str(i) + "_" + str(
                    x_name) + "_" + str(y_name) + ".tif"
                os.rename(name, new_name_gt)
            names_list_gt = ost.getFileByExt(out_gt, ".tif")

    #remove all stats files generated for mono bucket analysis in multiprocessing_func
    if args.create_tile:
        [os.remove(a) for a in ost.getFileByExt(out_tiles, ".aux.xml")]

    #writting h5 file of metadata and delete already existing ones
    print("Writting h5 file...", end='')
    if args.create_tile:
        try:
            os.remove(meta_img)
        except:
            pass
        with h5.File(meta_img, "w") as f:
            f.attrs["base_img"] = imgAP
            f.attrs["len"] = len(names_list_img)

            f.attrs["margin"] = args.margin
            f.attrs["channels"] = nb_channels
            f.attrs["norm_border"] = list_border_norm
            a = [ost.pathLeafExt(a) for a in names_list_img]
            f.create_dataset("tile_path",
                             data=np.array(a,
                                           dtype=h5.special_dtype(vlen=str)))
    if args.create_gt:
        try:
            os.remove(meta_gt)
        except:
            pass
        with h5.File(meta_gt, "w") as f:
            f.attrs["base_img"] = imgAP
            f.attrs["len"] = len(names_list_gt)

            f.attrs["info"] = str_rep
            f.attrs["names"] = names
            f.attrs["class_weight"] = propPercent / 100
            a = [ost.pathLeafExt(a) for a in names_list_gt]
            f.create_dataset("gt_path",
                             data=np.array(a,
                                           dtype=h5.special_dtype(vlen=str)))
    print(" done.")

    #creating data selection folder and shapefiles
    if args.create_select:
        current = os.getcwd()  #save current console position
        os.chdir(
            outAP
        )  #move current console position to outAP to save instruction lenght
        print("Generating index shape of tiles...", end='')
        shutil.rmtree(outAP + "/data_selection", ignore_errors=True)
        out_selec = ost.createDir(outAP + "/data_selection")
        out_shp_index = ost.pathLeaf(out_selec) + "/all_index.shp"
        try:
            num = ost.getFileByExt(out_tiles, ".tif")
        except:
            num = ost.getFileByExt(out_gt, ".tif")
        gdalTIndex_instruc = [
            "gdaltindex", "-f", "ESRI Shapefile", "-t_srs", "EPSG:2154",
            out_shp_index, *num
        ]
        subprocess.call(gdalTIndex_instruc, stdout=subprocess.DEVNULL)
        os.chdir(current)  #move current console position orignel one

        #copy all into to create others index
        list_TIndex = ost.getFileBySubstr(out_selec, "all_index")
        index_names = ["test_index", "val_index", "train_index"]
        for n in index_names:
            for a in list_TIndex:
                shutil.copy(a, out_selec + "/" + n + ost.pathExt(a))
        print(" done.")

    return 0
Beispiel #11
0
def multiprocessing_func(mlparg):
    #function to create an individual couple of image and GT tiles
    #launched by pool so can only take 1 arg which is a list of the arg you need.
    #extract args
    discard = mlparg[0].discard
    marge = mlparg[0].margin
    nodata_value = mlparg[0].nodata_value
    verif_nodata = mlparg[0].verif_nodata

    list_border_norm = mlparg[2]
    ulx, uly, tsx, tsy, imgAP, gtAP, img_tile, gt_tile = mlparg[3]

    #case of deletion of empty tiles
    discarded = False

    #create img tile
    if img_tile is not None:
        srcwin = [ulx - marge, uly - marge, tsx + 2 * marge, tsy + 2 * marge]
        opts = gdal.TranslateOptions(format="GTIFF",
                                     outputType=gdal.GDT_Float32,
                                     srcWin=srcwin)
        out_ras = gdal.Translate(img_tile, imgAP, options=opts)

        #verif if not img = 0 on all pix of all band
        if discard:
            tot = 0
            for i in range(1, out_ras.RasterCount + 1):
                #fast way : calculate hist with large bucket so all pix fall in but not nodata pix so if only no data hist = 0
                # WARNING : IMG SHOULD HAVE A NODATA VALUE SET
                # this generate .aux.xml files
                hist = out_ras.GetRasterBand(i).GetHistogram(min=1,
                                                             max=65536,
                                                             buckets=1,
                                                             approx_ok=False)
                tot += np.sum(hist)
            if tot == 0:
                discarded = True


#            print(tot,file)

#discard or normalize and produce tf
        if discarded:
            out_ras = None  # Write to disk.
            os.remove(img_tile)  #delete
        else:
            #TF is a mask array to know where to create GT
            tf = np.full((out_ras.GetRasterBand(1).YSize,
                          out_ras.GetRasterBand(1).XSize), 0)
            for i in range(1, out_ras.RasterCount + 1):
                a = out_ras.GetRasterBand(i).ReadAsArray()
                tf += np.array(
                    a, dtype=bool
                )  #store as bool aray where img is at 0 on each channel
                #                print(tf)
                out_ras.GetRasterBand(i).WriteArray(
                    ost.normalize(a,
                                  force=True,
                                  mini=list_border_norm[i - 1][0],
                                  maxi=list_border_norm[i - 1][1]))
            out_ras.FlushCache()  # Write to disk.

    #if we have discarded the img, don't create gt
    if not discarded and gt_tile is not None:
        srcwin = [ulx, uly, tsx, tsy]
        #if margin higher than 0, don't put GT on margin to avoid DATA LEAKING beetween train/validation/test sets
        if marge != 0:
            file_temp = "/vsimem/" + ost.pathBranch(
                gt_tile) + "/" + ost.pathLeaf(gt_tile) + "temp.tif"
            opts = gdal.TranslateOptions(format="GTiff",
                                         outputType=gdal.GDT_Byte,
                                         srcWin=srcwin,
                                         noData=nodata_value)
            out_ds = gdal.Translate(file_temp, gtAP, options=opts)

            srcwin = [-marge, -marge, tsx + 2 * marge, tsy + 2 * marge]
            opts = gdal.TranslateOptions(format="GTiff",
                                         outputType=gdal.GDT_Byte,
                                         srcWin=srcwin,
                                         noData=nodata_value)
            out_ds = gdal.Translate(gt_tile, out_ds, options=opts)
        else:
            opts = gdal.TranslateOptions(format="GTiff",
                                         outputType=gdal.GDT_Byte,
                                         srcWin=srcwin,
                                         noData=nodata_value)
            out_ds = gdal.Translate(gt_tile, gtAP, options=opts)

        #final check to suppress GT on pix that have 0 value on each band thanks to TF array
        if verif_nodata:
            a = np.where(np.array(tf, dtype=bool),
                         out_ds.GetRasterBand(1).ReadAsArray(), nodata_value)
            out_ds.GetRasterBand(1).WriteArray(a)

        out_ds.FlushCache()

    if discarded:
        return ost.pathLeafExt(img_tile), ost.pathLeafExt(gt_tile)
    else:
        return False
Beispiel #12
0
def full_train(args,cv,outIter):
    #function for one iteration of crossvalidation : training and evaluation of one model given cv data repartition
    print("Iteration file is ",outIter)
#    np.save(outIter+"/repartition_dalles.npy",cv) #saving data repartition of current CV iter
    
    #getting channels number to still match older version without channels data in meta_img
    try :
        channels=readH5Channels(args.meta_img)
        print("Image channels =", channels)
    except : 
        channels=4
        print("WARNING ! : OLD META_IMG : guessed 4 channels image")
    
    #initialize the model
    model=nb.OLIVENET(channels,len(args.c))
    print('Total number of parameters: {}'.format(sum([p.numel() for p in model.parameters()])))
    
    #load model on GPU if CUDA
    if args.cuda:
        model.cuda()
    
    #define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    #Learning Rate decay //!\\ LAST ARGUMENT HARD CODED
    plan=[i for i in range(50,300,50)]
    gamma=0.7
    print('  |  '.join('{}: {:.2e}'.format(pl, lr) for pl, lr in zip(plan,[args.lr*gamma**i for i in range(len(plan))])))
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=plan, gamma=gamma)
    
    #Case of pretrained model, load state
    if args.state is not False :
        print("TRANSFERT LEARNING : loading of :\n",args.state)
        checkpoint = torch.load(args.state)
        model.load_state_dict(checkpoint['state_dict'])
    
    #get the class weights from meta h5
    w=readH5Cweight(args.meta_gt)
    w=[1/np.sqrt(i) if i!=0 else 0 for i in w]
    w=w/np.linalg.norm(w)
    w=torch.from_numpy(w.astype(np.float32))
    print("Class weights:", [args.c[i]+": {:.4f} ".format(weight) for i,weight in enumerate(w)])
    
    #create the perf log saving class
    outLog=ost.createDir(outIter+"/perf_log")
    save_lr=[]
    perf=met.SavePerf(args.e, outLog,args.c) #PERF OBJECT : to add result be sure in which dataset you put the result in
    outState=ost.createDir(outIter+"/states")
    
    #saving data repartition of current CrossValidation iteration will be used by Drawer algorithm
    with open(outIter+"/crossValID.pckl",'wb') as fp:
        pickle.dump(cv[0],fp)
        pickle.dump(cv[1],fp)
        pickle.dump(cv[2],fp)
    
    #create data sets // the loader function will take care of the batching
    #suffle=True and Droplast=true only for train loader
    #Batch size * 3 because no gradient stored on GPU so more memory available for quicker inference
    train_loader=data.DataLoader(reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt,cv[0], args.aug, (200,200)) \
                           , batch_size=args.bs,shuffle=True, drop_last=True, num_workers=args.num_workers)
    val_loader=data.DataLoader(reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt, cv[1]) \
                           , batch_size=args.bs*3,shuffle=False, drop_last=False, num_workers=args.num_workers)
    test_loader=data.DataLoader(reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt, cv[2]) \
                           , batch_size=args.bs*3,shuffle=False, drop_last=False, num_workers=args.num_workers)
    
    #epoch iterations [try and except KeyboardInterrupt] to allow ctrl+c interruption
    try:
        for i_epoch in range(args.e):
            #periodic metric calculation and model saving
            if i_epoch % args.mem == 0 or i_epoch == args.e-1: #if last epoch too
                print("epoch:",i_epoch, end='', flush=True)
                
                #train with train data and register train metrics
                cm, train_loss = learningPhase(args, model, train_loader, w, optimizer, True)
                perf.addFullResults(i_epoch,cm,train_loss,0)#0 for TRAIN dataset
                
                # evaluate and suppress gradient calcution for these steps
                with torch.no_grad():
                    #validation data and register validation metrics and get FLAG OF BEST EPOCH
                    cm, loss = evalutionPhase(args, model, val_loader, w)
                    is_best=perf.addFullResults(i_epoch,cm,loss,1)#0 for VAL dataset
                    
                    #test data and register test metrics [not required for train]
                    cm, loss = evalutionPhase(args, model, test_loader,w)
                    perf.addFullResults(i_epoch,cm,loss,2)#2 for TEST dataset
                
                #save model characteristics and states
                save_checkpoint({
                    'epoch' : i_epoch,
                    'channels' : channels,
                    'names' : args.c,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                    }, is_best, outState)
            
            else :#just train without calculating metrics
                print("epoch:",i_epoch, end='', flush=True)
                __, train_loss= learningPhase(args, model, train_loader, w, optimizer, False)
                perf.addLossResults(i_epoch,train_loss,0)#0 for TRAIN dataset
                
            #increment learning rate step and save current value
            scheduler.step()
            save_lr.append([ group['lr'] for group in optimizer.param_groups ][0])
    except KeyboardInterrupt:
        pass
    
    #print performance results
    print("Best epoch:",perf.best)
    perf.printResults()#print graph
    perf.printResultsTxt()#print values in txt
    
    #plot learning rate graph and save values in txt
    ost.plotLogGraph(np.arange(len(save_lr)),save_lr,'lr','epochs','Lr','Learning rate decay according to epochs'\
                     ,outLog+'/lr_decay.tif',show=0)
    df=pd.DataFrame({'epochs':np.arange(len(save_lr)),'lr':save_lr})
    df.to_csv(outLog+"/lr_decay.csv", sep=";", index=False)
    
    return model
Beispiel #13
0
def selecter(args):
    #MAIN FUNCTION FOR ARGS LOOK AT PARSER
    #RETURN : [list of list of int] cross validation iteration list of dataset (train/val/test) list of index of tile couple

    #order = test,train,val by alphabetic order
    order_good = [
        2, 3, 1
    ]  #to read shapefiles in train/val/test order || 0 is all_index and is skipped
    data_lenght = readH5Length(args.meta_gt)
    selected = [[] for i in order_good
                ]  # init container of container of locked tiles

    #read shapefiles to see if some tiles couple will be locked some dataset all along CV
    try:
        shps = ost.getFileByExt(args.selection, "shp")  #get shapefiles
        print("\nReading selection shapefiles:")
        for i, order in enumerate(order_good):
            driver = ogr.GetDriverByName("ESRI Shapefile")
            dataSource = driver.Open(shps[order], 0)
            print("\t", shps[order])
            layer = dataSource.GetLayer()
            featureCount = layer.GetFeatureCount()
            if featureCount < data_lenght:  #if files are full, ignore them
                for feature in layer:
                    name = ost.pathLeaf(
                        feature.GetField("location"))  # retrieve tile name
                    #extract tile number
                    index = 0
                    tile_num = None
                    while tile_num is None:
                        try:
                            tile_num = int(
                                ost.pathLeaf(name).split("_")[index])


#                            print(index)
                        except:
                            index += 1
                            pass
                    if tile_num is None:
                        print(
                            "\tWARNING/!\: Error in tile numerotation, was not able to get the tile number"
                        )
                    else:
                        selected[i].append(tile_num)
            layer.ResetReading()  # assure reset layer reading data variable

        print("\nSelection found :\n\tTrain samples:", len(selected[0]),
              "\n\tVal samples:  ", len(selected[1]), "\n\tTest samples: ",
              len(selected[2]))
    except:
        print(
            "\tWARNING/!\: Path to data selection shapefiles directory not found"
        )
        pass

    #load/set seed for random data repartition [if new seed, will be saved]
    if args.seed is False:
        seedsave = ost.checkNIncrementLeaf(args.out + "/seed.dump")
        with open(seedsave, 'wb') as fp:
            prng = np.random.RandomState()
            pickle.dump(prng, fp)
    else:
        with open(args.seed, 'rb') as fp:
            prng = pickle.load(fp)

    #cross validation creation
    cross = crossValSpliterSelection(prng, data_lenght, args.cvSplit, selected)
    #    print(cross)
    return cross