def restore_checkpoint(load_dir, state): logger.info(f"Restoring checkpoint from {load_dir}") with open(os.path.join(load_dir, "flax_model.msgpack"), "rb") as f: params = from_bytes(state.params, f.read()) with open(os.path.join(load_dir, "opt_state.msgpack"), "rb") as f: opt_state = from_bytes(state.opt_state, f.read()) with open(os.path.join(load_dir, "training_state.json"), "r") as f: training_state = json.load(f) step = training_state["step"] logger.info(f"Checkpoint restored at step {step}") return state.replace(step=step, params=params, opt_state=opt_state), step
def load_from_zip(agent_dict, load_path): """ """ # Check if the file exists if isinstance(load_path, str): if not os.path.exists(load_path): if os.path.exists(load_path + ".zip"): load_path += ".zip" else: raise ValueError( "Error: the file {:} could not be found.".format( load_path)) # Open file and load the agent components with zipfile.ZipFile(load_path, "r") as f: namelist = f.namelist() for name, target in agent_dict.items(): # Skip components that were not saved if name not in namelist: continue serialized = f.read(name) agent[name] = from_bytes(target, serialized) return agent_dict
def test_namedtuple_serialization(self): foo_class = collections.namedtuple('Foo', 'a b c') x1 = foo_class(a=1, b=2, c=3) x1_serialized = serialization.to_bytes(x1) x2 = foo_class(a=0, b=0, c=0) restored_x1 = serialization.from_bytes(x2, x1_serialized) self.assertEqual(x1, restored_x1)
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'): """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: ckpt_1, ckpt_2, ckpt_3 --> ckpt_3 ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1 ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5 Args: ckpt_dir: str: directory of checkpoints to restore from. target: matching object to rebuild via deserialized state-dict. step: int: step number to load or None to load latest. prefix: str: name prefix of checkpoint files. Returns: Restored `target` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in `target` unchanged. """ if step: ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) if not gfile.exists(ckpt_path): raise ValueError(f'Matching checkpoint not found: {ckpt_path}') else: glob_path = os.path.join(ckpt_dir, f'{prefix}*') checkpoint_files = natural_sort(gfile.glob(glob_path)) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path] if not checkpoint_files: return target ckpt_path = checkpoint_files[-1] logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: return serialization.from_bytes(target, fp.read())
def restore_checkpoint(save_dir, state): print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ") with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f: params = from_bytes(state.params, f.read()) with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f: opt_state = from_bytes(state.opt_state, f.read()) args = joblib.load(os.path.join(save_dir, "args.joblib")) data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib")) with open(os.path.join(save_dir, "training_state.json"), "r") as f: training_state = json.load(f) step = training_state["step"] print("DONE") return params, opt_state, step, args, data_collator
def test_model_serialization_to_bytes(self): rng = random.PRNGKey(0) module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) model = nn.Model(module, initial_params) serialized_bytes = serialization.to_bytes(model) restored_model = serialization.from_bytes(model, serialized_bytes) self.assertEqual(restored_model.params, model.params)
def load_weights(weight_path: str) -> Dict[str, Any]: """Load and deserialize weight dictionary.""" if not gfile.exists(weight_path): raise ValueError('Matching checkpoint not found: {}'.format(weight_path)) else: logging.info('Loading weights from %s', weight_path) with gfile.GFile(weight_path, 'rb') as fp: params = serialization.from_bytes(None, fp.read()) return jax.tree_map(jnp.asarray, params)
def test_serialization_chunking2(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = {'a': np.ones((10, 10))} tmpbytes = serialization.to_bytes(tmp) newtmp = serialization.from_bytes(tmp, tmpbytes) finally: serialization.MAX_CHUNK_SIZE = old_chunksize jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp)
def test_optimizer_serialization_to_bytes(self): rng = random.PRNGKey(0) module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) model = nn.Model(module, initial_params) optim_def = optim.Momentum(learning_rate=1.) optimizer = optim_def.create(model) serialized_bytes = serialization.to_bytes(optimizer) restored_optimizer = serialization.from_bytes(optimizer, serialized_bytes) self.assertEqual(restored_optimizer, optimizer)
def restore_from_path(ckpt_dir, target, step=None, prefix='ckpt_'): """Restores a checkpoint from a directory path, if available.""" ckpt_destination_path = latest_checkpoint_path(ckpt_dir, prefix) if ckpt_destination_path is None: logging.info('No checkpoints found, starting from the beginning.') return target, step logging.info('Restoring checkpoint: %s', ckpt_destination_path) save_state = SaveState(target, step) with gfile.GFile(ckpt_destination_path, 'rb') as fp: save_state = serialization.from_bytes(save_state, fp.read()) return save_state.train_state, save_state.step
def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path): """Load flax checkpoints in a PyTorch model""" flax_checkpoint_path = os.path.abspath(flax_checkpoint_path) logger.info(f"Loading Flax weights from {flax_checkpoint_path}") # import correct flax class flax_cls = getattr(transformers, "Flax" + model.__class__.__name__) # load flax weight dict with open(flax_checkpoint_path, "rb") as state_f: try: flax_state_dict = from_bytes(flax_cls, state_f.read()) except UnpicklingError: raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ") return load_flax_weights_in_pytorch_model(model, flax_state_dict)
def variables_from_file(filename: str, variables: _PyTree): """ Loads the variables of a variational state from a `.mpack` file. Args: filename: the file containing the variables. Assumes a .mpack extension and adds it if missing and no file exists. variables: An object variables with the same structure and shape of the object to be deserialized. Returns: a PyTree like variables Examples: Serializing the data: >>> import netket as nk >>> import flax >>> # construct an RBM model on 10 spins >>> vstate = nk.variational.MCState( ... nk.sampler.MetropolisLocal(nk.hilbert.Spin(0.5)**10), ... nk.models.RBM()) >>> with open("test.mpack", 'wb') as file: ... bytes_written = file.write(flax.serialization.to_bytes(vstate.variables)) >>> print(bytes_written) 1052 >>> >>> # Deserialize the data >>> >>> del vstate >>> # construct an RBM model on 10 spins >>> vstate2 = nk.variational.MCState( ... nk.sampler.MetropolisLocal(nk.hilbert.Spin(0.5)**10), ... nk.models.RBM()) >>> # Load the data by passing the model >>> vars = nk.variational.experimental.variables_from_file("test.mpack", ... vstate2.variables) >>> # update the variables of vstate with the loaded data. >>> vstate2.variables = vars """ if not _path.isfile(filename): if filename[-6:] != ".mpack": filename = filename + ".mpack" with open(filename, "rb") as f: return _serialization.from_bytes(variables, f.read())
def test_serialization(vstate): from flax import serialization bdata = serialization.to_bytes(vstate) old_params = vstate.parameters old_samples = vstate.samples old_nsamples = vstate.n_samples old_ndiscard = vstate.n_discard_per_chain vstate = nk.vqs.MCState(vstate.sampler, vstate.model, n_samples=10, seed=SEED + 100) vstate = serialization.from_bytes(vstate, bdata) jax.tree_multimap(np.testing.assert_allclose, vstate.parameters, old_params) np.testing.assert_allclose(vstate.samples, old_samples) assert vstate.n_samples == old_nsamples assert vstate.n_discard_per_chain == old_ndiscard
def variables_from_tar(filename: str, variables: _PyTree, i: int): """ Loads the variables of a variational state from the i-th element of a `.tar` archive. Args: filename: the tar archive name. Assumes a .tar extension and adds it if missing and no file exists. variables: An object variables with the same structure and shape of the object to be deserialized. i: the index of the variables to load """ if not _path.isfile(filename): if filename[-4:] != ".tar": filename = filename + ".tar" with _tarfile.TarFile(filename, "r") as file: info = file.getmember(str(i) + ".mpack") with file.extractfile(info) as f: return _serialization.from_bytes(variables, f.read())
def test_serialization(vstate): from flax import serialization bdata = serialization.to_bytes(vstate) vstate_new = nk.variational.MCMixedState( vstate.sampler, vstate.model, n_samples=10, seed=SEED + 313 ) vstate_new = serialization.from_bytes(vstate_new, bdata) jax.tree_multimap( np.testing.assert_allclose, vstate.parameters, vstate_new.parameters ) np.testing.assert_allclose(vstate.samples, vstate_new.samples) np.testing.assert_allclose(vstate.diagonal.samples, vstate_new.diagonal.samples) assert vstate.n_samples == vstate_new.n_samples assert vstate.n_discard == vstate_new.n_discard assert vstate.n_samples_diag == vstate_new.n_samples_diag assert vstate.n_discard_diag == vstate_new.n_discard_diag
def _load_from_bytes(path: str, target: Any): with open(path, "rb") as f: return from_bytes(target, f.read())
def load_model(filename: str, model: nn.Module) -> nn.Module: with gfile.GFile(filename, "rb") as fp: return serialization.from_bytes(model, fp.read())
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs): r""" Instantiate a pretrained flax model from a pre-trained model configuration. The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): Can be either: - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this case, ``from_pt`` should be set to :obj:`True`. model_args (sequence of positional arguments, `optional`): All remaning positional arguments will be passed to the underlying model's ``__init__`` method. config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`): Can be either: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the `model id` string of a pretrained model). - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by supplying the save directory. - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a PyTorch checkpoint save file (see docstring of ``pretrained_model_name_or_path`` argument). force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (:obj:`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (i.e., do not try to download the model). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: >>> from transformers import BertConfig, FlaxBertModel >>> # Download model and configuration from huggingface.co and cache. >>> model = FlaxBertModel.from_pretrained('bert-base-cased') >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). >>> model = FlaxBertModel.from_pretrained('./test/saved_model/') >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file('./pt_model/config.json') >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config) """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) user_agent = { "file_type": "model", "framework": "flax", "from_auto_class": from_auto_class } if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, **kwargs, ) else: model_kwargs = kwargs # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) else: raise EnvironmentError( f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " f"{pretrained_model_name_or_path} or `from_pt` set to False" ) elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, revision=revision, ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except EnvironmentError as err: logger.error(err) msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info( f"loading weights file {archive_file} from cache at {resolved_archive_file}" ) else: resolved_archive_file = None # init random models model = cls(config, *model_args, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict( model, resolved_archive_file) else: with open(resolved_archive_file, "rb") as state_f: try: state = from_bytes(cls, state_f.read()) except UnpicklingError: raise EnvironmentError( f"Unable to convert {archive_file} to Flax deserializable object. " ) # if model is base model only use model_prefix key if cls.base_model_prefix not in dict( model.params) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] # flatten dicts state = flatten_dict(state) random_state = flatten_dict(unfreeze(model.params)) missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params # add missing keys as random parameters for missing_key in missing_keys: state[missing_key] = random_state[missing_key] # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.info( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) # set correct parameters model.params = unflatten_dict(state) return model
def restore_checkpoint(checkpoint_path, target): """Restores a checkpoint.""" with gfile.GFile(checkpoint_path, 'rb') as f: return serialization.from_bytes(target, f.read())
def restore_from_path(ckpt_path, target): ckpt_path = check_and_convert_gcs_filepath(ckpt_path) logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: return serialization.from_bytes(target, fp.read())
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs): r""" Instantiate a pretrained flax model from a pre-trained model configuration. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - A path to a *directory* containing model weights saved using [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, `from_pt` should be set to `True`. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. **Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and [`~FlaxPreTrainedModel.to_bf16`]. model_args (sequence of positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): Can be either: - an instance of a class derived from [`PretrainedConfig`], - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the *model id* string of a pretrained model). - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the save directory. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. from_pt (`bool`, *optional*, defaults to `False`): Load the model weights from a PyTorch checkpoint save file (see docstring of `pretrained_model_name_or_path` argument). ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with `config`, `**kwargs` will be directly passed to the underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, `kwargs` will be first passed to the configuration class initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that corresponds to a configuration attribute will be used to override said attribute with the supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's `__init__` function. Examples: ```python >>> from transformers import BertConfig, FlaxBertModel >>> # Download model and configuration from huggingface.co and cache. >>> model = FlaxBertModel.from_pretrained("bert-base-cased") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file("./pt_model/config.json") >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) ```""" config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) _do_init = kwargs.pop("_do_init", True) user_agent = { "file_type": "model", "framework": "flax", "from_auto_class": from_auto_class } if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, **kwargs, ) else: model_kwargs = kwargs # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) # At this stage we don't have a weight file so we will raise an error. elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " "weights.") else: raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"{pretrained_model_name_or_path}.") elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=filename, revision=revision, ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " "login` and pass `use_auth_token=True`.") except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " "this model name. Check the model page at " f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: if filename == FLAX_WEIGHTS_NAME: has_file_kwargs = { "revision": revision, "proxies": proxies, "use_auth_token": use_auth_token } if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from " "those weights.") else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} " f"or {WEIGHTS_NAME}.") else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." ) except HTTPError as err: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" f"{err}") except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in the cached " f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory " f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n" "Checkout your internet connection or see how to run the library in offline mode at " "'https://huggingface.co/docs/transformers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." ) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info( f"loading weights file {archive_file} from cache at {resolved_archive_file}" ) else: resolved_archive_file = None # init random models model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict( model, resolved_archive_file) else: with open(resolved_archive_file, "rb") as state_f: try: state = from_bytes(cls, state_f.read()) except (UnpicklingError, msgpack.exceptions.ExtraData) as e: try: with open(resolved_archive_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please install " "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " "you cloned.") else: raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError( f"Unable to convert {archive_file} to Flax deserializable object. " ) # make sure all arrays are stored as jnp.arrays # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 if _do_init: state = jax.tree_util.tree_map(jnp.array, state) else: # keep the params on CPU if we don't want to initialize state = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) # if model is base model only use model_prefix key if cls.base_model_prefix not in dict( model.params_shape_tree) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] # if model is head model and we are loading weights from base model # we initialize new params dict with base_model_prefix if cls.base_model_prefix in dict( model.params_shape_tree ) and cls.base_model_prefix not in state: state = {cls.base_model_prefix: state} # flatten dicts state = flatten_dict(state) random_state = flatten_dict( unfreeze(model.params if _do_init else model.params_shape_tree)) missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params if missing_keys and not _do_init: logger.warn( f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " f"Make sure to call model.init_weights to initialize the missing weights." ) cls._missing_keys = missing_keys # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys = [] for key in state.keys(): if key in random_state and state[key].shape != random_state[ key].shape: if ignore_mismatched_sizes: mismatched_keys.append( (key, state[key].shape, random_state[key].shape)) state[key] = random_state[key] else: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " "model.") # add missing keys as random parameters if we are initializing if missing_keys and _do_init: for missing_key in missing_keys: state[missing_key] = random_state[missing_key] # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.info( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join([ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ]) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) # dictionary of key: dtypes for the model params param_dtypes = jax.tree_map(lambda x: x.dtype, state) # extract keys of parameters not in jnp.float32 fp16_params = [ k for k in param_dtypes if param_dtypes[k] == jnp.float16 ] bf16_params = [ k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16 ] # raise a warning if any of the parameters are not in jnp.float32 if len(fp16_params) > 0: logger.warning( f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" "You should probably UPCAST the model weights to float32 if this was not intended. " "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." ) if len(bf16_params) > 0: logger.warning( f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" "You should probably UPCAST the model weights to float32 if this was not intended. " "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." ) if _do_init: # set correct parameters model.params = unflatten_dict(state) return model else: return model, unflatten_dict(state)
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True): """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: ckpt_1, ckpt_2, ckpt_3 --> ckpt_3 ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1 ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5 Args: ckpt_dir: str: checkpoint file or directory of checkpoints to restore from. target: matching object to rebuild via deserialized state-dict. If None, the deserialized state-dict is returned as-is. step: int: step number to load or None to load latest. If specified, ckpt_dir must be a directory. prefix: str: name prefix of checkpoint files. parallel: bool: whether to load seekable checkpoints in parallel, for speed. Returns: Restored `target` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in `target` unchanged. If a file path is specified and is not found, the passed-in `target` will be returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created. """ if step: ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) if not gfile.exists(ckpt_path): raise ValueError(f'Matching checkpoint not found: {ckpt_path}') else: if gfile.isdir(ckpt_dir): ckpt_path = latest_checkpoint(ckpt_dir, prefix) if not ckpt_path: logging.info(f'Found no checkpoint files in {ckpt_dir}') return target else: ckpt_path = ckpt_dir if not gfile.exists(ckpt_path): logging.info(f'Found no checkpoint file at {ckpt_path}') return target logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: if parallel and fp.seekable(): buf_size = 128 << 20 # 128M buffer. num_bufs = fp.size() / buf_size logging.debug('num_bufs: %d', num_bufs) checkpoint_contents = bytearray(fp.size()) def read_chunk(i): # NOTE: We have to re-open the file to read each chunk, otherwise the # parallelism has no effect. But we could reuse the file pointers # within each thread. with gfile.GFile(ckpt_path, 'rb') as f: f.seek(i * buf_size) buf = f.read(buf_size) if buf: checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf return len(buf) / buf_size pool_size = 32 pool = thread.ThreadPoolExecutor(pool_size) results = pool.map(read_chunk, range(int(num_bufs) + 1)) results = list(results) pool.shutdown(wait=False) logging.debug('results: %s', results) else: checkpoint_contents = fp.read() if target is None: return serialization.msgpack_restore(checkpoint_contents) else: return serialization.from_bytes(target, checkpoint_contents)
def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs): r""" Instantiate a pretrained Flax model from a pre-trained model configuration. """ config = kwargs.pop("config", None) # state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) # from_tf = kwargs.pop("from_tf", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) # output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, revision=revision, **kwargs, ) else: model_kwargs = kwargs # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype # Load model if pretrained_model_name_or_path is not None: if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) except EnvironmentError as err: logger.error(err) msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") else: resolved_archive_file = None # Instantiate model. with open(resolved_archive_file, "rb") as state_f: try: from flax.serialization import from_bytes state = from_bytes(cls.model_class, state_f) except TypeError: try: import torch state = torch.load(state_f) state = {k: v.numpy() for k, v in state.items()} state = cls.convert_from_pytorch(state, config) state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()}) except UnpicklingError: raise EnvironmentError( f"Unable to convert model {archive_file} to Flax deserializable object. " "Supported format are PyTorch archive or Flax msgpack" ) return cls(config, state, *model_args, **model_kwargs)
def load_model(filename, model): with open(filename, "rb") as fp: return serialization.from_bytes(model, fp.read())
def train(self, training_data, training_z, batch_size=5000, niter=2000): """Trains the classifier Parameters: ----------- training_data: numpy array, size Ngalaxes x Nbands training data, each row is a galaxy, each column is a band as per band defined above training_z: numpy array, size Ngalaxies true redshift for the training sample """ # create scaler features = self.features_scaler.fit_transform(training_data) features = np.clip(features, -4, 4) labels = training_z # If model is already trained, we just load the weights if os.path.exists(self.export_name): with open(self.export_name, 'rb') as file: self.model = serialization.from_bytes(self.model, pickle.load(file)) return lr = 0.001 optimizer = optim.Adam(learning_rate=lr).create(self.model) @jax.jit def train_step(optimizer, batch): # This is the loss function def loss_fn(model): # Apply classifier to features w = model(batch['features']) # returns - score, because we want to maximize score if self.metric == 'SNR': return -metrics.compute_snr_score(w, batch['labels']) elif self.metric == 'FOM': # Minimizing the Area return 1. / metrics.compute_fom(w, batch['labels']) elif self.metric == 'FOM_DETF': # Minimizing the Area return 1. / metrics.compute_fom( w, batch['labels'], inds=[5, 6]) else: raise NotImplementedError # Compute gradients loss, g = jax.value_and_grad(loss_fn)(optimizer.target) # Perform gradient descent optimizer = optimizer.apply_gradient(g) return optimizer, loss # This function provides random batches of data, TODO: convert to JAX print("Size of dataset", len(labels)) def get_batch(): inds = onp.random.choice(len(labels), batch_size) return {'labels': labels[inds], 'features': features[inds]} losses = [] for i in range(niter): optimizer, loss = train_step(optimizer, get_batch()) losses.append(loss) if i % 100 == 0: print('iter: %d; Loss : %f' % (i, loss)) # Export model to disk with open(self.export_name, 'wb') as file: pickle.dump(serialization.to_bytes(optimizer.target), file) self.model = optimizer.target
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r""" Instantiate a pretrained Flax model from a pre-trained model configuration. """ config = kwargs.pop("config", None) # state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) # from_tf = kwargs.pop("from_tf", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) # output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_cdn = kwargs.pop("use_cdn", True) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, **kwargs, ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if os.path.isfile(pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights." else: msg = ( f"Model name '{pretrained_model_name_or_path}' " f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). " f"We assumed '{archive_file}' was a path or url to model weight files but " "couldn't find any such file at this path or url.") raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info( f"loading weights file {archive_file} from cache at {resolved_archive_file}" ) else: resolved_archive_file = None # Instantiate model. with open(resolved_archive_file, "rb") as state_f: try: from flax.serialization import from_bytes state = from_bytes(cls.model_class, state_f) except TypeError: try: import torch state = torch.load(state_f) state = {k: v.numpy() for k, v in state.items()} state = cls.convert_from_pytorch(state, config) state = unflatten_dict( {tuple(k.split(".")[1:]): v for k, v in state.items()}) except UnpicklingError: raise EnvironmentError( f"Unable to convert model {archive_file} to Flax deserializable object. " "Supported format are PyTorch archive or Flax msgpack") return cls(config, state, *model_args, **model_kwargs)