Exemple #1
0
    def from_trained_models(
        cls,
        centroid_model_path: Optional[Text] = None,
        confmap_model_path: Optional[Text] = None,
        batch_size: int = 1,
        peak_threshold: float = 0.2,
        integral_refinement: bool = True,
        integral_patch_size: int = 5,
    ) -> "TopdownPredictor":
        """Create predictor from saved models.

        Args:
            centroid_model_path: Path to centroid model folder.
            confmap_model_path: Path to topdown confidence map model folder.
        
        Returns:
            An instance of TopdownPredictor with the loaded models.

            One of the two models can be left as None to perform inference with ground
            truth data. This will only work with LabelsReader as the provider.
        """
        if centroid_model_path is None and confmap_model_path is None:
            raise ValueError(
                "Either the centroid or topdown confidence map model must be provided."
            )

        if centroid_model_path is not None:
            # Load centroid model.
            centroid_config = TrainingJobConfig.load_json(centroid_model_path)
            centroid_keras_model_path = get_keras_model_path(centroid_model_path)
            centroid_model = Model.from_config(centroid_config.model)
            centroid_model.keras_model = tf.keras.models.load_model(
                centroid_keras_model_path, compile=False
            )
        else:
            centroid_config = None
            centroid_model = None

        if confmap_model_path is not None:
            # Load confmap model.
            confmap_config = TrainingJobConfig.load_json(confmap_model_path)
            confmap_keras_model_path = get_keras_model_path(confmap_model_path)
            confmap_model = Model.from_config(confmap_config.model)
            confmap_model.keras_model = tf.keras.models.load_model(
                confmap_keras_model_path, compile=False
            )
        else:
            confmap_config = None
            confmap_model = None

        return cls(
            centroid_config=centroid_config,
            centroid_model=centroid_model,
            confmap_config=confmap_config,
            confmap_model=confmap_model,
            batch_size=batch_size,
            peak_threshold=peak_threshold,
            integral_refinement=integral_refinement,
            integral_patch_size=integral_patch_size,
        )
Exemple #2
0
    def from_trained_models(cls, model_path: Text) -> "VisualPredictor":
        cfg = TrainingJobConfig.load_json(model_path)
        keras_model_path = get_keras_model_path(model_path)
        model = Model.from_config(cfg.model)
        model.keras_model = tf.keras.models.load_model(keras_model_path, compile=False)

        return cls(config=cfg, model=model)
Exemple #3
0
 def from_config_file(cls, path):
     cfg = TrainingJobConfig.load_json(path)
     head_name = cfg.model.heads.which_oneof_attrib_name()
     filename = os.path.basename(path)
     return cls(config=cfg,
                path=path,
                filename=filename,
                head_name=head_name)
Exemple #4
0
    def from_trained_models(cls, bottomup_model_path: Text) -> "BottomupPredictor":
        """Create predictor from saved models."""
        # Load bottomup model.
        bottomup_config = TrainingJobConfig.load_json(bottomup_model_path)
        bottomup_keras_model_path = get_keras_model_path(bottomup_model_path)
        bottomup_model = Model.from_config(bottomup_config.model)
        bottomup_model.keras_model = tf.keras.models.load_model(
            bottomup_keras_model_path, compile=False
        )

        return cls(bottomup_config=bottomup_config, bottomup_model=bottomup_model)
Exemple #5
0
    def try_loading_path(self, path: Text):
        try:
            cfg = TrainingJobConfig.load_json(path)
        except Exception as e:
            # Couldn't load so just ignore
            print(e)
            pass
        else:
            # Get the head from the model (i.e., what the model will predict)
            key = cfg.model.heads.which_oneof_attrib_name()

            filename = os.path.basename(path)

            # If filter isn't set or matches head name, add config to list
            if self.head_filter in (None, key):
                return ConfigFileInfo(path=path,
                                      filename=filename,
                                      config=cfg,
                                      head_name=key)

        return None
Exemple #6
0
def find_heads_for_model_paths(paths) -> Dict[str, str]:
    """Given list of models paths, returns dict with path keyed by head name."""

    trained_model_paths = dict()

    if paths is None:
        return trained_model_paths

    for model_path in paths:
        # Load the model config
        cfg = TrainingJobConfig.load_json(model_path)

        # Get the head from the model (i.e., what the model will predict)
        key = cfg.model.heads.which_oneof_attrib_name()

        # If path is to config file json, then get the path to parent dir
        if model_path.endswith(".json"):
            model_path = os.path.dirname(model_path)

        trained_model_paths[key] = model_path

    return trained_model_paths
