def main(args): # Build mask parameters DataFrame df_params = pd.DataFrame({"mask_name" : args.mask_name,\ "slice_axis" : args.slice_axis,\ "n_patches" : args.n_patches,\ "overlap" : args.overlap, \ "rotation" : args.rotation}) # print(df_params) mpl.use(args.mpl_agg) data_io.show_header() if not os.path.exists(args.seg_path): os.makedirs(args.seg_path) if args.run_seg: # Understand input data format if os.path.isdir(args.ct_fpath): tiff_input = True elif args.ct_fpath.split('.')[-1] in ("hdf5", "h5"): tiff_input = False if args.ct_data_tag == "": raise ArgumentTypeError("dataset-name required for hdf5") else: raise ArgumentTypeError( "input file type not recognized. must be tiff folder or hdf5 file" ) ct_dfile = data_io.DataFile(args.ct_fpath, \ tiff = tiff_input,\ data_tag = args.ct_data_tag, \ VERBOSITY = args.rw_verbosity) ct_dfile.show_stats() chunk_shape = ct_dfile.chunk_shape if args.stats_only: print("\nSet stats_only = False and start over to run program.") sys.exit() # Load model from model repo model_filename = os.path.join(args.model_path, args.model_name + '.hdf5') print("\nStarting segmentation mode ...") segmenter = Segmenter(model_filename=model_filename) print("Reading CT volume into memory...") dd = ct_dfile.read_full() if args.preprocess: print("\tPreprocessing volume...") if not os.path.exists("preprocessor.py"): input( "Looked for preprocessor.py, but not found! Please create one and press enter. Or press CTRL+C to exit" ) from preprocessor import preprocessor dd = preprocessor(dd) for idx, row in df_params.iterrows(): # iterate over masks # assign arguments from df_params for this mask slice_axis = row["slice_axis"] max_patches = row["n_patches"] segfile_tag = row["mask_name"] overlap = row["overlap"] rotation = row["rotation"] # define DataFile object for mask seg_fname = os.path.join(args.seg_path, segfile_tag) if not args.tiff_output: seg_fname = seg_fname + ".hdf5" seg_dfile = data_io.DataFile(seg_fname, \ data_tag = "SEG",\ tiff = args.tiff_output, \ d_shape = ct_dfile.d_shape, \ d_type = np.uint8, \ chunk_shape = chunk_shape,\ VERBOSITY = args.rw_verbosity) seg_dfile.create_new(overwrite=args.overwrite_OK) t0 = time.time() print("\nWorking on %s\n" % segfile_tag) ch = process_data(dd, segmenter, \ slice_axis = slice_axis, \ rot_angle = rotation, \ max_patches = max_patches, \ overlap = overlap, \ nprocs = args.nprocs, \ arr_split = args.arr_split,\ arr_split_infer = args.arr_split_infer,\ crops = args.crops) seg_dfile.write_full(ch) t1 = time.time() total_time = (t1 - t0) / 60.0 print( "\nDONE on %s\nTotal time for generating %s mask: %.2f minutes" % (time.ctime(), segfile_tag, total_time)) del slice_axis del max_patches del segfile_tag del rotation del ch if args.run_ensemble: print("\nStarting ensemble mode ...\n") t0 = time.time() # get the d_shape of one of the masks temp_fname = os.path.join(args.seg_path, df_params.loc[0, "mask_name"]) if not args.tiff_output: temp_fname = temp_fname + ".hdf5" temp_ds = data_io.DataFile(temp_fname, data_tag="SEG", tiff=args.tiff_output, VERBOSITY=0) mask_shape = temp_ds.d_shape chunk_shape = temp_ds.chunk_shape if not args.run_seg: temp_ds.show_stats() del temp_ds del temp_fname if args.stats_only: print("\nSet stats_only = False and start over to run program.") sys.exit() vote_fname = os.path.join(args.seg_path, args.vote_maskname) if not args.tiff_output: vote_fname = vote_fname + ".hdf5" vote_dfile = data_io.DataFile(vote_fname, \ tiff = args.tiff_output,\ data_tag = "SEG",\ d_shape = mask_shape, \ d_type = np.uint8, \ chunk_shape = chunk_shape,\ VERBOSITY = args.rw_verbosity) vote_dfile.create_new(overwrite=args.overwrite_OK) slice_start = 0 n_masks = len(df_params) pbar = tqdm(total=mask_shape[0]) while slice_start < mask_shape[0]: ch = [0] * len(df_params) for idx, row in df_params.iterrows(): seg_fname = os.path.join(args.seg_path, row["mask_name"]) if not args.tiff_output: seg_fname = seg_fname + ".hdf5" seg_dfile = data_io.DataFile(seg_fname, \ tiff = args.tiff_output, \ data_tag = "SEG", \ VERBOSITY = args.rw_verbosity) if mask_shape != seg_dfile.d_shape: raise ValueError("Shape of all masks must be same") ch[idx], s = seg_dfile.read_chunk(axis = 0, \ slice_start = slice_start, \ max_GB = args.mem_thres/(n_masks)) ch = np.asarray(ch) ch = np.median(ch, axis=0).astype(np.uint8) vote_dfile.write_chunk(ch, axis=0, s=s) del ch slice_start = s.stop pbar.update(s.stop - s.start) pbar.close() t1 = time.time() total_time = (t1 - t0) / 60.0 print("\nDONE on %s\nTotal time for ensemble mask %s : %.2f minutes" % (time.ctime(), args.vote_maskname, total_time)) if args.remove_masks: print("Intermediate masks will be removed.") for idx, row in df_params.iterrows(): # iterate over masks seg_fname = os.path.join(args.seg_path, row["mask_name"]) if not args.tiff_output: seg_fname = seg_fname + ".hdf5" os.remove(seg_fname) else: rmtree(seg_fname) if args.morpho_filt: print("\nApplying morphological operations on ensemble vote...") vote_fname = os.path.join(args.seg_path, args.vote_maskname) if not args.tiff_output: vote_fname = vote_fname + ".hdf5" vote_dfile = data_io.DataFile(vote_fname, \ tiff = args.tiff_output,\ data_tag = "SEG",\ VERBOSITY = args.rw_verbosity) from ct_segnet.morpho import morpho_filter vol = vote_dfile.read_full() vol = morpho_filter(vol, radius = args.radius, \ ops = args.ops, \ crops = args.crops, \ invert_mask = args.invert_mask) vote_dfile.write_full(vol) return
def main(args_summary, **kwargs): mpl.use('Agg') # Inputs from argparser args.kern_size = _make_tuples(args.kern_size) args.kern_size_upconv = _make_tuples(args.kern_size_upconv) args.pool_size = _make_tuples(args.pool_size) data_io.show_header() ############### LOAD OR REBUILD fCNN MODEL #################### model_file = os.path.join(args.model_path, args.model_name + ".hdf5") model_history = os.path.join(args.model_path, 'history', args.model_name) if not os.path.exists(model_history): os.makedirs(model_history) if args.rebuild: if args.config_model is None: raise Exception("Model config file must be provided if rebuilding") img_shape = (args.model_size, args.model_size, 1) segmenter = build_Unet(img_shape, \ n_depth = args.n_depth,\ n_pools = args.n_pools,\ activation = args.activation,\ batch_norm = args.is_batch_norm,\ kern_size = args.kern_size,\ kern_size_upconv = args.kern_size_upconv,\ pool_size = args.pool_size,\ dropout_level = args.dropout_level,\ loss = args.loss_def,\ stdinput = args.stdinput) if args.initialize_from is not None: print("\nInitializing weights from file:\n%s" % args.initialize_from) segmenter.load_weights(args.initialize_from) segmenter.save(model_file) df_prev = None elif os.path.exists(os.path.join(model_history, args.model_name + ".csv")): segmenter = load_model(model_file, custom_objects=custom_objects_dict) df_prev = pd.read_csv( os.path.join(model_history, args.model_name + ".csv")) else: raise ValueError( "Model history not available to retrain. Rebuild required.") ###### LOAD TRAINING AND TESTING DATA ############# train_path = os.path.join(args.data_path, args.train_fname) X = data_io.DataFile(train_path, tiff=False, data_tag='X', VERBOSITY=0) Y = data_io.DataFile(train_path, tiff=False, data_tag='Y', VERBOSITY=0) # used to evaluate accuracy metrics during training test_path = os.path.join(args.data_path, args.test_fname) Xtest = data_io.DataFile(test_path, tiff=False, data_tag='X', VERBOSITY=0) Ytest = data_io.DataFile(test_path, tiff=False, data_tag='Y', VERBOSITY=0) print("\nTotal test data shape: " + str(Ytest.d_shape)) print("Total training data shape: " + str(Y.d_shape)) # save_datasnaps(data_generator(X, Y, args.n_save), model_history) ##### DO SOME LOGGING ##################### rw = "w+" if args.rebuild else "a+" logfile = os.path.join(model_history, "README_" + args.model_name + ".txt") with open(logfile, rw) as f: if args.rebuild: f.write("\n") with redirect_stdout(f): segmenter.summary() f.write("\nNew Entry\n") f.write("Training Started: " + time.ctime() + "\n") if not args.rebuild: args_summary = args_summary.split("Defaults")[0] f.write(args_summary) f.write("\nTotal train data size: %s" % (str(Y.d_shape))) f.write("\nTotal test data size: %s" % (str(Ytest.d_shape))) f.write("\n") ######### START TRAINING ################## model_paths = {'name' : args.model_name,\ 'history' : model_history,\ 'file' : model_file} logger = Logger(Xtest, Ytest, model_paths, args.autosave_freq, df_prev=df_prev, n_test=args.n_test, check_norm=args.check_norm, augmenter_todo=None) n_train = min(args.nmax_train, X.d_shape[0]) t0 = time.time() try: if args.fit_generator: steps_per_epoch = int( (1 - args.validation_split) * n_train // args.batch_size) validation_steps = int(args.validation_split * n_train // args.batch_size) # data generator will work in 'inplace' mode (faster) and generate data of size = batch_size dg = data_generator(X, Y, args.batch_size, \ check_norm = args.check_norm, \ augmenter_todo = args.augmenter_todo, \ min_SNR = args.min_SNR, inplace = True, nprocs = 1) hist = segmenter.fit_generator(dg,\ steps_per_epoch = steps_per_epoch,\ validation_data = data_generator(X, Y, args.batch_size),\ validation_steps = validation_steps,\ epochs = args.n_epochs,\ verbose = 2, \ callbacks = [logger]) else: dg = data_generator(X, Y, n_train, \ check_norm = args.check_norm, \ augmenter_todo = args.augmenter_todo, \ min_SNR = args.min_SNR, inplace = False) x_train, y_train = next(dg) hist = segmenter.fit(x = x_train, y = y_train,\ verbose = 2, initial_epoch = 0, validation_split = args.validation_split,\ epochs = args.n_epochs, batch_size = args.batch_size, callbacks = [logger]) print("\nModel training complete") with open(logfile, "a+") as f: f.write("\nTraining Completed: " + time.ctime() + "\n") hours = (time.time() - t0) / (60 * 60.0) f.write("\nTotal time: %.4f hours\n" % (hours)) f.write("\nEnd of Entry") print("\n Total time: %.4f hours" % (hours)) segmenter.save(model_file) ########### EVALUATED MODEL AND SAVE SOME TEST IMAGES ############ print("\nSaving some test images...") model_results = os.path.join(args.model_path, 'history', args.model_name, 'testing') dg = data_generator(Xtest, Ytest, args.n_save, \ augmenter_todo = args.augmenter_todo, \ min_SNR = args.min_SNR, \ inplace = False, nprocs = 1) save_results(dg, model_results, segmenter) # Log keyboard interrupt exception except KeyboardInterrupt: print("\nModel training interrupted") with open(logfile, "a+") as f: f.write("\nTraining Interrupted after %i epochs: " % logger.i + time.ctime() + "\n") f.write("\nEnd of Entry") return
def main(args): data_io.show_header() print("\nLet's make some training data.\n") # Understand input data format ct_istiff = data_io._istiff(args.ct_fpath, args.ct_data_tag) seg_istiff = data_io._istiff(args.seg_path, args.seg_data_tag) dfile_recon = data_io.DataFile(args.ct_fpath, data_tag=args.ct_data_tag, tiff=ct_istiff, VERBOSITY=0) dfile_seg = data_io.DataFile(args.seg_path, data_tag=args.seg_data_tag, tiff=seg_istiff, VERBOSITY=0) if args.output_fpath == "": args.output_fpath = os.path.split(args.seg_path)[0] args.output_fname = args.output_fname.split('.')[0] + ".hdf5" output_fname = os.path.join(args.output_fpath, args.output_fname) print("Grayscale CT data:\n%s\n" % dfile_recon.fname) dfile_recon.show_stats() print("Manually segmented mask:\n%s\n" % dfile_seg.fname) dfile_seg.show_stats() if np.prod(dfile_seg.d_shape) != np.prod(dfile_recon.d_shape): raise ValueError( "CT data and segmentation mask must be exactly same shape / size") # Decide domain extents crop_shape = data_io.get_cropped_shape(args.crops, dfile_seg.d_shape) # Decide patching / slicing strategy df_params = pd.DataFrame({ "slice axis": args.slice_axis, "max patches": args.n_patches }) skip_fac = args.skip_fac # Estimate the shapes of hdf5 data files to be written n_images = 0 for idx, row in df_params.iterrows(): _len = np.ceil(crop_shape[row["slice axis"]] / skip_fac) _len = _len * np.prod(row["max patches"]) print("Calculated length of set %i: %i" % (idx + 1, _len)) n_images = _len + n_images write_dshape = (int(n_images), args.model_size, args.model_size) # Create datafile objects for writing stuff, and get file paths for it also w_dfile_seg = data_io.DataFile(output_fname, data_tag='Y', tiff=False, chunked_slice_size=None, d_shape=write_dshape, d_type=np.uint8) w_dfile_recon = data_io.DataFile(output_fname, data_tag='X', tiff=False, chunked_slice_size=None, d_shape=write_dshape, d_type=np.float32) w_dfile_seg.create_new(overwrite=False) w_dfile_recon.create_new(overwrite=False) # Work on seg data print("Working on the segmentation map...") d_seg = dfile_seg.read_full() d_seg = d_seg[slice(*args.crops[0]), slice(*args.crops[1]), slice(*args.crops[2])] gc.collect() slice_start = 0 pbar = tqdm(total=n_images) for idx, row in df_params.iterrows(): p = process_data(d_seg, skip_fac=skip_fac, nprocs=args.nprocs, patch_size=args.model_size, n_patches=row['max patches'], axis=row['slice axis']) slice_end = p.shape[0] + slice_start s = slice(slice_start, slice_end) w_dfile_seg.write_chunk(p, axis=0, s=s) slice_start = slice_end del p pbar.update(s.stop - s.start) gc.collect() pbar.close() del d_seg gc.collect() # Work on recon data print("Working on the grayscale CT volume...") slice_start = 0 d_recon = dfile_recon.read_full() d_recon = d_recon[slice(*args.crops[0]), slice(*args.crops[1]), slice(*args.crops[2])] pbar = tqdm(total=n_images) for idx, row in df_params.iterrows(): p = process_data(d_recon, skip_fac=skip_fac, nprocs=args.nprocs, patch_size=args.model_size, n_patches=row['max patches'], axis=row['slice axis']) slice_end = p.shape[0] + slice_start s = slice(slice_start, slice_end) w_dfile_recon.write_chunk(p, axis=0, s=s) slice_start = slice_end del p pbar.update(s.stop - s.start) gc.collect() pbar.close()
def main(args): data_io.show_header() # Understand input data format if os.path.isdir(args.input_fname): tiff_input = True if args.dataset_name == "": args.dataset_name = "data" elif args.input_fname.split('.')[-1] in ("hdf5", "h5"): tiff_input = False if args.dataset_name == "": raise ArgumentTypeError("dataset-name required for hdf5") else: raise ArgumentTypeError("input file type not recognized. must be tiff folder or hdf5 file") input_fname = args.input_fname # set up output file name / path / chunks parameter if args.output_fpath == "": args.output_fpath = os.path.split(args.input_fname)[0] args.output_fname = args.output_fname.split('.')[0] + ".hdf5" output_fname = os.path.join(args.output_fpath, args.output_fname) if type(args.chunk_param) in (int, float): chunk_size = args.chunk_param/1e3 # convert to GB chunk_shape = None elif type(args.chunk_param) == tuple: chunk_shape = args.chunk_param chunk_size = None chunked_slice_size = args.chunked_slice_size # print("Type chunk_param" + str(type(args.chunk_param))) # print("Type chunked_slice_size" + str(type(args.chunked_slice_size))) # print("Overwrite OK is %s"%args.overwrite_OK) # print("Stats only is %s"%args.stats_only) # print("Delete is %s"%args.delete) # sys.exit() # Define DataFile instances - quit here if stats_only requested r_dfile = data_io.DataFile(input_fname, tiff = tiff_input, \ data_tag = args.dataset_name, \ VERBOSITY = args.verbosity) print("Input data stats:") r_dfile.show_stats() if args.stats_only: sys.exit() w_shape = r_dfile.d_shape # future implementation must allow resampling dataset w_dtype = r_dfile.d_type # future implementation must allow changing dtype (with renormalization) w_dfile = data_io.DataFile(output_fname, tiff = False, \ data_tag = args.dataset_name, \ VERBOSITY = args.verbosity, \ d_shape = w_shape, d_type = w_dtype, \ chunk_shape = chunk_shape, \ chunk_size = chunk_size, \ chunked_slice_size = chunked_slice_size) print("\nChunking scheme estimated as: %s"%str(w_dfile.chunk_shape)) input("\nHDF5 file will be saved to the following location.\n%s\nPress any key to continue."%output_fname) w_dfile.create_new(overwrite = args.overwrite_OK) t0 = time.time() slice_start = 0 print("\n") pbar = tqdm(total = r_dfile.d_shape[0]) while slice_start < r_dfile.d_shape[0]: dd, s = r_dfile.read_chunk(axis = 0, slice_start = slice_start, \ max_GB = mem_thres, \ chunk_shape = w_dfile.chunk_shape) w_dfile.write_chunk(dd, axis = 0, s = s) slice_start = s.stop pbar.update(s.stop - s.start) pbar.close() total_time = (time.time() - t0)/60.0 # minutes print("\nTotal time: %.2f minutes"%(total_time)) if args.delete: input("Delete old file? Press any key") if tiff_input: rmtree(input_fname) else: os.remove(input_fname)
def main(args): df_params = pd.DataFrame({"mask_name" : args.mask_name,\ "slice_axis" : args.slice_axis,\ "n_patches" : args.n_patches,\ "overlap" : args.overlap}) mpl.use(args.mpl_agg) # print(df_params) data_io.show_header() if not os.path.exists(args.seg_path): os.makedirs(args.seg_path) if args.run_seg: # Load model from model repo ct_dfile = data_io.DataFile(args.ct_fpath, \ tiff = False,\ data_tag = args.ct_data_tag, \ VERBOSITY = args.rw_verbosity) ct_dfile.show_stats() if args.stats_only: print("\nSet stats_only = False and start over to run program.") sys.exit() chunk_shape = ct_dfile.chunk_shape model_filename = os.path.join(args.model_path, args.model_name + '.hdf5') print("\nStarting segmentation mode ...") segmenter = Segmenter(model_filename = model_filename) for idx, row in df_params.iterrows(): # iteratve over masks # assign arguments from df_params for this mask slice_axis = row["slice_axis"] max_patches = row["n_patches"] segfile_tag = row["mask_name"] overlap = row["overlap"] # define DataFile object for mask seg_fname = os.path.join(args.seg_path, segfile_tag + ".hdf5") seg_dfile = data_io.DataFile(seg_fname, \ data_tag = "SEG",\ tiff = False, \ d_shape = ct_dfile.d_shape, \ d_type = np.uint8, \ chunk_shape = chunk_shape,\ VERBOSITY = args.rw_verbosity) seg_dfile.create_new(overwrite = args.overwrite_OK) t0 = time.time() slice_start = 0 print("\nWorking on %s\n"%segfile_tag) pbar = tqdm(total = seg_dfile.d_shape[slice_axis]) while slice_start < seg_dfile.d_shape[slice_axis]: ch, s = ct_dfile.read_chunk(axis = slice_axis, \ slice_start = slice_start, \ max_GB = args.mem_thres) ch = segmenter.seg_chunk(ch, \ max_patches = max_patches, \ overlap = overlap,\ nprocs = args.nprocs, \ arr_split = args.arr_split) seg_dfile.write_chunk(ch, axis = slice_axis, s = s) del ch slice_start = s.stop pbar.update(s.stop - s.start) pbar.close() t1 = time.time() total_time = (t1 - t0) / 60.0 print("\nDONE on %s\nTotal time for generating %s mask: %.2f minutes"%(time.ctime(), segfile_tag, total_time)) del slice_axis del max_patches del segfile_tag if args.run_ensemble: print("\nStarting ensemble mode ...\n") t0 = time.time() # get the d_shape of one of the masks temp_fname = os.path.join(args.seg_path, df_params.loc[0,"mask_name"] + ".hdf5") temp_ds = data_io.DataFile(temp_fname, data_tag = "SEG", tiff = False, VERBOSITY = 0) mask_shape = temp_ds.d_shape chunk_shape = temp_ds.chunk_shape if not args.run_seg: temp_ds.show_stats() del temp_ds del temp_fname if args.stats_only: print("\nSet stats_only = False and start over to run program.") sys.exit() vote_fname = os.path.join(args.seg_path, args.vote_maskname) if not args.tiff_output: vote_fname = vote_fname + ".hdf5" vote_dfile = data_io.DataFile(vote_fname, \ tiff = args.tiff_output,\ data_tag = "SEG", \ d_shape = mask_shape, \ d_type = np.uint8, \ chunk_shape = chunk_shape,\ VERBOSITY = args.rw_verbosity) vote_dfile.create_new(overwrite = args.overwrite_OK) slice_start = 0 n_masks = len(df_params) pbar = tqdm(total = mask_shape[0]) while slice_start < mask_shape[0]: ch = [0]*len(df_params) for idx, row in df_params.iterrows(): seg_fname = os.path.join(args.seg_path, row["mask_name"]) + ".hdf5" seg_dfile = data_io.DataFile(seg_fname, \ tiff = False, \ data_tag = "SEG", \ VERBOSITY = args.rw_verbosity) if mask_shape != seg_dfile.d_shape: raise ValueError("Shape of all masks must be same") ch[idx], s = seg_dfile.read_chunk(axis = 0, \ slice_start = slice_start, \ max_GB = args.mem_thres/(n_masks)) ch = np.asarray(ch) ch = np.median(ch, axis = 0).astype(np.uint8) vote_dfile.write_chunk(ch, axis = 0, s = s) del ch slice_start = s.stop pbar.update(s.stop - s.start) pbar.close() t1 = time.time() total_time = (t1 - t0) / 60.0 print("\nDONE on %s\nTotal time for ensemble mask %s : %.2f minutes"%(time.ctime(), args.vote_maskname, total_time)) if args.remove_masks: print("Intermediate masks will be removed.") for idx, row in df_params.iterrows(): # iterate over masks seg_fname = os.path.join(args.seg_path, row["mask_name"]) + ".hdf5" os.remove(seg_fname) return