def plotcm(args): cm_file = args.cm conf_arr = np.loadtxt(cm_file, "int", delimiter=" ") print(conf_arr) if args.class_names is False: classes = [str(i) for i in range(conf_arr.shape[0])] else: classes = args.class_names norm_conf = [] for i in conf_arr: a = 0 tmp_arr = [] a = sum(i, 0) for j in i: tmp_arr.append(float(j) / float(a)) norm_conf.append(tmp_arr) norm_conf = np.array(norm_conf) # norm_diag=[[np.nan if x!=y else norm_conf[x][y] for x in range(norm_conf.shape[1])] for y in range(norm_conf.shape[0])] # print(norm_conf) # print(norm_diag) fig = plt.figure(figsize=(8.5, 7.5)) plt.clf() ax = fig.add_subplot(111) ax.set_aspect(1) res = ax.imshow(norm_conf, cmap='RdYlBu_r', interpolation='nearest') # res = ax.imshow(norm_diag, cmap='RdYlGn', # interpolation='nearest') width, height = conf_arr.shape if args.num: for x in range(width): for y in range(height): ax.annotate("{:.2f}".format(norm_conf[x][y] * 100), xy=(y, x), horizontalalignment='center', verticalalignment='center', fontsize=args.ft_size, rotation=0) if args.class_names is False: rot = 0 else: rot = 35 cb = fig.colorbar(res, fraction=0.046, pad=0.04) plt.xticks(range(width), classes[:width], rotation=rot) cb.ax.tick_params(labelsize=args.ft_size) plt.setp(ax.get_xticklabels(), fontsize=args.ft_size) plt.setp(ax.get_yticklabels(), fontsize=args.ft_size) plt.yticks(range(height), classes[:height]) plt.savefig(ost.pathBranch(cm_file) + '/confusion_matrix.png', format='png') plt.savefig(ost.pathBranch(cm_file) + '/confusion_matrix.eps', format='eps')
def crossTraining(args): #main function : manage cross validation iterations and plan its data repartition outAP=os.path.abspath(args.out) #load class names into args to be easily used everywhere args.c = readH5Class(args.meta_gt) #Planning of data repartition crossIndex=selec.selecter(args) iterartion = args.cvIter if args.cvIter<=args.cvSplit else args.cvSplit #print CV Planning print("Split: ", args.cvSplit, end="") if args.cvIter>=args.cvSplit : print(", Iter: ",args.cvIter," REDUCED TO ",iterartion) else : print(", Iter: ",args.cvIter) for i in range(iterartion): print("Iter "+str(i)+": train="+str(len(crossIndex[i][0]))+" val="+str(len(crossIndex[i][1])) \ +" test="+str(len(crossIndex[i][2]))) #cross validation loop for i in range(args.cvIter): outIter=ost.createDirIncremental(outAP+"/stepCrossVal", 0) model = full_train(args,crossIndex[i],outIter) return model
def addLossResults(self, i_epoch, loss, dataset_instance, show=1): #function to save loss and print it self.list_loss[dataset_instance].append(loss) self.list_loss_register[dataset_instance].append(i_epoch) if show: #color given the dataset instance if dataset_instance == 0: color = (19, 161, 14) elif dataset_instance == 1: color = (193, 156, 0) else: color = (204, 0, 0) print(ost.PRINTCOLOR(*color), end="") #set color text = self.dataset_names[dataset_instance] print('Epoch %3d -> %s Loss: %1.6f' % (i_epoch, text, loss)) print(ost.RESETCOLOR, end="") #reset color
def main(args): #function to rebuild full image as TIF or VRT recursively from all small tiles #TIF version is almost deprecated #VRT is way quicker #get folders and sort them by depth if recursive stitching if args.rec: l = ost.sortFoldersByDepth(*ost.checkFoldersWithDepth(args.dir)) else: l = [args.dir] #display founded folders print([el[60:] for el in l]) #recursive stitching for path in tqdm(l): path = os.path.abspath(path) current = os.getcwd() #save current console position os.chdir( ost.pathBranch(path) ) #move current console position to outAP to save instruction lenght #stitch current folder tiles #"ulimit -s 65536;" is to avoid command line lenght restrictions" if args.vrt: out = path + "_merged.vrt" subprocess.call("ulimit -s 65536; gdalbuildvrt -vrtnodata " + args.nodata + " " + out + " " + path + "/*.tif", shell=True) else: out = path + "_merged.tif" subprocess.call("ulimit -s 65536; gdal_merge.py -o " + out + " -n " + args.nodata + " -a_nodata " + args.nodata + " " + path + "/*.tif", shell=True) os.chdir(current) #move current console position orignel one return 0
def addFullResults(self, i_epoch, cm, loss, dataset_instance, show=1): #function to save loss and other metric from CM #RETURN : IS_BEST only if dataset instance is 1 (VALIDATION set) #ISBEST is TRUE if current epoch got best mIoU ever #color given the dataset instance if dataset_instance == 0: color = (19, 161, 14) elif dataset_instance == 1: color = (193, 156, 0) else: color = (204, 0, 0) #get dataset instance name (train/val/test) text = self.dataset_names[dataset_instance] print(ost.PRINTCOLOR(*color), end="") #set color mIoU, ious = cm.class_IoU(show) #calcul iou and miou and display IoU oa = cm.overall_accuracy() #calcul OA #display general metrics if show: print('Epoch %3d -> %s Overall Accuracy: %3.2f%% %s mIoU : %3.2f%% %s Loss: %1.6f' \ % (i_epoch, text, oa,text, mIoU,text, loss)) print(ost.RESETCOLOR, end="") #reset color #save metric values in containers self.list_cm[dataset_instance].append(cm) self.list_loss[dataset_instance].append(loss) self.list_mIoU[dataset_instance].append(mIoU) self.list_ious[dataset_instance].append(ious) self.list_oa[dataset_instance].append(oa) self.list_register[dataset_instance].append(i_epoch) self.list_loss_register[dataset_instance].append(i_epoch) #check if current epoch is best epoch by comparing mIoU isBest = True if mIoU >= max( self.list_mIoU[dataset_instance]) else False if isBest and dataset_instance == 1: #verif if we really are in VALIDATION DATASET INSTANCE self.best = i_epoch self.best_id = self.list_mIoU[dataset_instance].index(mIoU) return isBest
def multiprocessing_func(mpArg): #function to produce required inference product for one tile #extract arguments prediction = mpArg[0] gt = mpArg[1] base_name = mpArg[2] outs = mpArg[3] args = mpArg[4] colors = mpArg[5] #take arg max of prediction tensor for dif and inf product if args.inf or args.dif: pred = prediction.argmax(0).squeeze().numpy() #cut prediction if margin if args.margin != 0: pred = pred[args.margin:pred.shape[0] - args.margin, args.margin:pred.shape[1] - args.margin] gt = gt[args.margin:gt.shape[0] - args.margin, args.margin:gt.shape[1] - args.margin] #produce semantic segmentation map if args.inf: path = outs[0] + "/inf_" + ost.pathLeaf(base_name) + ".tif" ost.array2raster(pred, path, base_name, rasterType='GTiff', datatype=gdal.GDT_Byte, noDataValue=args.nodata, colors=colors, margin=args.margin) #produce difference map between GT and prediction arg max map if args.dif: colors_dif = {0: (0, 0, 0, 0), 1: (255, 0, 0, 255)} dif = ost.arrayDif(gt, pred, args.nodata) path = outs[1] + "/dif_" + ost.pathLeaf(base_name) + ".tif" ost.array2raster(dif, path, base_name, rasterType='GTiff', datatype=gdal.GDT_Byte, noDataValue=0, colors=colors_dif, margin=args.margin) #produce proba maps (1 image for each class) if args.proba: #convert logits to probabilities with softmax pred_soft = nn.functional.softmax(prediction, 0) #iterate on each class (output dir) for i, o in enumerate(outs[2]): path = o + "/proba_" + ost.pathLeaf(base_name) + ".tif" c = pred_soft[i].numpy() #cut prediction if margin if args.margin != 0: c = c[args.margin:c.shape[0] - args.margin, args.margin:c.shape[1] - args.margin] ost.array2raster(c, path, base_name, rasterType='GTiff', datatype=gdal.GDT_Float32, margin=args.margin) #produce proba maps in h5 format (saved as full 3D tensor) if args.probaH5: path = outs[3] + "/proba_" + ost.pathLeaf(base_name) + ".h5" # print(path) with h5.File(path, "w") as f: c = prediction.numpy() #cut margin and save f.create_dataset("pred", data=c[:, args.margin:c.shape[1] - args.margin, args.margin:c.shape[2] - args.margin])
def main(args): #MAIN FUNCTION FOR ARGS LOOK AT PARSER #RETURN : NOTHING ! HA ! but create all inference data in output dir #manage unmatching args (there may exist cleverer solution, it was implemented to bu simple) if not args.inf and not args.dif and not args.proba and not args.probaH5: print( "ERROR : WAKE UP !! ... select at least one things to do (inf/dif/proba/probaH5)" ) return 0 if args.dir_gt is False and args.meta_gt is False: args.noGT = True #if no gt noGT is true elif args.dir_gt is not False and args.meta_gt is False: print("ERROR : Either set -dir_gt to False or give meta_gt path") return 0 elif (args.dir_gt is False and args.meta_gt is not False): print("ERROR : Either set -meta_gt to False or give dir_gt path") return 0 else: args.noGT = False if args.noGT and args.dif: print( "ERROR : You can't get dif between Inf and GT without GT ... Either set -noGT to False or -dif to False" ) return 0 if args.noGT and args.metric: print( "ERROR : You can't get metric without GT ... Either set -noGT to False or -metric to False" ) return 0 #create output dir args.out = ost.createDir(args.out + "/drawings") #if data repartition file given, create separate train/val/test dataset reader DS_list = [] if args.cv is not False: args.train_set = True with open(args.cv, 'rb') as fp: train = pickle.load(fp) val = pickle.load(fp) test = pickle.load(fp) #create data sets DS_list.append( reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt, train, noGT=args.noGT)) DS_list.append( reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt, val, noGT=args.noGT)) DS_list.append( reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt, test, noGT=args.noGT)) else: DS_list.append( reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt, noGT=args.noGT)) #create list of data loader given created data reader loader_list = [ data.DataLoader(ds, batch_size=args.bs, shuffle=False, drop_last=False, num_workers=args.num_workers) for ds in DS_list ] #load model checkpoint = torch.load(args.state) print("Best epoch: ", checkpoint['epoch']) try: channels = checkpoint['channels'] args.c = checkpoint['names'] print("Model channels :", channels, "\nModel class :", args.c) except: print("ERROR : OLD MODEL, PLEASE ACTUALIZE") sys.exit() model = nb.OLIVENET(channels, len(args.c)) model.load_state_dict(checkpoint['state_dict']) model.cuda() model.eval() # model in training mode #create output dir regarding to asked operations if args.inf: ost.createDir(args.out + "/inf") if args.dif: ost.createDir(args.out + "/dif") if args.proba: ost.createDir(args.out + "/proba") if args.probaH5: ost.createDir(args.out + "/probaH5") #inference loop through all data loader created for i, loader in enumerate(loader_list): inference(i, model, loader, args)
def inference(ind, model, loader, args): #function to produce inference data for one data loader #ind : loop position #model : explicit #loader : current data loader #args : look at parser args #create all subfolder if train/val/test split, else don't create new subfolders outs = ["", "", [], []] name = ["train", "val", "test"] #need to match the order of dataloader if args.inf: outs[0] = ost.createDir( args.out + "/inf/" + name[ind]) if args.train_set else args.out + "/inf" if args.dif: outs[1] = ost.createDir( args.out + "/dif/" + name[ind]) if args.train_set else args.out + "/dif" if args.proba: outHM = [ost.createDir(args.out + "/proba/" + c) for c in args.c] outs[2] = [ost.createDir(p + "/" + name[ind]) for p in outHM] if args.train_set else outHM if args.probaH5: outs[3] = ost.createDir( args.out + "/probaH5/" + name[ind]) if args.train_set else args.out + "/probaH5" # print(outs[3]) #get colors as dict for colortable if args.color is not False: c = np.loadtxt(args.color, delimiter=",", dtype=int) colors = {} for i in range(c.shape[0]): colors[int(c[i][0])] = (tuple(c[i][1:5])) else: colors = None #metric containers if args.metric: cm = m.ConfusionMatrix(len(args.c), args.c, args.nodata) #loop through batch given by data reader, unfold tuple for batch_ndx, (imgs, gt, names) in enumerate(tqdm(loader)): #if CUDA, load on GPU if args.cuda: batch_tensor = imgs.cuda() else: batch_tensor = imgs #generate prediction prediction = model(batch_tensor) prediction = prediction.cpu() #build instruction for multiprocess loop (1 instruction for each tile of the batch) if args.noGT: mpArg = [(prediction[i].detach(), None, names[i], outs, args, colors) for i in range(prediction.shape[0])] else: mpArg = [(prediction[i].detach(), gt[i], names[i], outs, args, colors) for i in range(prediction.shape[0])] #multi process loop to create inference data of the current batch with multiprocessing.Pool() as pool: pool.map(multiprocessing_func, mpArg) #calculate batch metric and add it to metric containers if args.metric: for i in range(prediction.size()[0]): pred = prediction[i].argmax(0).squeeze() cm.add_batch(gt[i].numpy(), pred.numpy()) #free memory del imgs del gt del batch_tensor del prediction #produce current dataset metric to metric output folder if args.metric: out_perf = ost.createDir( args.out + "/" + name[ind] + "_perf_inf") if args.train_set else ost.createDir(args.out + "/perf_inf") cm.printPerf(out_perf) return 0
def gtBuilder_v4(base_img, path, path_img, nodata_value, nodata_mask=False, split_field='', datatype=gdal.GDT_UInt16, arrayBack=False, colors=None): #Create Ground truth image given an base image extent by rasterizing a set of shapefiles from path #Return the class names or if arrayBack is True, return GT as array and class names #WARNING : order of read of shapefiles matters (natural order listing used here) #base_img : img on which groud truth extent is based #path : path of shape file or of folder where shapefiles are stored #path_img : path where to save ground truth image #split_field : deprecated #arrayBack: if true return GT as array and class names #nodata_mask : path to mask where GT will be produce. Values outside polygones will be set to no_data_values #color : color list in "r,g,b" may be deprecated path = os.path.abspath(path) print('image saved at: ', path_img) #manage case where path = shapefile or path = folder of shapefile list_shp = [] if os.path.isdir(path): list_shp = ost.getFileByExt(path, ".shp") elif os.path.isfile(path): list_shp.append(path) else: print('PATH GIVEN IS NOT SHAPE FILE OR DIR') return 0 if len(list_shp) == 0: print('NO SHAPE FILE FOUNDED') return 0 #create GT raster by copying base img metadata base_img = os.path.abspath(base_img) base = gdal.Open(base_img) ras_c = ost.rasterCopy(base, path_img, datatype=datatype, bands=1) ras_c.GetRasterBand(1).SetNoDataValue(nodata_value) #get class names from shapefiles names names = [ost.pathLeaf(p) for p in list_shp] print(names) #if only 1 class, switch to binary mode and create other class and initialize GT with "others" class if len(names) == 1: print("/!\ SWITCH TO CLASSIFICATION BINAIRE...") ras_c.GetRasterBand(1).Fill(1) names.append("others") #rasterize all shapefile in natural name order for i, path_shp in enumerate(list_shp): try: shp = ogr.Open(path_shp) if shp: # checks to see if shapefile was successfully defined print('loading: %s' % (path_shp)) else: # if it's not successfully defined print('COULD NOT LOAD SHAPE: %s' % (path_shp)) except: # Seems redundant, but if an exception is raised in the Open() call, you get a message print( 'EXCEPTION RAISED WHILE LOADING: %s' % (path_shp) ) # if you want to see the full stacktrace - like you are currently getting,# then you can add the following: raise source_layer = shp.GetLayer() gdal.RasterizeLayer(ras_c, [1], source_layer, burn_values=[i]) #+1 #apply color given if colors is not None: cT = gdal.ColorTable() # set color for each for i in range(len(colors)): cT.SetColorEntry(i, colors[i]) cT.SetColorEntry(nodata_value, (0, 0, 0, 0)) #nodata # set color table and color interpretation ras_c.GetRasterBand(1).SetRasterColorTable(cT) ras_c.GetRasterBand(1).SetRasterColorInterpretation( gdal.GCI_PaletteIndex) if arrayBack: gt = ras_c.ReadAsArray() #save raster to have it available for ras_c = None #burn nodata zone with nodata_mask if nodata_mask is not False: print("Applying Nodata_Mask:", nodata_mask) gdal_instruc = [ "gdal_rasterize", "-b", "1", "-i", "-burn", str(nodata_value), nodata_mask, path_img ] subprocess.call( gdal_instruc ) #,stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL) if arrayBack: return gt, names else: return names
def megaTailor(args): #main function for args look at parser #check if gt is given as shp or tif. IF TIF NEED NOMENCLATRUE if args.create_gt: if ost.pathExt(args.gt) == ".tif": gt_is_shp = False print("GT given as tif img") if args.nomenclature is False: print("ERROR : TIF GT NEED NOMENCLATURE") sys.exit() else: gt_is_shp = True print("GT given as shp") #create path variable imgAP = os.path.abspath(args.img) outAP = os.path.abspath(args.out) meta_img = outAP + "/meta_img.h5" meta_gt = outAP + "/meta_gt.h5" #check at least one of gt or tile are activated if not args.create_tile and not args.create_gt: print( "ERROR : WAKE UP !! ... Choose at least one between GT or TILE !!") sys.exit() #extract main img vital information ds = gdal.Open(imgAP) band = ds.GetRasterBand(1) xsize = band.XSize ysize = band.YSize tile_size_x = args.x tile_size_y = args.y lenght_dataset = (xsize // tile_size_x + 1) * (ysize // tile_size_y + 1) gtAP = None #initialize gtAP path variable #work GT if args.create_gt: gtAP = outAP + "/gt.tif" #Create GT img or crop it to imgAP extent if gt_is_shp: #get colors as dict for colortable if args.color is not False: print("Color given by :", args.color) c = np.loadtxt(args.color, delimiter=",", dtype=int) colors = {} for i in range(c.shape[0]): colors[int(c[i][0])] = (tuple(c[i][1:5])) else: colors = None #create gt and get class names from shapefiles names shpAP = os.path.abspath(args.gt) names = gtBuilder_v4(imgAP, shpAP, gtAP, nodata_value=args.nodata_value, nodata_mask=args.nodata_mask, datatype=gdal.GDT_Byte, colors=colors) else: geotransform = ds.GetGeoTransform() originX = geotransform[0] originY = geotransform[3] xres = geotransform[1] yres = geotransform[5] bottomX = originX + xsize * xres bottomY = originY + ysize * yres img_gtAP = os.path.abspath(args.gt) #crop GT TIF given to IMG extent gdal.Translate(gtAP, img_gtAP, projWin=[originX, originY, bottomX, bottomY], xRes=xres, yRes=yres, resampleAlg="nearest", noData=args.nodata_value) #Open newly created gt dsGT = gdal.Open(gtAP) band = dsGT.GetRasterBand(1) #replace names by nomenclature names given if args.nomenclature is not False: print("Class names given by :", args.nomenclature) names_indices = np.loadtxt(args.nomenclature, delimiter=" ", dtype=str) # indices=names_indices[:,0] names = list(names_indices[:, 1]) # print (names) # print (indices) # get class proportion in GT img hist = band.GetHistogram(approx_ok=False) prop = np.array(hist[0:len(names)]) propPercent = 100 * prop / np.sum(prop) totalPix = dsGT.RasterXSize * dsGT.RasterYSize propPercentTotal = 100 * prop / totalPix str_rep='GT size = '+str(args.x)+'*'+str(args.y)+'*'+str(lenght_dataset)+' with '+str(int(np.sum(prop)))+'/'+str(totalPix)+' labeled pixels\n'+\ "Dataset repartition :\n"+\ '{:20} : {:12} || {:12} || {:4}'.format('CLASS','PIX NB', '% IN LABELED ','% IN TOTAL\n')+\ '{:20} : {:12d} || {:12.2f}% || {:12.2f}%\n'.format('no_data',int(totalPix-np.sum(prop)), 0,(1-(np.sum(prop)/totalPix))*100)+\ '\n'.join('{:20} : {:12d} || {:12.2f}% || {:12.2f}%'.format( \ name, int(p), percent,pt) for name, percent,p,pt in zip(names,propPercent,prop,propPercentTotal)) print(str_rep) # get img border if args.create_tile: nb_channels = ds.RasterCount print("Getting img quartiles...") list_border_norm = [] #load band limit if given if args.band_norm is not False: list_border_norm = np.loadtxt(args.band_norm, dtype=float) #expand dim if only 1 band try: list_border_norm.shape[1] except: list_border_norm = np.expand_dims(list_border_norm, axis=0) # print(file_limit) #manage unmatching files if nb_channels != list_border_norm.shape[0]: print( "ERROR BAND_NORM FILE DOES NOT MATCH RASTER BAND COUNT :", nb_channels) for el in list_border_norm: print(el) sys.exit() else: #get border from img if not given (long to calculate for big image) for i in range(1, nb_channels + 1): list_border_norm.append( ost.getBorderOutliers(ds.GetRasterBand(i).ReadAsArray(), lower=2, upper=98)) #display border found for i in range(1, nb_channels + 1): print("Band {} normailzed to : {:7.2f}-{:7.2f}".format( i, list_border_norm[i - 1][0], list_border_norm[i - 1][1])) #create ou tiles folders and delete existant ones if args.create_tile: shutil.rmtree(outAP + "/tiles", ignore_errors=True) out_tiles = ost.createDir(outAP + "/tiles") output_filename = '/tile_' names_list_img = [] if args.create_gt: shutil.rmtree(outAP + "/tiles_gt", ignore_errors=True) out_gt = ost.createDir(outAP + "/tiles_gt") output_filename_gt = '/gt_' names_list_gt = [] #building tilling instructions instruc_list = [] count = 0 for i in range(0, xsize, tile_size_x): for j in range(0, ysize, tile_size_y): instructions = [None] * 5 #to have a stable form #img tiles instructions img_tile = None if args.create_tile: img_tile = out_tiles + output_filename + str( count) + "_" + str(i) + "_" + str(j) + ".tif" instructions[2] = list_border_norm #gt tiles instructions gt_tile = None if args.create_gt: gt_tile = out_gt + output_filename_gt + str(count) + "_" + str( i) + "_" + str(j) + ".tif" #general instructions instructions[0] = args instructions[1] = count instructions[3] = [ i, j, tile_size_x, tile_size_y, imgAP, gtAP, img_tile, gt_tile ] #store tiles names if args.create_tile: names_list_img.append(ost.pathLeafExt(img_tile)) if args.create_gt: names_list_gt.append(ost.pathLeafExt(gt_tile)) #append instructions instruc_list.append(instructions) count += 1 #generating tiles with instruction print("Generating tiles...") list_discarded = [] with multiprocessing.Pool() as pool: for dis in tqdm(pool.imap_unordered(multiprocessing_func, instruc_list), total=len(instruc_list)): if dis is not False: list_discarded.append(dis) #register discarded tiles if args.discard is True: dis_file = outAP + "/discarded_tiles.txt" np.savetxt(dis_file, list_discarded, fmt='%s') #delete GT tiles if no img tiles and discard_file is given if args.discard_file is not False and args.create_tile is False: print( "Suppressing tiles according to discard file :", args.discard_file, ) list_discarded = np.genfromtxt(args.discard_file, dtype='str') for (ti, tg) in list_discarded: try: os.remove(out_gt + "/" + tg) #delete except: pass #rename tiles if discard operation to be sure to have all tiles with number from 0 to n without missing ones #renaming is vital for selecter algorithm if args.discard or args.discard_file is not False: print("Renaming tiles...") if args.create_tile: to_rename_img = ost.getFileByExt(out_tiles, ".tif") root = ost.pathBranch(ost.pathBranch(to_rename_img[0])) for i, name in tqdm(enumerate(to_rename_img)): old_name = ost.pathLeaf(name) name_split = old_name.split("_") x_name = name_split[2] y_name = name_split[3] new_name_img = root + "/tiles/tile_" + str(i) + "_" + str( x_name) + "_" + str(y_name) + ".tif" os.rename(name, new_name_img) names_list_img = ost.getFileByExt(out_tiles, ".tif") if args.create_gt: to_rename_gt = ost.getFileByExt(out_gt, ".tif") root = ost.pathBranch(ost.pathBranch(to_rename_gt[0])) for i, name in tqdm(enumerate(to_rename_gt)): old_name = ost.pathLeaf(name) name_split = old_name.split("_") x_name = name_split[2] y_name = name_split[3] new_name_gt = root + "/tiles_gt/gt_" + str(i) + "_" + str( x_name) + "_" + str(y_name) + ".tif" os.rename(name, new_name_gt) names_list_gt = ost.getFileByExt(out_gt, ".tif") #remove all stats files generated for mono bucket analysis in multiprocessing_func if args.create_tile: [os.remove(a) for a in ost.getFileByExt(out_tiles, ".aux.xml")] #writting h5 file of metadata and delete already existing ones print("Writting h5 file...", end='') if args.create_tile: try: os.remove(meta_img) except: pass with h5.File(meta_img, "w") as f: f.attrs["base_img"] = imgAP f.attrs["len"] = len(names_list_img) f.attrs["margin"] = args.margin f.attrs["channels"] = nb_channels f.attrs["norm_border"] = list_border_norm a = [ost.pathLeafExt(a) for a in names_list_img] f.create_dataset("tile_path", data=np.array(a, dtype=h5.special_dtype(vlen=str))) if args.create_gt: try: os.remove(meta_gt) except: pass with h5.File(meta_gt, "w") as f: f.attrs["base_img"] = imgAP f.attrs["len"] = len(names_list_gt) f.attrs["info"] = str_rep f.attrs["names"] = names f.attrs["class_weight"] = propPercent / 100 a = [ost.pathLeafExt(a) for a in names_list_gt] f.create_dataset("gt_path", data=np.array(a, dtype=h5.special_dtype(vlen=str))) print(" done.") #creating data selection folder and shapefiles if args.create_select: current = os.getcwd() #save current console position os.chdir( outAP ) #move current console position to outAP to save instruction lenght print("Generating index shape of tiles...", end='') shutil.rmtree(outAP + "/data_selection", ignore_errors=True) out_selec = ost.createDir(outAP + "/data_selection") out_shp_index = ost.pathLeaf(out_selec) + "/all_index.shp" try: num = ost.getFileByExt(out_tiles, ".tif") except: num = ost.getFileByExt(out_gt, ".tif") gdalTIndex_instruc = [ "gdaltindex", "-f", "ESRI Shapefile", "-t_srs", "EPSG:2154", out_shp_index, *num ] subprocess.call(gdalTIndex_instruc, stdout=subprocess.DEVNULL) os.chdir(current) #move current console position orignel one #copy all into to create others index list_TIndex = ost.getFileBySubstr(out_selec, "all_index") index_names = ["test_index", "val_index", "train_index"] for n in index_names: for a in list_TIndex: shutil.copy(a, out_selec + "/" + n + ost.pathExt(a)) print(" done.") return 0
def multiprocessing_func(mlparg): #function to create an individual couple of image and GT tiles #launched by pool so can only take 1 arg which is a list of the arg you need. #extract args discard = mlparg[0].discard marge = mlparg[0].margin nodata_value = mlparg[0].nodata_value verif_nodata = mlparg[0].verif_nodata list_border_norm = mlparg[2] ulx, uly, tsx, tsy, imgAP, gtAP, img_tile, gt_tile = mlparg[3] #case of deletion of empty tiles discarded = False #create img tile if img_tile is not None: srcwin = [ulx - marge, uly - marge, tsx + 2 * marge, tsy + 2 * marge] opts = gdal.TranslateOptions(format="GTIFF", outputType=gdal.GDT_Float32, srcWin=srcwin) out_ras = gdal.Translate(img_tile, imgAP, options=opts) #verif if not img = 0 on all pix of all band if discard: tot = 0 for i in range(1, out_ras.RasterCount + 1): #fast way : calculate hist with large bucket so all pix fall in but not nodata pix so if only no data hist = 0 # WARNING : IMG SHOULD HAVE A NODATA VALUE SET # this generate .aux.xml files hist = out_ras.GetRasterBand(i).GetHistogram(min=1, max=65536, buckets=1, approx_ok=False) tot += np.sum(hist) if tot == 0: discarded = True # print(tot,file) #discard or normalize and produce tf if discarded: out_ras = None # Write to disk. os.remove(img_tile) #delete else: #TF is a mask array to know where to create GT tf = np.full((out_ras.GetRasterBand(1).YSize, out_ras.GetRasterBand(1).XSize), 0) for i in range(1, out_ras.RasterCount + 1): a = out_ras.GetRasterBand(i).ReadAsArray() tf += np.array( a, dtype=bool ) #store as bool aray where img is at 0 on each channel # print(tf) out_ras.GetRasterBand(i).WriteArray( ost.normalize(a, force=True, mini=list_border_norm[i - 1][0], maxi=list_border_norm[i - 1][1])) out_ras.FlushCache() # Write to disk. #if we have discarded the img, don't create gt if not discarded and gt_tile is not None: srcwin = [ulx, uly, tsx, tsy] #if margin higher than 0, don't put GT on margin to avoid DATA LEAKING beetween train/validation/test sets if marge != 0: file_temp = "/vsimem/" + ost.pathBranch( gt_tile) + "/" + ost.pathLeaf(gt_tile) + "temp.tif" opts = gdal.TranslateOptions(format="GTiff", outputType=gdal.GDT_Byte, srcWin=srcwin, noData=nodata_value) out_ds = gdal.Translate(file_temp, gtAP, options=opts) srcwin = [-marge, -marge, tsx + 2 * marge, tsy + 2 * marge] opts = gdal.TranslateOptions(format="GTiff", outputType=gdal.GDT_Byte, srcWin=srcwin, noData=nodata_value) out_ds = gdal.Translate(gt_tile, out_ds, options=opts) else: opts = gdal.TranslateOptions(format="GTiff", outputType=gdal.GDT_Byte, srcWin=srcwin, noData=nodata_value) out_ds = gdal.Translate(gt_tile, gtAP, options=opts) #final check to suppress GT on pix that have 0 value on each band thanks to TF array if verif_nodata: a = np.where(np.array(tf, dtype=bool), out_ds.GetRasterBand(1).ReadAsArray(), nodata_value) out_ds.GetRasterBand(1).WriteArray(a) out_ds.FlushCache() if discarded: return ost.pathLeafExt(img_tile), ost.pathLeafExt(gt_tile) else: return False
def full_train(args,cv,outIter): #function for one iteration of crossvalidation : training and evaluation of one model given cv data repartition print("Iteration file is ",outIter) # np.save(outIter+"/repartition_dalles.npy",cv) #saving data repartition of current CV iter #getting channels number to still match older version without channels data in meta_img try : channels=readH5Channels(args.meta_img) print("Image channels =", channels) except : channels=4 print("WARNING ! : OLD META_IMG : guessed 4 channels image") #initialize the model model=nb.OLIVENET(channels,len(args.c)) print('Total number of parameters: {}'.format(sum([p.numel() for p in model.parameters()]))) #load model on GPU if CUDA if args.cuda: model.cuda() #define the optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr) #Learning Rate decay //!\\ LAST ARGUMENT HARD CODED plan=[i for i in range(50,300,50)] gamma=0.7 print(' | '.join('{}: {:.2e}'.format(pl, lr) for pl, lr in zip(plan,[args.lr*gamma**i for i in range(len(plan))]))) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=plan, gamma=gamma) #Case of pretrained model, load state if args.state is not False : print("TRANSFERT LEARNING : loading of :\n",args.state) checkpoint = torch.load(args.state) model.load_state_dict(checkpoint['state_dict']) #get the class weights from meta h5 w=readH5Cweight(args.meta_gt) w=[1/np.sqrt(i) if i!=0 else 0 for i in w] w=w/np.linalg.norm(w) w=torch.from_numpy(w.astype(np.float32)) print("Class weights:", [args.c[i]+": {:.4f} ".format(weight) for i,weight in enumerate(w)]) #create the perf log saving class outLog=ost.createDir(outIter+"/perf_log") save_lr=[] perf=met.SavePerf(args.e, outLog,args.c) #PERF OBJECT : to add result be sure in which dataset you put the result in outState=ost.createDir(outIter+"/states") #saving data repartition of current CrossValidation iteration will be used by Drawer algorithm with open(outIter+"/crossValID.pckl",'wb') as fp: pickle.dump(cv[0],fp) pickle.dump(cv[1],fp) pickle.dump(cv[2],fp) #create data sets // the loader function will take care of the batching #suffle=True and Droplast=true only for train loader #Batch size * 3 because no gradient stored on GPU so more memory available for quicker inference train_loader=data.DataLoader(reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt,cv[0], args.aug, (200,200)) \ , batch_size=args.bs,shuffle=True, drop_last=True, num_workers=args.num_workers) val_loader=data.DataLoader(reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt, cv[1]) \ , batch_size=args.bs*3,shuffle=False, drop_last=False, num_workers=args.num_workers) test_loader=data.DataLoader(reader.DatasetTIF(args.dir_img, args.dir_gt, args.meta_img, args.meta_gt, cv[2]) \ , batch_size=args.bs*3,shuffle=False, drop_last=False, num_workers=args.num_workers) #epoch iterations [try and except KeyboardInterrupt] to allow ctrl+c interruption try: for i_epoch in range(args.e): #periodic metric calculation and model saving if i_epoch % args.mem == 0 or i_epoch == args.e-1: #if last epoch too print("epoch:",i_epoch, end='', flush=True) #train with train data and register train metrics cm, train_loss = learningPhase(args, model, train_loader, w, optimizer, True) perf.addFullResults(i_epoch,cm,train_loss,0)#0 for TRAIN dataset # evaluate and suppress gradient calcution for these steps with torch.no_grad(): #validation data and register validation metrics and get FLAG OF BEST EPOCH cm, loss = evalutionPhase(args, model, val_loader, w) is_best=perf.addFullResults(i_epoch,cm,loss,1)#0 for VAL dataset #test data and register test metrics [not required for train] cm, loss = evalutionPhase(args, model, test_loader,w) perf.addFullResults(i_epoch,cm,loss,2)#2 for TEST dataset #save model characteristics and states save_checkpoint({ 'epoch' : i_epoch, 'channels' : channels, 'names' : args.c, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), }, is_best, outState) else :#just train without calculating metrics print("epoch:",i_epoch, end='', flush=True) __, train_loss= learningPhase(args, model, train_loader, w, optimizer, False) perf.addLossResults(i_epoch,train_loss,0)#0 for TRAIN dataset #increment learning rate step and save current value scheduler.step() save_lr.append([ group['lr'] for group in optimizer.param_groups ][0]) except KeyboardInterrupt: pass #print performance results print("Best epoch:",perf.best) perf.printResults()#print graph perf.printResultsTxt()#print values in txt #plot learning rate graph and save values in txt ost.plotLogGraph(np.arange(len(save_lr)),save_lr,'lr','epochs','Lr','Learning rate decay according to epochs'\ ,outLog+'/lr_decay.tif',show=0) df=pd.DataFrame({'epochs':np.arange(len(save_lr)),'lr':save_lr}) df.to_csv(outLog+"/lr_decay.csv", sep=";", index=False) return model
def selecter(args): #MAIN FUNCTION FOR ARGS LOOK AT PARSER #RETURN : [list of list of int] cross validation iteration list of dataset (train/val/test) list of index of tile couple #order = test,train,val by alphabetic order order_good = [ 2, 3, 1 ] #to read shapefiles in train/val/test order || 0 is all_index and is skipped data_lenght = readH5Length(args.meta_gt) selected = [[] for i in order_good ] # init container of container of locked tiles #read shapefiles to see if some tiles couple will be locked some dataset all along CV try: shps = ost.getFileByExt(args.selection, "shp") #get shapefiles print("\nReading selection shapefiles:") for i, order in enumerate(order_good): driver = ogr.GetDriverByName("ESRI Shapefile") dataSource = driver.Open(shps[order], 0) print("\t", shps[order]) layer = dataSource.GetLayer() featureCount = layer.GetFeatureCount() if featureCount < data_lenght: #if files are full, ignore them for feature in layer: name = ost.pathLeaf( feature.GetField("location")) # retrieve tile name #extract tile number index = 0 tile_num = None while tile_num is None: try: tile_num = int( ost.pathLeaf(name).split("_")[index]) # print(index) except: index += 1 pass if tile_num is None: print( "\tWARNING/!\: Error in tile numerotation, was not able to get the tile number" ) else: selected[i].append(tile_num) layer.ResetReading() # assure reset layer reading data variable print("\nSelection found :\n\tTrain samples:", len(selected[0]), "\n\tVal samples: ", len(selected[1]), "\n\tTest samples: ", len(selected[2])) except: print( "\tWARNING/!\: Path to data selection shapefiles directory not found" ) pass #load/set seed for random data repartition [if new seed, will be saved] if args.seed is False: seedsave = ost.checkNIncrementLeaf(args.out + "/seed.dump") with open(seedsave, 'wb') as fp: prng = np.random.RandomState() pickle.dump(prng, fp) else: with open(args.seed, 'rb') as fp: prng = pickle.load(fp) #cross validation creation cross = crossValSpliterSelection(prng, data_lenght, args.cvSplit, selected) # print(cross) return cross