예제 #1
0
    def __init__(self, parameters: dict = None, tags: list = None, disabled: bool = False, extras: dict = None):
        if parameters is None:
            parameters = {}
        if tags is None:
            tags = []
        self.disabled = disabled
        if extras is None:
            extras = {}

        if not disabled:
            self.name = parameters['name'] if 'name' in parameters else 'Unnamed'

            neptune_config_file = Path('config') / 'neptune.yaml'
            if not neptune_config_file.exists():
                raise AttributeError('Missing Neptune config. Create config/neptune.yaml with token and project keys.')

            config = load_yaml(neptune_config_file)
            if 'project' not in config or 'token' not in config:
                raise AttributeError('Missing Neptune config. Create config/neptune.yaml with token and project keys.')

            self.project_name = config['project']
            self.api_token = config['token']

            logger.info("Experiment tracked with Neptune:")
            neptune.init(project_qualified_name=self.project_name, api_token=self.api_token)

            self.exp = neptune.create_experiment(name=self.name, tags=tags, params=parameters, **extras)
        else:
            logger.info("Tracking is disabled.")
예제 #2
0
    def load_pretrained(
        self,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        force_download: bool = False,
        strict: bool = False,
        print_incompatible_keys: bool = False,
        fix_dict_keys: bool = True,
    ):
        """
        Instance method to (optionnaly) download and load the weights of the given pre-trained model.
        If the `pretrained_model_name_or_path` is a string (the name/id of the model), it will check if it is present in
        the cache (in user folder ~/.solarnet/models). If not in cache, it will be downloaded from the MinIO models bucket.
        If the `pretrained_model_name_or_path` is a path, it will use the model.pt from this folder.
        Use this method to load (some) weights from a pre-trained model while customizing the final architecture. For
          example to finetune on a downstream task. Use strict=False if the model architecture is not exactly the same.
          Otherwise, an error will be raised.

        :param pretrained_model_name_or_path: the name/id of the model for download, or a path where the model exists in the
                                              the local filesystem.
        :param force_download: if true, do not search in the cache and download the model.
        :param strict: if true, an error will be raised if the architecture is not the same and weights fail to be loaded.
        :param fix_dict_keys: if true, some known state keys will be renamed (ex: encoder to backbone).
        """

        path, config_path = download_or_cached_or_local_model_path(
            pretrained_model_name_or_path, force_download=force_download)

        # Config
        config = {}
        backbone = None
        if config_path is not None:
            config = load_yaml(config_path)
            backbone = config.get("backbone", None)
        logger.info(
            f"Model {pretrained_model_name_or_path} loaded with config:")
        logger.info(config)

        if backbone is not None and self.backbone_name != "undefined" and backbone != self.backbone_name:
            raise RuntimeError(
                "The backbone of the pretrained model is different.")

        state_dict = torch.load(path)

        # Fix some know dict keys
        # encoder -> backbone
        if fix_dict_keys:
            state_dict = {
                k.replace("encoder.", "backbone."): v
                for k, v in state_dict.items()
            }

        incompatible_keys = self.load_state_dict(state_dict, strict=strict)
        if print_incompatible_keys:
            print_incompatible_keys_fn(incompatible_keys)

        self.eval()
예제 #3
0
파일: main.py 프로젝트: jdonzallaz/solarnet
def dataset_command(
        config_file: Path = typer.Argument(Path("config") / "dataset.yaml"),
        verbose: bool = typer.Option(False, "--verbose", "-v"),
):
    if verbose:
        set_log_level(logging.INFO)

    params = load_yaml(config_file)

    make_dataset(params)
