import pandas as pd from suanpan import asyncio, path, utils from suanpan.arguments import String from suanpan.docker import DockerComponent as dc from suanpan.docker.arguments import Folder, HiveTable @dc.input( HiveTable(key="inputData", table="inputDataTable", partition="inputDataPartition")) @dc.input(Folder(key="inputDataFolder", required=True)) @dc.output(Folder(key="outputImagesFolder", required=True)) @dc.column(String(key="idColumn", default="id")) @dc.column(String(key="dataColumn", default="data_path")) def SPData2Images(context): args = context.args with asyncio.multiThread() as pool: for _, row in args.inputData.iterrows(): image = utils.loadFromNpy( os.path.join(args.inputDataFolder, row[args.dataColumn])) prefix = os.path.join(args.outputImagesFolder, row[args.idColumn]) utils.saveAllAsImages(prefix, image, pool=pool) return args.outputImagesFolder if __name__ == "__main__":
Folder( key="inputDataFolder", required=True, help="DSB stage1/2 or similar directory path.", )) @dc.output( HiveTable(key="outputData", table="outputDataTable", partition="outputDataPartition")) @dc.output( Folder( key="outputDataFolder", required=True, help="Directory to save preprocessed npy files to.", )) @dc.output(String(key="idColumn", default="patient")) @dc.output(String(key="imageColumn", default="image_path")) def SPPredictPreprocess(context): args = context.args stage1Path = args.inputDataFolder preprocessResultPath = args.outputDataFolder full_prep(stage1Path, preprocessResultPath, n_worker=asyncio.WORKERS, use_existing=False) data = scan_prep_results(preprocessResultPath, args.idColumn, args.imageColumn) return data, args.outputDataFolder
@dc.input( Checkpoint(key="inputCheckpoint", required=True, help="Ckpt model file.")) @dc.output( HiveTable( key="outputBboxData", table="outputBboxDataTable", partition="outputBboxDataPartition", )) @dc.output( Folder( key="outputBboxFolder", required=True, help="Directory to save bbox npy files to.", )) @dc.column( String(key="idColumn", default="id", help="ID column of inputImagesData.")) @dc.column( String(key="lbbColumn", default="lbb_path", help="Lbb column of inputImagesData.")) @dc.column( String(key="pbbColumn", default="pbb_path", help="Pbb column of inputImagesData.")) @dc.param(Int(key="margin", default=16, help="patch margin.")) @dc.param(Int(key="sidelen", default=64, help="patch side length.")) @dc.param(Int(key="batchSize", default=16)) def SPNNetPredict(context): args = context.args data = args.inputData
@dc.input( HiveTable( key="inputTrainData", table="inputTrainDataTable", partition="inputTrainDataPartition", )) @dc.input( HiveTable( key="inputValidateData", table="inputValidateDataTable", partition="inputValidateDataPartition", )) @dc.input(Folder(key="inputDataFolder", required=True)) @dc.input(Checkpoint(key="inputCheckpoint")) @dc.output(Checkpoint(key="outputCheckpoint", required=True)) @dc.column(String(key="idColumn", default="id")) @dc.param(Int(key="epochs", default=100)) @dc.param(Int(key="batchSize", default=16)) @dc.param(Float(key="learningRate", default=0.01)) @dc.param(Float(key="momentum", default=0.9)) @dc.param(Float(key="weightDecay", default=1e-4)) @dc.param(Bool(key="distributed", default=False)) @dc.param(String(key="ckptFolder")) @dc.param(Bool(key="saveFreq", default=1)) def SPNNetTrain(context): torch.manual_seed(0) args = context.args saveFolder = os.path.dirname(args.outputCheckpoint) dataFolder = args.inputDataFolder
@dc.input( HiveTable( key="inputTrainData", table="inputTrainDataTable", partition="inputTrainDataPartition", )) @dc.input( HiveTable( key="inputValidateData", table="inputValidateDataTable", partition="inputValidateDataPartition", )) @dc.input(Folder(key="inputDataFolder", required=True)) @dc.input(Checkpoint(key="inputCheckpoint")) @dc.output(Checkpoint(key="outputCheckpoint", required=True)) @dc.column(String(key="idColumn", default="id")) @dc.param(Int(key="epochs", default=100)) @dc.param(Int(key="batchSize", default=16)) @dc.param(Float(key="learningRate", default=0.01)) @dc.param(Float(key="momentum", default=0.9)) @dc.param(Float(key="weightDecay", default=1e-4)) def SPNNetTrain(context): torch.manual_seed(0) args = context.args saveFolder = os.path.dirname(args.outputCheckpoint) dataFolder = args.inputDataFolder checkoutPointPath = args.inputCheckpoint trainIds = args.inputTrainData[args.idColumn] validateIds = args.inputValidateData[args.idColumn]