コード例 #1
0
    def get_filtered_configs(
        self, head_filter: Text = "", only_trained: bool = False
    ) -> List[ConfigFileInfo]:
        """Returns filtered subset of loaded configs."""

        base_config_dir = os.path.realpath(
            sleap_utils.get_package_file("sleap/training_profiles")
        )

        cfgs_to_return = []
        paths_included = []

        for cfg_info in self._configs:
            if cfg_info.head_name == head_filter or not head_filter:
                if not only_trained or cfg_info.has_trained_model:
                    # At this point we know that config is appropriate
                    # for this head type and is trained if that is required.

                    # We just want a single config from each model directory.
                    # Taking the first config we see in the directory means
                    # we'll get the *trained* config if there is one, since
                    # it will be newer and we've sorted by desc date modified.

                    # TODO: check filenames since timestamp sort could be off
                    #  if files were copied

                    cfg_dir = os.path.realpath(os.path.dirname(cfg_info.path))

                    if cfg_dir == base_config_dir or cfg_dir not in paths_included:
                        paths_included.append(cfg_dir)
                        cfgs_to_return.append(cfg_info)

        return cfgs_to_return
コード例 #2
0
    def make_from_labels_filename(cls,
                                  labels_filename: Text,
                                  head_filter: Optional[Text] = None):
        dir_paths = []
        if labels_filename:
            labels_model_dir = os.path.join(os.path.dirname(labels_filename),
                                            "models")
            dir_paths.append(labels_model_dir)

        base_config_dir = sleap_utils.get_package_file(
            "sleap/training_profiles")
        dir_paths.append(base_config_dir)

        return cls(dir_paths=dir_paths, head_filter=head_filter)
コード例 #3
0
    def make_from_labels_filename(
        cls, labels_filename: Text, head_filter: Optional[Text] = None
    ) -> "TrainingConfigsGetter":
        """
        Makes object which checks for models in default subdir for dataset.
        """
        dir_paths = []
        if labels_filename:
            labels_model_dir = os.path.join(os.path.dirname(labels_filename), "models")
            dir_paths.append(labels_model_dir)

        base_config_dir = sleap_utils.get_package_file("sleap/training_profiles")
        dir_paths.append(base_config_dir)

        return cls(dir_paths=dir_paths, head_filter=head_filter)
コード例 #4
0
    def from_name(cls, form_name: str, *args, **kwargs) -> "YamlFormWidget":
        """
        Instantiate class from the short name of form (e.g., "suggestions").

        Short name is converted to path to yaml file, and then class is
        instantiated using this path.

        Args:
            form_name: Short name of form, corresponds to name of yaml file.
            args: Positional args passed to class initializer.
            kwargs: Named args passed to class initializer.

        Returns:
            Instance of `YamlFormWidget` class.
        """
        yaml_path = get_package_file(f"sleap/config/{form_name}.yaml")
        return cls(yaml_path, *args, **kwargs)
コード例 #5
0
def main():
    """Create CLI for training and run."""
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("training_job_path",
                        help="Path to training job profile JSON file.")
    parser.add_argument("labels_path",
                        help="Path to labels file to use for training.")
    parser.add_argument(
        "--video-paths",
        type=str,
        default="",
        help=
        "List of paths for finding videos in case paths inside labels file need fixing.",
    )
    parser.add_argument(
        "--val_labels",
        "--val",
        help=
        "Path to labels file to use for validation (overrides training job path if set).",
    )
    parser.add_argument(
        "--test_labels",
        "--test",
        help=
        "Path to labels file to use for test (overrides training job path if set).",
    )
    parser.add_argument(
        "--tensorboard",
        action="store_true",
        help="Enables TensorBoard logging to the run path.",
    )
    parser.add_argument(
        "--save_viz",
        action="store_true",
        help="Enables saving of prediction visualizations to the run folder.",
    )
    parser.add_argument("--zmq",
                        action="store_true",
                        help="Enables ZMQ logging (for GUI).")
    parser.add_argument(
        "--run_name",
        default="",
        help=
        "Run name to use when saving file, overrides other run name settings.",
    )
    parser.add_argument("--prefix",
                        default="",
                        help="Prefix to prepend to run name.")
    parser.add_argument("--suffix",
                        default="",
                        help="Suffix to append to run name.")

    args, _ = parser.parse_known_args()

    # Find job configuration file.
    job_filename = args.training_job_path
    if not os.path.exists(job_filename):
        profile_dir = get_package_file("sleap/training_profiles")

        if os.path.exists(os.path.join(profile_dir, job_filename)):
            job_filename = os.path.join(profile_dir, job_filename)
        else:
            raise FileNotFoundError(
                f"Could not find training profile: {job_filename}")

    # Load job configuration.
    job_config = TrainingJobConfig.load_json(job_filename)

    # Override config settings for CLI-based training.
    job_config.outputs.save_outputs = True
    job_config.outputs.tensorboard.write_logs = args.tensorboard
    job_config.outputs.zmq.publish_updates = args.zmq
    job_config.outputs.zmq.subscribe_to_controller = args.zmq
    if args.run_name != "":
        job_config.outputs.run_name = args.run_name
    if args.prefix != "":
        job_config.outputs.run_name_prefix = args.prefix
    if args.suffix != "":
        job_config.outputs.run_name_suffix = args.suffix
    job_config.outputs.save_visualizations = args.save_viz

    logger.info(f"Training labels file: {args.labels_path}")
    logger.info(f"Training profile: {job_filename}")
    logger.info("")

    # Log configuration to console.
    logger.info("Arguments:")
    logger.info(json.dumps(vars(args), indent=4))
    logger.info("")
    logger.info("Training job:")
    logger.info(job_config.to_json())
    logger.info("")

    logger.info("Initializing trainer...")
    # Create a trainer and run!
    trainer = Trainer.from_config(
        job_config,
        training_labels=args.labels_path,
        validation_labels=args.val_labels,
        test_labels=args.test_labels,
        video_search_paths=args.video_paths.split(","),
    )
    trainer.train()