Exemple #1
0
def main(args, max_workers=3):
    signal_paths = args.signal_planes_paths[args.signal_channel]
    background_paths = args.background_planes_path[0]
    signal_images = get_sorted_file_paths(signal_paths, file_extension="tif")
    background_images = get_sorted_file_paths(background_paths,
                                              file_extension="tif")

    # Too many workers doesn't increase speed, and uses huge amounts of RAM
    workers = get_num_processes(min_free_cpu_cores=args.n_free_cpus,
                                n_max_processes=max_workers)

    logging.debug("Initialising cube generator")
    inference_generator = CubeGeneratorFromFile(
        args.paths.detected_points,
        signal_images,
        background_images,
        args.voxel_sizes,
        args.network_voxel_sizes,
        batch_size=args.batch_size,
        cube_width=args.cube_width,
        cube_height=args.cube_height,
        cube_depth=args.cube_depth,
    )

    model = get_model(
        existing_model=args.trained_model,
        model_weights=args.model_weights,
        network_depth=models[args.network_depth],
        inference=True,
    )

    logging.info("Running inference")
    predictions = model.predict(
        inference_generator,
        use_multiprocessing=True,
        workers=workers,
        verbose=True,
    )
    predictions = predictions.round()
    predictions = predictions.astype("uint16")

    predictions = np.argmax(predictions, axis=1)
    cells_list = []

    # only go through the "extractable" cells
    for idx, cell in enumerate(inference_generator.ordered_cells):
        cell.type = predictions[idx] + 1
        cells_list.append(cell)

    logging.info("Saving classified cells")
    save_cells(cells_list,
               args.paths.classified_points,
               save_csv=args.save_csv)
    try:
        get_cells(args.paths.classified_points, cells_only=True)
        return True
    except MissingCellsError:
        return False
Exemple #2
0
def main():
    from cellfinder.main import suppress_tf_logging

    suppress_tf_logging(tf_suppress_log_messages)

    from tensorflow.keras.callbacks import (
        TensorBoard,
        ModelCheckpoint,
        CSVLogger,
    )

    from cellfinder.tools.prep import prep_training
    from cellfinder.classify.tools import make_lists, get_model
    from cellfinder.classify.cube_generator import CubeGeneratorFromDisk

    start_time = datetime.now()
    args = training_parse()
    output_dir = Path(args.output_dir)
    ensure_directory_exists(output_dir)
    args = prep_training(args)

    fancylog.start_logging(
        args.output_dir,
        program_for_log,
        variables=[args],
        log_header="CELLFINDER TRAINING LOG",
    )

    yaml_contents = parse_yaml(args.yaml_file)

    tiff_files = get_tiff_files(yaml_contents)
    logging.info(f"Found {sum(len(imlist) for imlist in tiff_files)} images "
                 f"from {len(yaml_contents)} datasets "
                 f"in {len(args.yaml_file)} yaml files")

    model = get_model(
        existing_model=args.trained_model,
        model_weights=args.model_weights,
        network_depth=models[args.network_depth],
        learning_rate=args.learning_rate,
        continue_training=args.continue_training,
    )

    signal_train, background_train, labels_train = make_lists(tiff_files)

    if args.test_fraction > 0:
        logging.info("Splitting data into training and validation datasets")
        (
            signal_train,
            signal_test,
            background_train,
            background_test,
            labels_train,
            labels_test,
        ) = train_test_split(
            signal_train,
            background_train,
            labels_train,
            test_size=args.test_fraction,
        )

        logging.info(f"Using {len(signal_train)} images for training and "
                     f"{len(signal_test)} images for validation")
        validation_generator = CubeGeneratorFromDisk(
            signal_test,
            background_test,
            labels=labels_test,
            batch_size=args.batch_size,
            train=True,
        )

        # for saving checkpoints
        base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}.h5"

    else:
        logging.info("No validation data selected.")
        validation_generator = None
        base_checkpoint_file_name = "-epoch.{epoch:02d}.h5"

    training_generator = CubeGeneratorFromDisk(
        signal_train,
        background_train,
        labels=labels_train,
        batch_size=args.batch_size,
        shuffle=True,
        train=True,
        augment=not args.no_augment,
    )
    callbacks = []

    if args.tensorboard:
        logdir = output_dir / "tensorboard"
        ensure_directory_exists(logdir)
        tensorboard = TensorBoard(
            log_dir=logdir,
            histogram_freq=0,
            write_graph=True,
            update_freq="epoch",
        )
        callbacks.append(tensorboard)

    if not args.no_save_checkpoints:
        if args.save_weights:
            filepath = str(output_dir / ("weight" + base_checkpoint_file_name))
        else:
            filepath = str(output_dir / ("model" + base_checkpoint_file_name))

        checkpoints = ModelCheckpoint(
            filepath,
            save_weights_only=args.save_weights,
        )
        callbacks.append(checkpoints)

    if args.save_progress:
        filepath = str(output_dir / "training.csv")
        csv_logger = CSVLogger(filepath)
        callbacks.append(csv_logger)

    logging.info("Beginning training.")
    model.fit(
        training_generator,
        validation_data=validation_generator,
        use_multiprocessing=False,
        epochs=args.epochs,
        callbacks=callbacks,
    )

    if args.save_weights:
        logging.info("Saving model weights")
        model.save_weights(str(output_dir / "model_weights.h5"))
    else:
        logging.info("Saving model")
        model.save(output_dir / "model.h5")

    logging.info(
        "Finished training, "
        "Total time taken: %s",
        datetime.now() - start_time,
    )