예제 #4
0
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        force_download: bool = False,
        strict: bool = False,
        print_incompatible_keys: bool = False,
        **kwargs,
    ) -> pl.LightningModule:
        """
        Class method which will build a model, and download and load the weights of the given pre-trained model.
        If the `pretrained_model_name_or_path` is a string (the name/id of the model), it will check if it is present in
        the cache (in user folder ~/.solarnet/models). If not in cache, it will be downloaded from the MinIO models bucket.
        If the `pretrained_model_name_or_path` is a path, it will use the model.pt from this folder.
        kwargs are used to override model config.
        Use this method to create a model and load weights trained with the same model architecture.

        :param pretrained_model_name_or_path: the name/id of the model for download, or a path where the model exists in the
                                              the local filesystem.
        :param force_download: if true, do not search in the cache and download the model.
        :param strict: if true, an error will be raised if the architecture is not the same and weights fail to be loaded.
        :param kwargs: used to override model config.
        :return: a pl.LightningModule instance
        """

        path, config_path = download_or_cached_or_local_model_path(
            pretrained_model_name_or_path, force_download=force_download)

        # Config
        config = {}
        backbone = None
        if config_path is not None:
            config = load_yaml(config_path)
            backbone = config.get("backbone", None)
        logger.info(
            f"Model {pretrained_model_name_or_path} loaded with config:")
        logger.info(config)
        hparams = config.pop("hparams", {})
        hparams = {**hparams, **kwargs}

        # load
        model = cls(**hparams)
        if backbone is not None and model.backbone_name != "undefined" and backbone != model.backbone_name:
            raise AttributeError(
                f"Model {pretrained_model_name_or_path} is not compatible with class {cls}."
            )
        state_dict = torch.load(path)
        incompatible_keys = model.load_state_dict(state_dict, strict=strict)
        if print_incompatible_keys:
            print_incompatible_keys_fn(incompatible_keys)

        model.eval()

        return model
예제 #5
0
    def objective(trial):
        # Load parameters
        config_path = Path("config") / "config.yaml"
        parameters = load_yaml(config_path)

        # Suggest parameters
        parameters["model"]["learning_rate"] = trial.suggest_float(
            "learning_rate", 1e-5, 6e-3, log=True
        )
        parameters["trainer"]["epochs"] = trial.suggest_int("epochs", 5, 30, 5)
        parameters["trainer"]["batch_size"] = trial.suggest_int("batch_size", 32, 256, 32)
        parameters["data"]["channel"] = trial.suggest_categorical(
            "channel",
            [
                "94",
                "131",
                "171",
                "193",
                "211",
                "304",
                "335",
                "1700",
                "continuum",
                "magnetogram",
            ],
        )
        # parameters["data"]["size"] = trial.suggest_int("size", 128, 256)
        parameters["model"]["activation"] = trial.suggest_categorical(
            "activation", ["relu", "selu", "relu6", "tanh", "prelu", "leakyrelu"]
        )

        # Write parameters
        write_yaml(config_path, parameters)

        # Run pipeline
        process = subprocess.run(["dvc", "repro", "--glob", f"*@{model}", "-q"])

        # Check metric
        metrics = load_yaml(Path("models") / model / "metrics.yaml")
        return metrics[metric]
예제 #6
0
def s3_write_config():
    # Check MinIO config
    minio_config_file = WRITE_CONFIG_PATH
    if not minio_config_file.exists():
        raise AttributeError(
            "Missing Minio config. Create config/minio.yaml with aws_access_key_id, aws_secret_access_key and endpoint_url keys."
        )

    config = load_yaml(minio_config_file)
    if "aws_access_key_id" not in config or "aws_secret_access_key" not in config or "endpoint_url" not in config:
        raise AttributeError(
            "Missing Minio config. Create config/minio.yaml with aws_access_key_id, aws_secret_access_key and endpoint_url keys."
        )

    return config
예제 #7
0
    def __init__(self, parameters: dict = None, tags: list = None, disabled: bool = False, extras: dict = None, config_file: Path = Path("config") / "neptune.yaml"):
        if parameters is None:
            parameters = {}
        if tags is None:
            tags = []
        self.disabled = disabled
        if extras is None:
            extras = {}

        if not disabled:
            neptune_config_file = config_file
            if not neptune_config_file.exists():
                raise AttributeError('Missing Neptune config. Create config/neptune.yaml with token and project keys.')

            config = load_yaml(neptune_config_file)
            if 'project' not in config or 'token' not in config:
                raise AttributeError('Missing Neptune config. Create config/neptune.yaml with token and project keys.')

            self.project_name = config['project']
            self.api_token = config['token']

            logger.info("Experiment tracked with Neptune:")

            if "run_id" in extras:

                # Resume logging (do not add name and tags)
                self.run = neptune_new.init(
                    project=self.project_name,
                    api_token=self.api_token, run=extras.pop("run_id"),
                    **extras,
                )
            else:
                self.run = neptune_new.init(
                    project=self.project_name,
                    api_token=self.api_token,
                    name=parameters.get('name'),
                    tags=tags,
                    **extras,
                )

            self.run['parameters'] = parameters
        else:
            logger.info("Tracking is disabled.")
