Ejemplo n.º 1
0
def save_optimizer_weights(model, filepath, overwrite=True, **kwargs):
    if not isinstance(filepath, h5py.File):
        # If file exists and should not be overwritten.
        if not overwrite and os.path.isfile(filepath):
            proceed = hdf5_format.ask_to_proceed_with_overwrite(filepath)
            if not proceed:
                return
        f = h5py.File(filepath, mode='w')
        opened_new_file = True
    else:
        f = filepath
        opened_new_file = False
    try:
        model_metadata = saving_utils.model_metadata(
            model, include_optimizer=True, require_config=False)
        for k, v in model_metadata.items():
            if isinstance(v, (dict, list, tuple)):
                f.attrs[k] = json.dumps(
                    v, default=serialization.get_json_type).encode('utf8')
            else:
                f.attrs[k] = v
        if not isinstance(model.optimizer, optimizers.TFOptimizer):
            hdf5_format.save_optimizer_weights_to_hdf5_group(f, model.optimizer)
        f.flush()
    finally:
        if opened_new_file:
            f.close()
Ejemplo n.º 2
0
    def _save_checkpoint(self, path: pathlib.Path) -> workload.Response:
        if not self.is_chief:
            return workload.Skipped()

        path.mkdir(parents=True, exist_ok=True)

        # Save model weights. We use `tf` format because `h5` does not support
        # models that subclass `tf.keras.Model` and define custom `call()`
        # and/or `train_step()` functions.
        self.model.save_weights(str(
            path.joinpath("determined-keras-model-weights")),
                                save_format="tf")

        # Save optimizer(s) weights.
        with h5py.File(path.joinpath("determined-keras-optimizer-weights.h5"),
                       "w") as h5file:
            for idx, optimizer in enumerate(self.context._optimizers):
                opt_group = h5file.create_group(f"optimizer-{idx}")
                save_optimizer_weights_to_hdf5_group(opt_group, optimizer)

        # Save RNG state.
        rng_state = {
            "np_rng_state": np.random.get_state(),
            "random_rng_state": random.getstate()
        }
        if version.parse(tf.__version__) < version.parse("2.0.0"):
            rng_state["tf_rng_global_seed"] = tf.compat.v1.random.get_seed(
                0)[0]
        else:
            generator = tf.random.get_global_generator()
            rng_state["tf2_rng_global_algorithm"] = generator.algorithm
            rng_state["tf2_rng_global_state"] = generator.state
        with open(path.joinpath("rng_state.pkl"), "wb") as f:
            pickle.dump(rng_state, f)

        # Save user code.
        det.util.write_user_code(path)

        # Save callback(s) state.
        callbacks_state = self.multiplexer._get_state()
        with path.joinpath("determined-callbacks.v1.pkl").open("wb") as f:
            pickle.dump(callbacks_state, f)

        self.multiplexer._checkpoint_end(path)

        return {
            "framework": f"tensorflow-{tf.__version__}",
            "format": "saved_weights"
        }
Ejemplo n.º 3
0
    def _get_optimizer_state(self, optimizer, optimizer_name=None):
        state = io.BytesIO()
        with h5py.File(state, 'w') as f:
            hdf5_format.save_optimizer_weights_to_hdf5_group(f, optimizer)

        return state
Ejemplo n.º 4
0
 def save_optimizer_to_hdf5(model, filepath):
     with h5py.File(filepath, mode='w') as f:
         if (model.optimizer and
                 not isinstance(model.optimizer, optimizers.TFOptimizer)):
             save_optimizer_weights_to_hdf5_group(f, model.optimizer)
         f.flush()