Exemplo n.º 1
0
    def _save_checkpoint(self, path: pathlib.Path) -> None:
        path.mkdir(parents=True, exist_ok=True)

        # Save model weights. We use `tf` format because `h5` does not support
        # models that subclass `tf.keras.Model` and define custom `call()`
        # and/or `train_step()` functions.
        self.model.save_weights(
            str(path.joinpath("determined-keras-model-weights")), save_format="tf"
        )

        # Save optimizer(s) weights.
        with h5py.File(path.joinpath("determined-keras-optimizer-weights.h5"), "w") as h5file:
            for idx, optimizer in enumerate(self.context._optimizers):
                opt_group = h5file.create_group(f"optimizer-{idx}")
                save_optimizer_weights_to_hdf5_group(opt_group, optimizer)

        # Save RNG state.
        rng_state = get_rng_state()

        with open(path.joinpath("rng_state.pkl"), "wb") as f:
            pickle.dump(rng_state, f)

        # Save user code.
        det.util.write_user_code(path, self.env.on_cluster)

        # Save callback(s) state.
        callbacks_state = self.multiplexer._get_state()
        with path.joinpath("determined-callbacks.v1.pkl").open("wb") as f:
            pickle.dump(callbacks_state, f)

        self.multiplexer._checkpoint_end(path)

        if self.wlsq is not None:
            with path.joinpath("workload_sequencer.pkl").open("wb") as f:
                pickle.dump(self.wlsq.get_state(), f)

        trial_cls = type(self.trial)
        with open(path.joinpath("load_data.json"), "w") as f2:
            json.dump(
                {
                    "trial_type": "TFKerasTrial",
                    "experiment_config": self.context.env.experiment_config,
                    "hparams": self.context.env.hparams,
                    "trial_cls_spec": f"{trial_cls.__module__}:{trial_cls.__qualname__}",
                },
                f2,
            )
Exemplo n.º 2
0
    def _save_checkpoint(self, path: pathlib.Path) -> workload.Response:
        if not self.is_chief:
            return workload.Skipped()

        path.mkdir(parents=True, exist_ok=True)

        # Save model weights. We use `tf` format because `h5` does not support
        # models that subclass `tf.keras.Model` and define custom `call()`
        # and/or `train_step()` functions.
        self.model.save_weights(str(
            path.joinpath("determined-keras-model-weights")),
                                save_format="tf")

        # Save optimizer(s) weights.
        with h5py.File(path.joinpath("determined-keras-optimizer-weights.h5"),
                       "w") as h5file:
            for idx, optimizer in enumerate(self.context._optimizers):
                opt_group = h5file.create_group(f"optimizer-{idx}")
                save_optimizer_weights_to_hdf5_group(opt_group, optimizer)

        # Save RNG state.
        rng_state = get_rng_state()

        with open(path.joinpath("rng_state.pkl"), "wb") as f:
            pickle.dump(rng_state, f)

        # Save user code.
        det.util.write_user_code(path, self.env.on_cluster)

        # Save callback(s) state.
        callbacks_state = self.multiplexer._get_state()
        with path.joinpath("determined-callbacks.v1.pkl").open("wb") as f:
            pickle.dump(callbacks_state, f)

        self.multiplexer._checkpoint_end(path)

        return {
            "framework": f"tensorflow-{tf.__version__}",
            "format": "saved_weights"
        }
Exemplo n.º 3
0
    def save_rng_state_with_checkpoint(self, checkpoint_dir: str) -> None:
        rng_state = get_rng_state()

        with open(checkpoint_dir + "/rng_state.pkl", "wb") as f:
            pickle.dump(rng_state, f)