def _save_as_bytes(path: str, obj: Any): path = os.path.abspath(path) folder = os.path.dirname(path) os.makedirs(folder, exist_ok=True) with open(path, "wb") as f: f.write(to_bytes(obj))
def save_pretrained(self, save_directory: Union[str, os.PathLike]): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method Arguments: save_directory (:obj:`str` or :obj:`os.PathLike`): Directory to which to save. Will be created if it doesn't exist. """ if os.path.isfile(save_directory): logger.error( f"Provided path ({save_directory}) should be a directory, not a file" ) return os.makedirs(save_directory, exist_ok=True) # get abs dir save_directory = os.path.abspath(save_directory) # save config as well self.config.architectures = [self.__class__.__name__[4:]] self.config.save_pretrained(save_directory) # save model with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f: model_bytes = to_bytes(self.params) f.write(model_bytes)
def test_namedtuple_serialization(self): foo_class = collections.namedtuple('Foo', 'a b c') x1 = foo_class(a=1, b=2, c=3) x1_serialized = serialization.to_bytes(x1) x2 = foo_class(a=0, b=0, c=0) restored_x1 = serialization.from_bytes(x2, x1_serialized) self.assertEqual(x1, restored_x1)
def test_variables_from_tar(vstate, tmp_path): fname = str(tmp_path) + "/file.tar" with tarfile.TarFile(fname, "w") as f: for i in range(10): save_binary_to_tar(f, serialization.to_bytes(vstate.variables), f"{i}.mpack") for name in [fname, fname[:-4]]: vstate2 = nk.variational.MCState(vstate.sampler, vstate.model, n_samples=10, seed=SEED + 100) for j in [0, 3, 8]: vstate2.variables = nk.variational.experimental.variables_from_tar( name, vstate2.variables, j) # check jax.tree_multimap(np.testing.assert_allclose, vstate.parameters, vstate2.parameters) with pytest.raises(KeyError): nk.variational.experimental.variables_from_tar( name, vstate2.variables, 15)
def test_model_serialization_to_bytes(self): rng = random.PRNGKey(0) module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) model = nn.Model(module, initial_params) serialized_bytes = serialization.to_bytes(model) restored_model = serialization.from_bytes(model, serialized_bytes) self.assertEqual(restored_model.params, model.params)
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `[`~FlaxPreTrainedModel.from_pretrained`]` class method Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face model hub after saving it. <Tip warning={true}> Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing folder. Pass along `temp_dir=True` to use a temporary directory instead. </Tip> kwargs: Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): logger.error( f"Provided path ({save_directory}) should be a directory, not a file" ) return if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo = self._create_or_get_repo(save_directory, **kwargs) os.makedirs(save_directory, exist_ok=True) # get abs dir save_directory = os.path.abspath(save_directory) # save config as well self.config.architectures = [self.__class__.__name__[4:]] self.config.save_pretrained(save_directory) # save model output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) with open(output_model_file, "wb") as f: params = params if params is not None else self.params model_bytes = to_bytes(params) f.write(model_bytes) logger.info(f"Model weights saved in {output_model_file}") if push_to_hub: url = self._push_to_hub(repo, commit_message=commit_message) logger.info(f"Model pushed to the hub in this commit: {url}")
def save_pretrained(self, folder): folder_abs = os.path.abspath(folder) if not os.path.exists(folder_abs): os.mkdir(folder_abs) with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f: model_bytes = to_bytes(self.params) f.write(model_bytes)
def _flush_params(self, variational_state): if not self._save_params: return binary_data = serialization.to_bytes(variational_state.variables) with open(self._prefix + ".mpack", "wb") as outfile: outfile.write(binary_data) self._steps_notflushed_pars = 0
def test_optimizer_serialization_to_bytes(self): rng = random.PRNGKey(0) module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) model = nn.Model(module, initial_params) optim_def = optim.Momentum(learning_rate=1.) optimizer = optim_def.create(model) serialized_bytes = serialization.to_bytes(optimizer) restored_optimizer = serialization.from_bytes(optimizer, serialized_bytes) self.assertEqual(restored_optimizer, optimizer)
def test_serialization_chunking2(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = {'a': np.ones((10, 10))} tmpbytes = serialization.to_bytes(tmp) newtmp = serialization.from_bytes(tmp, tmpbytes) finally: serialization.MAX_CHUNK_SIZE = old_chunksize jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp)
def test_restore_chunked(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = np.random.uniform(-100, 100, size=(21, 37)) serialized = serialization.to_bytes(tmp) restored = serialization.msgpack_restore(serialized) finally: serialization.MAX_CHUNK_SIZE = old_chunksize np.testing.assert_array_equal(restored, tmp)
def save_checkpoint(self, save_dir, state): state = jax_utils.unreplicate(state) print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ") self.model_save_fn(save_dir, params=state.params) with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f: f.write(to_bytes(state.opt_state)) joblib.dump(self.args, os.path.join(save_dir, "args.joblib")) joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib")) with open(os.path.join(save_dir, "training_state.json"), "w") as f: json.dump({"step": state.step.item()}, f) print("DONE")
def _save_variables(self, variational_state): if self._init is False: self._init_output() _time = time.time() binary_data = serialization.to_bytes(variational_state.variables) if self._tar: save_binary_to_tar( self._tar_file, binary_data, str(self._file_step) + ".mpack" ) else: with open(self._prefix + str(self._file_step) + ".mpack", "wb") as f: f.write(binary_data) self._file_step += 1 self._runtime_taken += time.time() - _time
def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str: path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: bool: allow overwriting when writing a checkpoint. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) logging.info('Writing to temporary checkpoint location: %s', ckpt_tmp_path) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite) logging.info('Saved checkpoint at %s', ckpt_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method Arguments: save_directory (:obj:`str` or :obj:`os.PathLike`): Directory to which to save. Will be created if it doesn't exist. push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to push your model to the Hugging Face model hub after saving it. kwargs: Additional key word arguments passed along to the :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. """ if os.path.isfile(save_directory): logger.error( f"Provided path ({save_directory}) should be a directory, not a file" ) return os.makedirs(save_directory, exist_ok=True) # get abs dir save_directory = os.path.abspath(save_directory) # save config as well self.config.architectures = [self.__class__.__name__[4:]] self.config.save_pretrained(save_directory) # save model output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) with open(output_model_file, "wb") as f: params = params if params is not None else self.params model_bytes = to_bytes(params) f.write(model_bytes) logger.info(f"Model weights saved in {output_model_file}") if push_to_hub: saved_files = [ os.path.join(save_directory, CONFIG_NAME), output_model_file ] url = self._push_to_hub(save_files=saved_files, **kwargs) logger.info(f"Model pushed to the hub in this commit: {url}")
def test_serialization(vstate): from flax import serialization bdata = serialization.to_bytes(vstate) old_params = vstate.parameters old_samples = vstate.samples old_nsamples = vstate.n_samples old_ndiscard = vstate.n_discard_per_chain vstate = nk.vqs.MCState(vstate.sampler, vstate.model, n_samples=10, seed=SEED + 100) vstate = serialization.from_bytes(vstate, bdata) jax.tree_multimap(np.testing.assert_allclose, vstate.parameters, old_params) np.testing.assert_allclose(vstate.samples, old_samples) assert vstate.n_samples == old_nsamples assert vstate.n_discard_per_chain == old_ndiscard
def test_variables_from_file(vstate, tmp_path): fname = str(tmp_path) + "/file.mpack" with open(fname, "wb") as f: f.write(serialization.to_bytes(vstate.variables)) for name in [fname, fname[:-6]]: vstate2 = nk.vqs.MCState(vstate.sampler, vstate.model, n_samples=10, seed=SEED + 100) vstate2.variables = nkx.vqs.variables_from_file( name, vstate2.variables) # check jax.tree_map(np.testing.assert_allclose, vstate.parameters, vstate2.parameters)
def test_serialization(vstate): from flax import serialization bdata = serialization.to_bytes(vstate) vstate_new = nk.variational.MCMixedState( vstate.sampler, vstate.model, n_samples=10, seed=SEED + 313 ) vstate_new = serialization.from_bytes(vstate_new, bdata) jax.tree_multimap( np.testing.assert_allclose, vstate.parameters, vstate_new.parameters ) np.testing.assert_allclose(vstate.samples, vstate_new.samples) np.testing.assert_allclose(vstate.diagonal.samples, vstate_new.diagonal.samples) assert vstate.n_samples == vstate_new.n_samples assert vstate.n_discard == vstate_new.n_discard assert vstate.n_samples_diag == vstate_new.n_samples_diag assert vstate.n_discard_diag == vstate_new.n_discard_diag
def save_checkpoint(ckpt_dir, state, step, keep=10, prefix='ckpt_'): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: Path to store checkpoint files in. state: Serializable flax object, usually a flax optimizer. step: Training step number or other metric number. keep: Number of checkpoints to keep. prefix: Checkepoint filename prefix. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = os.path.join(ckpt_dir, 'tmp') ckpt_destination_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_destination_path)) save_state = SaveState(state, step) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(save_state)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_destination_path, overwrite=True) logging.info('Saved checkpoint at %s', ckpt_destination_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint: %s', path) gfile.remove(path) return ckpt_destination_path
def save_checkpoint( model, save_dir, state, cur_step: int, with_opt: bool = True, push_to_hub: bool = False, ): state = jax_utils.unreplicate(state) if with_opt: logger.info(f"Saving optimizer and training state in {save_dir}...") with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f: f.write(to_bytes(state.opt_state)) with open(os.path.join(save_dir, "training_state.json"), "w") as f: json.dump({"step": state.step.item()}, f) logger.info( f'Saving model in {save_dir} {"and pushing it to HF Hub" if push_to_hub else ""}' ) model.save_pretrained( save_dir, params=state.params, push_to_hub=push_to_hub, commit_message=f"Saving weights and logs of step {cur_step}", )
def save_model(filename: str, model: nn.Module) -> None: gfile.makedirs(os.path.dirname(filename)) with gfile.GFile(filename, "wb") as fp: fp.write(serialization.to_bytes(model))
def train(self, training_data, training_z, batch_size=5000, niter=2000): """Trains the classifier Parameters: ----------- training_data: numpy array, size Ngalaxes x Nbands training data, each row is a galaxy, each column is a band as per band defined above training_z: numpy array, size Ngalaxies true redshift for the training sample """ # create scaler features = self.features_scaler.fit_transform(training_data) features = np.clip(features, -4, 4) labels = training_z # If model is already trained, we just load the weights if os.path.exists(self.export_name): with open(self.export_name, 'rb') as file: self.model = serialization.from_bytes(self.model, pickle.load(file)) return lr = 0.001 optimizer = optim.Adam(learning_rate=lr).create(self.model) @jax.jit def train_step(optimizer, batch): # This is the loss function def loss_fn(model): # Apply classifier to features w = model(batch['features']) # returns - score, because we want to maximize score if self.metric == 'SNR': return -metrics.compute_snr_score(w, batch['labels']) elif self.metric == 'FOM': # Minimizing the Area return 1. / metrics.compute_fom(w, batch['labels']) elif self.metric == 'FOM_DETF': # Minimizing the Area return 1. / metrics.compute_fom( w, batch['labels'], inds=[5, 6]) else: raise NotImplementedError # Compute gradients loss, g = jax.value_and_grad(loss_fn)(optimizer.target) # Perform gradient descent optimizer = optimizer.apply_gradient(g) return optimizer, loss # This function provides random batches of data, TODO: convert to JAX print("Size of dataset", len(labels)) def get_batch(): inds = onp.random.choice(len(labels), batch_size) return {'labels': labels[inds], 'features': features[inds]} losses = [] for i in range(niter): optimizer, loss = train_step(optimizer, get_batch()) losses.append(loss) if i % 100 == 0: print('iter: %d; Loss : %f' % (i, loss)) # Export model to disk with open(self.export_name, 'wb') as file: pickle.dump(serialization.to_bytes(optimizer.target), file) self.model = optimizer.target
def save_model(filename, model): os.makedirs(os.path.dirname(filename), exist_ok=True) with open(filename, "wb") as fp: fp.write(serialization.to_bytes(model))
def save_checkpoint(ckpt_dir: Union[str, os.PathLike], target, step, prefix='checkpoint_', keep=1, overwrite=False): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str or pathlib-like path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False). Returns: Filename of saved checkpoint. """ ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) if ckpt_dir.startswith('./'): ckpt_dir = ckpt_dir[2:] # gfile.glob() can remove leading './' ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) base_path = os.path.join(ckpt_dir, prefix) checkpoint_files = gfile.glob(base_path + '*') if ckpt_path in checkpoint_files: if not overwrite: raise errors.InvalidCheckpointError(ckpt_path, step) else: checkpoint_files.append(ckpt_path) checkpoint_files = natural_sort(checkpoint_files) if checkpoint_files[-1] == ckpt_tmp_path: checkpoint_files.pop(-1) if ckpt_path != checkpoint_files[-1]: if not overwrite: raise errors.InvalidCheckpointError(ckpt_path, step) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite) logging.info('Saved checkpoint at %s', ckpt_path) print(ckpt_path) # Remove newer checkpoints if overwrite: ind = checkpoint_files.index(ckpt_path) + 1 newer_ckpts = checkpoint_files[ind:] checkpoint_files = checkpoint_files[:ind] for path in newer_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) # Remove old checkpoint files. if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def save_weights(weight_path: str, model_params: Dict[str, Any]): """Save serialized weight dictionary.""" serialized_params = serialization.to_bytes(model_params) gfile.makedirs(os.path.dirname(weight_path)) with gfile.GFile(weight_path, 'wb') as fp: fp.write(serialized_params)