Ejemplo n.º 1
0
    def test_checkpointing(self, tmp_path: Path,
                           xor_trial_controller: Callable) -> None:
        checkpoint_dir = str(tmp_path.joinpath("checkpoint"))
        latest_checkpoint = None
        steps_completed = 0
        old_loss = -1

        def make_workloads_1() -> workload.Stream:
            nonlocal old_loss

            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=10, validation_freq=10)
            training_metrics, validation_metrics = trainer.result()
            old_loss = validation_metrics[-1]["val_loss"]

            interceptor = workload.WorkloadResponseInterceptor()
            yield from interceptor.send(workload.checkpoint_workload())
            nonlocal latest_checkpoint, steps_completed
            latest_checkpoint = interceptor.metrics_result()["uuid"]
            steps_completed = trainer.get_steps_completed()

        controller = xor_trial_controller(
            self.hparams,
            make_workloads_1(),
            trial_seed=self.trial_seed,
            checkpoint_dir=checkpoint_dir,
        )
        controller.run()

        # Restore the checkpoint on a new trial instance and recompute
        # validation. The validation error should be the same as it was
        # previously.
        def make_workloads_2() -> workload.Stream:
            interceptor = workload.WorkloadResponseInterceptor()

            yield from interceptor.send(workload.validation_workload())
            metrics = interceptor.metrics_result()

            new_loss = metrics["metrics"]["validation_metrics"]["val_loss"]
            assert new_loss == pytest.approx(old_loss)

        controller = xor_trial_controller(
            self.hparams,
            make_workloads_2(),
            trial_seed=self.trial_seed,
            checkpoint_dir=checkpoint_dir,
            latest_checkpoint=latest_checkpoint,
            steps_completed=steps_completed,
        )
        controller.run()

        # Verify that we can load the model from a checkpoint dir
        ckpt_path = str(tmp_path / "checkpoint" / latest_checkpoint)
        model = keras.load_model_from_checkpoint_path(ckpt_path)
        assert isinstance(model, tf.keras.models.Model), type(model)
Ejemplo n.º 2
0
def export_model(trial_id: int, latest=False) -> tf.keras.Model:
    trial = client.get_trial(trial_id)
    checkpoint: client.Checkpoint = (trial.select_checkpoint(
        latest=True) if latest else trial.top_checkpoint())
    print(f"Checkpoint {checkpoint.uuid}")
    try:
        # Checkpoints from AWS deployment don't have these attributes
        print(f"Trial {checkpoint.trial_id}")
        print(f"Batch {checkpoint.batch_number}")
    except AttributeError:
        pass
    path = checkpoint.download()
    model = keras.load_model_from_checkpoint_path(path)
    return model
Ejemplo n.º 3
0
def test_checkpoint_loading(ckpt_ver):
    checkpoint_dir = os.path.join(utils.fixtures_path("ancient-checkpoints"),
                                  f"{ckpt_ver}-keras")
    model = keras.load_model_from_checkpoint_path(checkpoint_dir)
    assert isinstance(model, tf.keras.models.Model), type(model)
Ejemplo n.º 4
0
    def load_from_path(path: str,
                       tags: Optional[List[str]] = None,
                       **kwargs: Any) -> Any:
        """Loads a Determined checkpoint from a local file system path into memory.

        For PyTorch checkpoints, the return type is an object that inherits from
        ``determined.pytorch.PyTorchTrial`` as defined by the ``entrypoint`` field
        in the experiment config.

        For TensorFlow checkpoints, the return type is a TensorFlow autotrackable object.

        Arguments:
            path (string): Local path to the checkpoint directory.
            tags (list string, optional): Only relevant for TensorFlow
                SavedModel checkpoints. Specifies which tags are loaded from
                the TensorFlow SavedModel. See documentation for
                `tf.compat.v1.saved_model.load_v2
                <https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/saved_model/load_v2>`_.

        .. warning::

           Checkpoint.load_from_path() has been deprecated and will be removed in a future version.

           Please use one of the following instead to load your checkpoint:
             - ``det.pytorch.load_trial_from_checkpoint_path()``
             - ``det.keras.load_model_from_checkpoint_path()``
             - ``det.estimator.load_estimator_from_checkpoint_path()``
        """
        warnings.warn(
            "Checkpoint.load_from_path() has been deprecated and will be removed in a future "
            "version.\n"
            "\n"
            "Please use one of the following instead to load your checkpoint:\n"
            "  - det.pytorch.load_trial_from_checkpoint_path()\n"
            "  - det.keras.load_model_from_checkpoint_path()\n"
            "  - det.estimator.load_estimator_from_checkpoint_path()\n",
            FutureWarning,
        )
        checkpoint_dir = pathlib.Path(path)
        metadata = Checkpoint._parse_metadata(checkpoint_dir)
        checkpoint_type = Checkpoint._get_type(metadata)

        if checkpoint_type == ModelFramework.PYTORCH:
            from determined import pytorch

            return pytorch.load_trial_from_checkpoint_path(path, **kwargs)

        if checkpoint_type == ModelFramework.TENSORFLOW:
            save_format = metadata.get("format", "saved_model")

            # For tf.estimators we save the entire model using the saved_model format.
            # For tf.keras we save only the weights also using the saved_model format,
            # which we call saved_weights.
            if cast(str, save_format) == "saved_model":
                from determined import estimator

                return estimator.load_estimator_from_checkpoint_path(
                    path, tags)

            if save_format in ("saved_weights", "h5"):
                from determined import keras

                return keras.load_model_from_checkpoint_path(path, tags)

        raise AssertionError("Unknown checkpoint format at {}".format(path))
Ejemplo n.º 5
0
def _export_and_load_model(experiment_id: int, master_url: str) -> None:
    # Normally verifying that we can load a model would be a good unit test, but making this an e2e
    # test ensures that our model saving and loading works with all the versions of tf that we test.
    ckpt = client.Determined(master_url).get_experiment(
        experiment_id).top_checkpoint()
    _ = keras.load_model_from_checkpoint_path(ckpt.download())