Exemple #7
0
    def from_trained_models(
        cls,
        confmap_model_path: Text,
        peak_threshold: float = 0.2,
        integral_refinement: bool = True,
        integral_patch_size: int = 7,
    ) -> "SingleInstancePredictor":
        """Create predictor from saved models."""
        # Load confmap model.
        confmap_config = TrainingJobConfig.load_json(confmap_model_path)
        confmap_keras_model_path = get_keras_model_path(confmap_model_path)
        confmap_model = Model.from_config(confmap_config.model)
        confmap_model.keras_model = tf.keras.models.load_model(
            confmap_keras_model_path, compile=False)

        return cls(
            confmap_config=confmap_config,
            confmap_model=confmap_model,
            peak_threshold=peak_threshold,
            integral_refinement=integral_refinement,
            integral_patch_size=integral_patch_size,
        )
Exemple #8
0
def main():
    """Create CLI for training and run."""
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("training_job_path",
                        help="Path to training job profile JSON file.")
    parser.add_argument("labels_path",
                        help="Path to labels file to use for training.")
    parser.add_argument(
        "--video-paths",
        type=str,
        default="",
        help=
        "List of paths for finding videos in case paths inside labels file need fixing.",
    )
    parser.add_argument(
        "--val_labels",
        "--val",
        help=
        "Path to labels file to use for validation (overrides training job path if set).",
    )
    parser.add_argument(
        "--test_labels",
        "--test",
        help=
        "Path to labels file to use for test (overrides training job path if set).",
    )
    parser.add_argument(
        "--tensorboard",
        action="store_true",
        help="Enables TensorBoard logging to the run path.",
    )
    parser.add_argument(
        "--save_viz",
        action="store_true",
        help="Enables saving of prediction visualizations to the run folder.",
    )
    parser.add_argument("--zmq",
                        action="store_true",
                        help="Enables ZMQ logging (for GUI).")
    parser.add_argument(
        "--run_name",
        default="",
        help=
        "Run name to use when saving file, overrides other run name settings.",
    )
    parser.add_argument("--prefix",
                        default="",
                        help="Prefix to prepend to run name.")
    parser.add_argument("--suffix",
                        default="",
                        help="Suffix to append to run name.")

    args, _ = parser.parse_known_args()

    # Find job configuration file.
    job_filename = args.training_job_path
    if not os.path.exists(job_filename):
        profile_dir = get_package_file("sleap/training_profiles")

        if os.path.exists(os.path.join(profile_dir, job_filename)):
            job_filename = os.path.join(profile_dir, job_filename)
        else:
            raise FileNotFoundError(
                f"Could not find training profile: {job_filename}")

    # Load job configuration.
    job_config = TrainingJobConfig.load_json(job_filename)

    # Override config settings for CLI-based training.
    job_config.outputs.save_outputs = True
    job_config.outputs.tensorboard.write_logs = args.tensorboard
    job_config.outputs.zmq.publish_updates = args.zmq
    job_config.outputs.zmq.subscribe_to_controller = args.zmq
    if args.run_name != "":
        job_config.outputs.run_name = args.run_name
    if args.prefix != "":
        job_config.outputs.run_name_prefix = args.prefix
    if args.suffix != "":
        job_config.outputs.run_name_suffix = args.suffix
    job_config.outputs.save_visualizations = args.save_viz

    logger.info(f"Training labels file: {args.labels_path}")
    logger.info(f"Training profile: {job_filename}")
    logger.info("")

    # Log configuration to console.
    logger.info("Arguments:")
    logger.info(json.dumps(vars(args), indent=4))
    logger.info("")
    logger.info("Training job:")
    logger.info(job_config.to_json())
    logger.info("")

    logger.info("Initializing trainer...")
    # Create a trainer and run!
    trainer = Trainer.from_config(
        job_config,
        training_labels=args.labels_path,
        validation_labels=args.val_labels,
        test_labels=args.test_labels,
        video_search_paths=args.video_paths.split(","),
    )
    trainer.train()