def main(args, max_workers=3): signal_paths = args.signal_planes_paths[args.signal_channel] background_paths = args.background_planes_path[0] signal_images = get_sorted_file_paths(signal_paths, file_extension="tif") background_images = get_sorted_file_paths(background_paths, file_extension="tif") # Too many workers doesn't increase speed, and uses huge amounts of RAM workers = get_num_processes(min_free_cpu_cores=args.n_free_cpus, n_max_processes=max_workers) logging.debug("Initialising cube generator") inference_generator = CubeGeneratorFromFile( args.paths.detected_points, signal_images, background_images, args.voxel_sizes, args.network_voxel_sizes, batch_size=args.batch_size, cube_width=args.cube_width, cube_height=args.cube_height, cube_depth=args.cube_depth, ) model = get_model( existing_model=args.trained_model, model_weights=args.model_weights, network_depth=models[args.network_depth], inference=True, ) logging.info("Running inference") predictions = model.predict( inference_generator, use_multiprocessing=True, workers=workers, verbose=True, ) predictions = predictions.round() predictions = predictions.astype("uint16") predictions = np.argmax(predictions, axis=1) cells_list = [] # only go through the "extractable" cells for idx, cell in enumerate(inference_generator.ordered_cells): cell.type = predictions[idx] + 1 cells_list.append(cell) logging.info("Saving classified cells") save_cells(cells_list, args.paths.classified_points, save_csv=args.save_csv) try: get_cells(args.paths.classified_points, cells_only=True) return True except MissingCellsError: return False
def main(): from cellfinder.main import suppress_tf_logging suppress_tf_logging(tf_suppress_log_messages) from tensorflow.keras.callbacks import ( TensorBoard, ModelCheckpoint, CSVLogger, ) from cellfinder.tools.prep import prep_training from cellfinder.classify.tools import make_lists, get_model from cellfinder.classify.cube_generator import CubeGeneratorFromDisk start_time = datetime.now() args = training_parse() output_dir = Path(args.output_dir) ensure_directory_exists(output_dir) args = prep_training(args) fancylog.start_logging( args.output_dir, program_for_log, variables=[args], log_header="CELLFINDER TRAINING LOG", ) yaml_contents = parse_yaml(args.yaml_file) tiff_files = get_tiff_files(yaml_contents) logging.info(f"Found {sum(len(imlist) for imlist in tiff_files)} images " f"from {len(yaml_contents)} datasets " f"in {len(args.yaml_file)} yaml files") model = get_model( existing_model=args.trained_model, model_weights=args.model_weights, network_depth=models[args.network_depth], learning_rate=args.learning_rate, continue_training=args.continue_training, ) signal_train, background_train, labels_train = make_lists(tiff_files) if args.test_fraction > 0: logging.info("Splitting data into training and validation datasets") ( signal_train, signal_test, background_train, background_test, labels_train, labels_test, ) = train_test_split( signal_train, background_train, labels_train, test_size=args.test_fraction, ) logging.info(f"Using {len(signal_train)} images for training and " f"{len(signal_test)} images for validation") validation_generator = CubeGeneratorFromDisk( signal_test, background_test, labels=labels_test, batch_size=args.batch_size, train=True, ) # for saving checkpoints base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}.h5" else: logging.info("No validation data selected.") validation_generator = None base_checkpoint_file_name = "-epoch.{epoch:02d}.h5" training_generator = CubeGeneratorFromDisk( signal_train, background_train, labels=labels_train, batch_size=args.batch_size, shuffle=True, train=True, augment=not args.no_augment, ) callbacks = [] if args.tensorboard: logdir = output_dir / "tensorboard" ensure_directory_exists(logdir) tensorboard = TensorBoard( log_dir=logdir, histogram_freq=0, write_graph=True, update_freq="epoch", ) callbacks.append(tensorboard) if not args.no_save_checkpoints: if args.save_weights: filepath = str(output_dir / ("weight" + base_checkpoint_file_name)) else: filepath = str(output_dir / ("model" + base_checkpoint_file_name)) checkpoints = ModelCheckpoint( filepath, save_weights_only=args.save_weights, ) callbacks.append(checkpoints) if args.save_progress: filepath = str(output_dir / "training.csv") csv_logger = CSVLogger(filepath) callbacks.append(csv_logger) logging.info("Beginning training.") model.fit( training_generator, validation_data=validation_generator, use_multiprocessing=False, epochs=args.epochs, callbacks=callbacks, ) if args.save_weights: logging.info("Saving model weights") model.save_weights(str(output_dir / "model_weights.h5")) else: logging.info("Saving model") model.save(output_dir / "model.h5") logging.info( "Finished training, " "Total time taken: %s", datetime.now() - start_time, )
def main(max_workers=3): from cellfinder.main import suppress_tf_logging suppress_tf_logging(tf_suppress_log_messages) from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint from cellfinder.tools import system from cellfinder.tools.prep import prep_training from cellfinder.classify.tools import make_lists, get_model from cellfinder.classify.cube_generator import CubeGeneratorFromDisk start_time = datetime.now() args = training_parse() output_dir = Path(args.output_dir) system.ensure_directory_exists(output_dir) args = prep_training(args) tiff_files = parse_yaml(args.yaml_file) # Too many workers doesn't increase speed, and uses huge amounts of RAM workers = system.get_num_processes(min_free_cpu_cores=args.n_free_cpus, n_max_processes=max_workers) model = get_model( existing_model=args.trained_model, model_weights=args.model_weights, network_depth=models[args.network_depth], learning_rate=args.learning_rate, continue_training=args.continue_training, ) signal_train, background_train, labels_train = make_lists(tiff_files) if args.test_fraction > 0: ( signal_train, signal_test, background_train, background_test, labels_train, labels_test, ) = train_test_split( signal_train, background_train, labels_train, test_size=args.test_fraction, ) validation_generator = CubeGeneratorFromDisk( signal_test, background_test, labels=labels_test, batch_size=args.batch_size, train=True, ) else: validation_generator = None training_generator = CubeGeneratorFromDisk( signal_train, background_train, labels=labels_train, batch_size=args.batch_size, shuffle=True, train=True, augment=not args.no_augment, ) callbacks = [] if args.tensorboard: logdir = output_dir / "tensorboard" system.ensure_directory_exists(logdir) tensorboard = TensorBoard( log_dir=logdir, histogram_freq=0, write_graph=True, update_freq="epoch", ) callbacks.append(tensorboard) if args.save_checkpoints: if args.save_weights: filepath = str(output_dir / "weights.{epoch:02d}-{val_loss:.3f}.h5") else: filepath = str(output_dir / "model.{epoch:02d}-{val_loss:.3f}.h5") checkpoints = ModelCheckpoint(filepath, save_weights_only=args.save_weights) callbacks.append(checkpoints) model.fit( training_generator, validation_data=validation_generator, use_multiprocessing=True, workers=workers, epochs=args.epochs, callbacks=callbacks, ) if args.save_weights: print("Saving model weights") model.save_weights(str(output_dir / "model_weights.h5")) else: print("Saving model") model.save(output_dir / "model.h5") print( "Finished training, " "Total time taken: %s", datetime.now() - start_time, )