예제 #8
0
파일: test.py 프로젝트: jdonzallaz/solarnet
def test(parameters: dict, verbose: bool = False):
    logger.info("Testing...")

    seed_everything(parameters["seed"])

    model_path = Path(parameters["path"])
    plot_path = Path(parameters["path"]) / "test_plots"
    plot_path.mkdir(parents=True, exist_ok=True)
    metadata_path = model_path / "metadata.yaml"
    metadata = load_yaml(metadata_path) if metadata_path.exists() else None

    regression = parameters["data"]["targets"] == "regression"
    labels = None if regression else [
        list(x.keys())[0] for x in parameters["data"]["targets"]["classes"]
    ]
    n_class = 1 if regression else len(
        parameters["data"]["targets"]["classes"])
    parameters["system"]["gpus"] = min(1, parameters["system"]["gpus"])

    # Tracking
    tracking: Optional[Tracking] = None
    if parameters["tracking"] and metadata is not None and metadata[
            "tracking_id"] is not None:
        run_id = metadata["tracking_id"]
        tracking = NeptuneNewTracking.resume(run_id)

    datamodule = datamodule_from_config(parameters)
    datamodule.setup("test")
    logger.info(f"Data format: {datamodule.size()}")

    model_class = ImageRegression if regression else ImageClassification
    model = model_class.load_from_checkpoint(str(model_path / "model.ckpt"))
    logger.info(f"Model: {model}")

    trainer = pl.Trainer(
        gpus=parameters["system"]["gpus"],
        logger=None,
    )

    # Evaluate model
    raw_metrics = trainer.test(model, datamodule=datamodule, verbose=verbose)
    raw_metrics = raw_metrics[0]

    if regression:
        metrics = {
            "mae": raw_metrics["test_mae"],
            "mse": raw_metrics["test_mse"],
        }
    else:
        tp = raw_metrics.pop("test_tp")  # hits
        fp = raw_metrics.pop("test_fp")  # false alarm
        tn = raw_metrics.pop("test_tn")  # correct negative
        fn = raw_metrics.pop("test_fn")  # miss

        metrics = {
            "balanced_accuracy": raw_metrics.pop("test_recall"),
            **stats_metrics(tp, fp, tn, fn)
        }

        for key, value in raw_metrics.items():
            metrics[key[len("test_"):]] = value
        metrics = dict(sorted(metrics.items()))

    write_yaml(model_path / "metrics.yaml", metrics)
    if tracking:
        tracking.log_metrics(metrics, "metrics/test")

    # Prepare a set of test samples
    model.freeze()
    nb_image_grid = 10
    dataset_image, dataloader = get_random_test_samples_dataloader(
        parameters,
        transform=datamodule.transform,
        nb_sample=nb_image_grid,
        classes=None if regression else list(range(n_class)),
    )
    y, y_pred, y_proba = predict(model,
                                 dataloader,
                                 regression,
                                 return_proba=True)
    images, _ = map(list, zip(*dataset_image))
    plot_image_grid(
        images,
        y,
        y_pred,
        y_proba,
        labels=labels,
        save_path=Path(plot_path / "test_samples.png"),
        max_images=nb_image_grid,
    )
    if tracking:
        tracking.log_artifact(plot_path / "test_samples.png",
                              "metrics/test/test_samples")

    # Confusion matrix or regression line
    y, y_pred, y_proba = predict(model,
                                 datamodule.test_dataloader(),
                                 regression,
                                 return_proba=True)

    if regression:
        plot_path = Path(plot_path / "regression_line.png")
        plot_regression_line(y, y_pred, save_path=plot_path)

        if tracking:
            tracking.log_artifact(plot_path, "metrics/test/regression_line")
    else:
        # Confusion matrix
        confusion_matrix_path = Path(plot_path / "confusion_matrix.png")
        plot_confusion_matrix(y,
                              y_pred,
                              labels,
                              save_path=confusion_matrix_path)
        # Roc curve
        if n_class <= 2:
            roc_curve_path = Path(plot_path / "roc_curve.png")
            plot_roc_curve(y,
                           y_proba,
                           n_class=n_class,
                           save_path=roc_curve_path,
                           figsize=(7, 5))

        if tracking:
            tracking.log_artifact(confusion_matrix_path,
                                  "metrics/test/confusion_matrix")
            if n_class <= 2:
                tracking.log_artifact(roc_curve_path, "metrics/test/roc_curve")

    if tracking:
        tracking.end()