def _save_checkpoint(self, path: pathlib.Path) -> None: path.mkdir(parents=True, exist_ok=True) # Save model weights. We use `tf` format because `h5` does not support # models that subclass `tf.keras.Model` and define custom `call()` # and/or `train_step()` functions. self.model.save_weights( str(path.joinpath("determined-keras-model-weights")), save_format="tf" ) # Save optimizer(s) weights. with h5py.File(path.joinpath("determined-keras-optimizer-weights.h5"), "w") as h5file: for idx, optimizer in enumerate(self.context._optimizers): opt_group = h5file.create_group(f"optimizer-{idx}") save_optimizer_weights_to_hdf5_group(opt_group, optimizer) # Save RNG state. rng_state = get_rng_state() with open(path.joinpath("rng_state.pkl"), "wb") as f: pickle.dump(rng_state, f) # Save user code. det.util.write_user_code(path, self.env.on_cluster) # Save callback(s) state. callbacks_state = self.multiplexer._get_state() with path.joinpath("determined-callbacks.v1.pkl").open("wb") as f: pickle.dump(callbacks_state, f) self.multiplexer._checkpoint_end(path) if self.wlsq is not None: with path.joinpath("workload_sequencer.pkl").open("wb") as f: pickle.dump(self.wlsq.get_state(), f) trial_cls = type(self.trial) with open(path.joinpath("load_data.json"), "w") as f2: json.dump( { "trial_type": "TFKerasTrial", "experiment_config": self.context.env.experiment_config, "hparams": self.context.env.hparams, "trial_cls_spec": f"{trial_cls.__module__}:{trial_cls.__qualname__}", }, f2, )
def _save_checkpoint(self, path: pathlib.Path) -> workload.Response: if not self.is_chief: return workload.Skipped() path.mkdir(parents=True, exist_ok=True) # Save model weights. We use `tf` format because `h5` does not support # models that subclass `tf.keras.Model` and define custom `call()` # and/or `train_step()` functions. self.model.save_weights(str( path.joinpath("determined-keras-model-weights")), save_format="tf") # Save optimizer(s) weights. with h5py.File(path.joinpath("determined-keras-optimizer-weights.h5"), "w") as h5file: for idx, optimizer in enumerate(self.context._optimizers): opt_group = h5file.create_group(f"optimizer-{idx}") save_optimizer_weights_to_hdf5_group(opt_group, optimizer) # Save RNG state. rng_state = get_rng_state() with open(path.joinpath("rng_state.pkl"), "wb") as f: pickle.dump(rng_state, f) # Save user code. det.util.write_user_code(path, self.env.on_cluster) # Save callback(s) state. callbacks_state = self.multiplexer._get_state() with path.joinpath("determined-callbacks.v1.pkl").open("wb") as f: pickle.dump(callbacks_state, f) self.multiplexer._checkpoint_end(path) return { "framework": f"tensorflow-{tf.__version__}", "format": "saved_weights" }
def save_rng_state_with_checkpoint(self, checkpoint_dir: str) -> None: rng_state = get_rng_state() with open(checkpoint_dir + "/rng_state.pkl", "wb") as f: pickle.dump(rng_state, f)