Exemple #3
0
def main(max_workers=3):
    from cellfinder.main import suppress_tf_logging

    suppress_tf_logging(tf_suppress_log_messages)

    from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint

    from cellfinder.tools import system
    from cellfinder.tools.prep import prep_training
    from cellfinder.classify.tools import make_lists, get_model
    from cellfinder.classify.cube_generator import CubeGeneratorFromDisk

    start_time = datetime.now()
    args = training_parse()
    output_dir = Path(args.output_dir)
    system.ensure_directory_exists(output_dir)
    args = prep_training(args)
    tiff_files = parse_yaml(args.yaml_file)

    # Too many workers doesn't increase speed, and uses huge amounts of RAM
    workers = system.get_num_processes(min_free_cpu_cores=args.n_free_cpus,
                                       n_max_processes=max_workers)

    model = get_model(
        existing_model=args.trained_model,
        model_weights=args.model_weights,
        network_depth=models[args.network_depth],
        learning_rate=args.learning_rate,
        continue_training=args.continue_training,
    )

    signal_train, background_train, labels_train = make_lists(tiff_files)

    if args.test_fraction > 0:
        (
            signal_train,
            signal_test,
            background_train,
            background_test,
            labels_train,
            labels_test,
        ) = train_test_split(
            signal_train,
            background_train,
            labels_train,
            test_size=args.test_fraction,
        )
        validation_generator = CubeGeneratorFromDisk(
            signal_test,
            background_test,
            labels=labels_test,
            batch_size=args.batch_size,
            train=True,
        )
    else:
        validation_generator = None

    training_generator = CubeGeneratorFromDisk(
        signal_train,
        background_train,
        labels=labels_train,
        batch_size=args.batch_size,
        shuffle=True,
        train=True,
        augment=not args.no_augment,
    )
    callbacks = []

    if args.tensorboard:
        logdir = output_dir / "tensorboard"
        system.ensure_directory_exists(logdir)
        tensorboard = TensorBoard(
            log_dir=logdir,
            histogram_freq=0,
            write_graph=True,
            update_freq="epoch",
        )
        callbacks.append(tensorboard)

    if args.save_checkpoints:
        if args.save_weights:
            filepath = str(output_dir /
                           "weights.{epoch:02d}-{val_loss:.3f}.h5")
        else:
            filepath = str(output_dir / "model.{epoch:02d}-{val_loss:.3f}.h5")

        checkpoints = ModelCheckpoint(filepath,
                                      save_weights_only=args.save_weights)
        callbacks.append(checkpoints)

    model.fit(
        training_generator,
        validation_data=validation_generator,
        use_multiprocessing=True,
        workers=workers,
        epochs=args.epochs,
        callbacks=callbacks,
    )

    if args.save_weights:
        print("Saving model weights")
        model.save_weights(str(output_dir / "model_weights.h5"))
    else:
        print("Saving model")
        model.save(output_dir / "model.h5")

    print(
        "Finished training, "
        "Total time taken: %s",
        datetime.now() - start_time,
    )