Ejemplo n.º 1
0
    def _load(self) -> None:
        self.multiplexer_load_state = None  # type: Optional[Dict]
        if not self.load_path:
            return

        # Find model code path, we check multiple naming conventions for backwards compatibility.
        if self.load_path.joinpath("determined-keras-model.h5").exists():
            self._load_model_and_optimizer_weights_v2()
        elif self.load_path.joinpath(
                "determined-keras-optimizer-weights.h5").exists():
            self._load_model_and_optimizer_weights_v3()
        else:
            self._load_model_and_optimizer_weights_v1()

        # Load RNG state.
        try:
            with open(self.load_path.joinpath("rng_state.pkl"), "rb") as f:
                rng_state = pickle.load(f)

            set_rng_state(rng_state)
        except IOError:
            logging.warning("Checkpoint did not include RNG state.")

        # Load callbacks.
        cb_state_path = self.load_path.joinpath("determined-callbacks.v1.pkl")
        if cb_state_path.exists():
            with cb_state_path.open("rb") as f:
                self.multiplexer_load_state = pickle.load(f)
Ejemplo n.º 2
0
    def _load(self, load_path: pathlib.Path) -> None:
        # Find model code path, we check multiple naming conventions for backwards compatibility.
        if load_path.joinpath("determined-keras-model.h5").exists():
            self._load_model_and_optimizer_weights_v2(load_path)
        elif load_path.joinpath("determined-keras-optimizer-weights.h5").exists():
            self._load_model_and_optimizer_weights_v3(load_path)
        else:
            self._load_model_and_optimizer_weights_v1(load_path)

        # Load RNG state.
        try:
            with open(load_path.joinpath("rng_state.pkl"), "rb") as f:
                rng_state = pickle.load(f)

            set_rng_state(rng_state)
        except IOError:
            logging.warning("Checkpoint did not include RNG state.")

        # Load callbacks.
        cb_state_path = load_path.joinpath("determined-callbacks.v1.pkl")
        if cb_state_path.exists():
            with cb_state_path.open("rb") as f:
                self.multiplexer_load_state = pickle.load(f)

        # Load WorkloadSequencer state.
        wlsq_path = load_path.joinpath("workload_sequencer.pkl")
        if self.wlsq is not None and wlsq_path.exists():
            with wlsq_path.open("rb") as f:
                self.wlsq.load_state(pickle.load(f))
Ejemplo n.º 3
0
    def load_rng_state_from_checkpoint(self, checkpoint_dir: str) -> None:
        rng_state = None
        try:
            with open(checkpoint_dir + "/rng_state.pkl", "rb") as f:
                rng_state = pickle.load(f)
        except IOError:
            # backward compatibility: this is expected if it's a checkpoint
            # from before the on_checkpoint_end hook was added above
            logging.warn("No RNG state found in checkpoint_dir")
            return

        if rng_state is not None:
            logging.info("Restoring RNG state from checkpoint")
            set_rng_state(rng_state)