Beispiel #1
0
def restore_checkpoint(load_dir, state):
    logger.info(f"Restoring checkpoint from {load_dir}")
    with open(os.path.join(load_dir, "flax_model.msgpack"), "rb") as f:
        params = from_bytes(state.params, f.read())
    with open(os.path.join(load_dir, "opt_state.msgpack"), "rb") as f:
        opt_state = from_bytes(state.opt_state, f.read())
    with open(os.path.join(load_dir, "training_state.json"), "r") as f:
        training_state = json.load(f)
    step = training_state["step"]
    logger.info(f"Checkpoint restored at step {step}")
    return state.replace(step=step, params=params, opt_state=opt_state), step
Beispiel #2
0
def load_from_zip(agent_dict, load_path):
    """
    """
    # Check if the file exists
    if isinstance(load_path, str):
        if not os.path.exists(load_path):
            if os.path.exists(load_path + ".zip"):
                load_path += ".zip"
            else:
                raise ValueError(
                    "Error: the file {:} could not be found.".format(
                        load_path))

    # Open file and load the agent components
    with zipfile.ZipFile(load_path, "r") as f:
        namelist = f.namelist()
        for name, target in agent_dict.items():
            # Skip components that were not saved
            if name not in namelist:
                continue

            serialized = f.read(name)
            agent[name] = from_bytes(target, serialized)

    return agent_dict
Beispiel #3
0
 def test_namedtuple_serialization(self):
     foo_class = collections.namedtuple('Foo', 'a b c')
     x1 = foo_class(a=1, b=2, c=3)
     x1_serialized = serialization.to_bytes(x1)
     x2 = foo_class(a=0, b=0, c=0)
     restored_x1 = serialization.from_bytes(x2, x1_serialized)
     self.assertEqual(x1, restored_x1)
Beispiel #4
0
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'):
    """Restore last/best checkpoint from checkpoints in path.

  Sorts the checkpoint files naturally, returning the highest-valued
  file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
    ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
    ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

  Args:
    ckpt_dir: str: directory of checkpoints to restore from.
    target: matching object to rebuild via deserialized state-dict.
    step: int: step number to load or None to load latest.
    prefix: str: name prefix of checkpoint files.

  Returns:
    Restored `target` updated from checkpoint file, or if no step specified and
    no checkpoint files present, returns the passed-in `target` unchanged.
  """
    if step:
        ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
        if not gfile.exists(ckpt_path):
            raise ValueError(f'Matching checkpoint not found: {ckpt_path}')
    else:
        glob_path = os.path.join(ckpt_dir, f'{prefix}*')
        checkpoint_files = natural_sort(gfile.glob(glob_path))
        ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
        checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path]
        if not checkpoint_files:
            return target
        ckpt_path = checkpoint_files[-1]

    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        return serialization.from_bytes(target, fp.read())
Beispiel #5
0
def restore_checkpoint(save_dir, state):
    print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
    with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
        params = from_bytes(state.params, f.read())

    with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
        opt_state = from_bytes(state.opt_state, f.read())

    args = joblib.load(os.path.join(save_dir, "args.joblib"))
    data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))

    with open(os.path.join(save_dir, "training_state.json"), "r") as f:
        training_state = json.load(f)
    step = training_state["step"]

    print("DONE")
    return params, opt_state, step, args, data_collator
Beispiel #6
0
 def test_model_serialization_to_bytes(self):
     rng = random.PRNGKey(0)
     module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
     _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
     model = nn.Model(module, initial_params)
     serialized_bytes = serialization.to_bytes(model)
     restored_model = serialization.from_bytes(model, serialized_bytes)
     self.assertEqual(restored_model.params, model.params)
Beispiel #7
0
def load_weights(weight_path: str) -> Dict[str, Any]:
  """Load and deserialize weight dictionary."""
  if not gfile.exists(weight_path):
    raise ValueError('Matching checkpoint not found: {}'.format(weight_path))
  else:
    logging.info('Loading weights from %s', weight_path)
    with gfile.GFile(weight_path, 'rb') as fp:
      params = serialization.from_bytes(None, fp.read())
    return jax.tree_map(jnp.asarray, params)
Beispiel #8
0
 def test_serialization_chunking2(self):
     old_chunksize = serialization.MAX_CHUNK_SIZE
     serialization.MAX_CHUNK_SIZE = 91 * 8
     try:
         tmp = {'a': np.ones((10, 10))}
         tmpbytes = serialization.to_bytes(tmp)
         newtmp = serialization.from_bytes(tmp, tmpbytes)
     finally:
         serialization.MAX_CHUNK_SIZE = old_chunksize
     jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp)
Beispiel #9
0
 def test_optimizer_serialization_to_bytes(self):
   rng = random.PRNGKey(0)
   module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
   _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
   model = nn.Model(module, initial_params)
   optim_def = optim.Momentum(learning_rate=1.)
   optimizer = optim_def.create(model)
   serialized_bytes = serialization.to_bytes(optimizer)
   restored_optimizer = serialization.from_bytes(optimizer, serialized_bytes)
   self.assertEqual(restored_optimizer, optimizer)
Beispiel #10
0
def restore_from_path(ckpt_dir, target, step=None, prefix='ckpt_'):
    """Restores a checkpoint from a directory path, if available."""
    ckpt_destination_path = latest_checkpoint_path(ckpt_dir, prefix)

    if ckpt_destination_path is None:
        logging.info('No checkpoints found, starting from the beginning.')
        return target, step

    logging.info('Restoring checkpoint: %s', ckpt_destination_path)
    save_state = SaveState(target, step)
    with gfile.GFile(ckpt_destination_path, 'rb') as fp:
        save_state = serialization.from_bytes(save_state, fp.read())
        return save_state.train_state, save_state.step
def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
    """Load flax checkpoints in a PyTorch model"""
    flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
    logger.info(f"Loading Flax weights from {flax_checkpoint_path}")

    # import correct flax class
    flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)

    # load flax weight dict
    with open(flax_checkpoint_path, "rb") as state_f:
        try:
            flax_state_dict = from_bytes(flax_cls, state_f.read())
        except UnpicklingError:
            raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")

    return load_flax_weights_in_pytorch_model(model, flax_state_dict)
Beispiel #12
0
def variables_from_file(filename: str, variables: _PyTree):
    """
    Loads the variables of a variational state from a `.mpack` file.

    Args:
        filename: the file containing the variables. Assumes a .mpack
            extension and adds it if missing and no file exists.
        variables: An object variables with the same structure and shape
            of the object to be deserialized.

    Returns:
        a PyTree like variables

    Examples:
       Serializing the data:

       >>> import netket as nk
       >>> import flax
       >>> # construct an RBM model on 10 spins
       >>> vstate = nk.variational.MCState(
       ...      nk.sampler.MetropolisLocal(nk.hilbert.Spin(0.5)**10),
       ...      nk.models.RBM())
       >>> with open("test.mpack", 'wb') as file:
       ...     bytes_written = file.write(flax.serialization.to_bytes(vstate.variables))
       >>> print(bytes_written)
       1052
       >>>
       >>> # Deserialize the data
       >>>
       >>> del vstate
       >>> # construct an RBM model on 10 spins
       >>> vstate2 = nk.variational.MCState(
       ...      nk.sampler.MetropolisLocal(nk.hilbert.Spin(0.5)**10),
       ...      nk.models.RBM())
       >>> # Load the data by passing the model
       >>> vars = nk.variational.experimental.variables_from_file("test.mpack",
       ...                                                        vstate2.variables)
       >>> # update the variables of vstate with the loaded data.
       >>> vstate2.variables = vars
    """
    if not _path.isfile(filename):
        if filename[-6:] != ".mpack":
            filename = filename + ".mpack"

    with open(filename, "rb") as f:
        return _serialization.from_bytes(variables, f.read())
Beispiel #13
0
def test_serialization(vstate):
    from flax import serialization

    bdata = serialization.to_bytes(vstate)

    old_params = vstate.parameters
    old_samples = vstate.samples
    old_nsamples = vstate.n_samples
    old_ndiscard = vstate.n_discard_per_chain

    vstate = nk.vqs.MCState(vstate.sampler, vstate.model, n_samples=10, seed=SEED + 100)

    vstate = serialization.from_bytes(vstate, bdata)

    jax.tree_multimap(np.testing.assert_allclose, vstate.parameters, old_params)
    np.testing.assert_allclose(vstate.samples, old_samples)
    assert vstate.n_samples == old_nsamples
    assert vstate.n_discard_per_chain == old_ndiscard
Beispiel #14
0
def variables_from_tar(filename: str, variables: _PyTree, i: int):
    """
    Loads the variables of a variational state from the i-th element of a `.tar`
    archive.

    Args:
        filename: the tar archive name. Assumes a .tar
            extension and adds it if missing and no file exists.
        variables: An object variables with the same structure and shape
            of the object to be deserialized.
        i: the index of the variables to load
    """
    if not _path.isfile(filename):
        if filename[-4:] != ".tar":
            filename = filename + ".tar"

    with _tarfile.TarFile(filename, "r") as file:
        info = file.getmember(str(i) + ".mpack")
        with file.extractfile(info) as f:
            return _serialization.from_bytes(variables, f.read())
Beispiel #15
0
def test_serialization(vstate):
    from flax import serialization

    bdata = serialization.to_bytes(vstate)

    vstate_new = nk.variational.MCMixedState(
        vstate.sampler, vstate.model, n_samples=10, seed=SEED + 313
    )

    vstate_new = serialization.from_bytes(vstate_new, bdata)

    jax.tree_multimap(
        np.testing.assert_allclose, vstate.parameters, vstate_new.parameters
    )
    np.testing.assert_allclose(vstate.samples, vstate_new.samples)
    np.testing.assert_allclose(vstate.diagonal.samples, vstate_new.diagonal.samples)
    assert vstate.n_samples == vstate_new.n_samples
    assert vstate.n_discard == vstate_new.n_discard
    assert vstate.n_samples_diag == vstate_new.n_samples_diag
    assert vstate.n_discard_diag == vstate_new.n_discard_diag
Beispiel #16
0
def _load_from_bytes(path: str, target: Any):
    with open(path, "rb") as f:
        return from_bytes(target, f.read())
Beispiel #17
0
def load_model(filename: str, model: nn.Module) -> nn.Module:
    with gfile.GFile(filename, "rb") as fp:
        return serialization.from_bytes(model, fp.read())
Beispiel #18
0
    def from_pretrained(cls,
                        pretrained_model_name_or_path: Union[str, os.PathLike],
                        dtype: jnp.dtype = jnp.float32,
                        *model_args,
                        **kwargs):
        r"""
        Instantiate a pretrained flax model from a pre-trained model configuration.

        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.

        Parameters:
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
                Can be either:

                    - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
                    - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
                      case, ``from_pt`` should be set to :obj:`True`.
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `model id` string of a pretrained
                      model).
                    - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
                      by supplying the save directory.
                    - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
                      configuration JSON file named `config.json` is found in the directory.
            cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch checkpoint save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str], `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (i.e., do not try to download the model).
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            >>> from transformers import BertConfig, FlaxBertModel
            >>> # Download model and configuration from huggingface.co and cache.
            >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
            >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            >>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
            >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
            >>> config = BertConfig.from_json_file('./pt_model/config.json')
            >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
        """
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)

        user_agent = {
            "file_type": "model",
            "framework": "flax",
            "from_auto_class": from_auto_class
        }
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_pt and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                FLAX_WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
                        f"{pretrained_model_name_or_path} or `from_pt` set to False"
                    )
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
                    revision=revision,
                )

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
                )
            except EnvironmentError as err:
                logger.error(err)
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(
                    f"loading weights file {archive_file} from cache at {resolved_archive_file}"
                )
        else:
            resolved_archive_file = None

        # init random models
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            state = load_pytorch_checkpoint_in_flax_state_dict(
                model, resolved_archive_file)
        else:
            with open(resolved_archive_file, "rb") as state_f:
                try:
                    state = from_bytes(cls, state_f.read())
                except UnpicklingError:
                    raise EnvironmentError(
                        f"Unable to convert {archive_file} to Flax deserializable object. "
                    )

        # if model is base model only use model_prefix key
        if cls.base_model_prefix not in dict(
                model.params) and cls.base_model_prefix in state:
            state = state[cls.base_model_prefix]

        # flatten dicts
        state = flatten_dict(state)
        random_state = flatten_dict(unfreeze(model.params))

        missing_keys = model.required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - model.required_params

        # add missing keys as random parameters
        for missing_key in missing_keys:
            state[missing_key] = random_state[missing_key]

        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.info(
                f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
            )

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
        else:
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )

        # set correct parameters
        model.params = unflatten_dict(state)
        return model
def restore_checkpoint(checkpoint_path, target):
    """Restores a checkpoint."""
    with gfile.GFile(checkpoint_path, 'rb') as f:
        return serialization.from_bytes(target, f.read())
Beispiel #20
0
def restore_from_path(ckpt_path, target):
    ckpt_path = check_and_convert_gcs_filepath(ckpt_path)
    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        return serialization.from_bytes(target, fp.read())
    def from_pretrained(cls,
                        pretrained_model_name_or_path: Union[str, os.PathLike],
                        dtype: jnp.dtype = jnp.float32,
                        *model_args,
                        **kwargs):
        r"""
        Instantiate a pretrained flax model from a pre-trained model configuration.

        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
        weights are discarded.

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                Can be either:

                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                      user or organization name, like `dbmdz/bert-base-german-cased`.
                    - A path to a *directory* containing model weights saved using
                      [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
                      `from_pt` should be set to `True`.
            dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
                The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
                `jax.numpy.bfloat16` (on TPUs).

                This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
                specified all the computation will be performed with the given `dtype`.

                **Note that this only specifies the dtype of the computation and does not influence the dtype of model
                parameters.**

                If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
                [`~FlaxPreTrainedModel.to_bf16`].
            model_args (sequence of positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
                Can be either:

                    - an instance of a class derived from [`PretrainedConfig`],
                    - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].

                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
                      model).
                    - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
                      save directory.
                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
                      configuration JSON file named *config.json* is found in the directory.
            cache_dir (`Union[str, os.PathLike]`, *optional*):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            from_pt (`bool`, *optional*, defaults to `False`):
                Load the model weights from a PyTorch checkpoint save file (see docstring of
                `pretrained_model_name_or_path` argument).
            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
                checkpoint with 3 labels).
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (`bool`, *optional*, defaults to `False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (`Dict[str, str]`, *optional*):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(`bool`, *optional*, defaults to `False`):
                Whether or not to only look at local files (i.e., do not try to download the model).
            revision (`str`, *optional*, defaults to `"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
                identifier allowed by git.
            kwargs (remaining dictionary of keyword arguments, *optional*):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
                automatically loaded:

                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
                      corresponds to a configuration attribute will be used to override said attribute with the
                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
                      will be passed to the underlying model's `__init__` function.

        Examples:

        ```python
        >>> from transformers import BertConfig, FlaxBertModel

        >>> # Download model and configuration from huggingface.co and cache.
        >>> model = FlaxBertModel.from_pretrained("bert-base-cased")
        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
        >>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
        >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
        >>> config = BertConfig.from_json_file("./pt_model/config.json")
        >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
        ```"""
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)
        _do_init = kwargs.pop("_do_init", True)

        user_agent = {
            "file_type": "model",
            "framework": "flax",
            "from_auto_class": from_auto_class
        }
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_pt and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                FLAX_WEIGHTS_NAME)
                # At this stage we don't have a weight file so we will raise an error.
                elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
                    raise EnvironmentError(
                        f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
                        "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
                        "weights.")
                else:
                    raise EnvironmentError(
                        f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
                        f"{pretrained_model_name_or_path}.")
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=filename,
                    revision=revision,
                )

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
                )

            except RepositoryNotFoundError:
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
                    "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
                    "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
                    "login` and pass `use_auth_token=True`.")
            except RevisionNotFoundError:
                raise EnvironmentError(
                    f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
                    "this model name. Check the model page at "
                    f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
                )
            except EntryNotFoundError:
                if filename == FLAX_WEIGHTS_NAME:
                    has_file_kwargs = {
                        "revision": revision,
                        "proxies": proxies,
                        "use_auth_token": use_auth_token
                    }
                    if has_file(pretrained_model_name_or_path, WEIGHTS_NAME,
                                **has_file_kwargs):
                        raise EnvironmentError(
                            f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
                            "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
                            "those weights.")
                    else:
                        raise EnvironmentError(
                            f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
                            f"or {WEIGHTS_NAME}.")
                else:
                    raise EnvironmentError(
                        f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
                    )
            except HTTPError as err:
                raise EnvironmentError(
                    f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
                    f"{err}")
            except ValueError:
                raise EnvironmentError(
                    f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in the cached "
                    f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
                    f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
                    "Checkout your internet connection or see how to run the library in offline mode at "
                    "'https://huggingface.co/docs/transformers/installation#offline-mode'."
                )
            except EnvironmentError:
                raise EnvironmentError(
                    f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
                    "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
                    f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
                    f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
                )

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(
                    f"loading weights file {archive_file} from cache at {resolved_archive_file}"
                )
        else:
            resolved_archive_file = None

        # init random models
        model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)

        if from_pt:
            state = load_pytorch_checkpoint_in_flax_state_dict(
                model, resolved_archive_file)
        else:
            with open(resolved_archive_file, "rb") as state_f:
                try:
                    state = from_bytes(cls, state_f.read())
                except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
                    try:
                        with open(resolved_archive_file) as f:
                            if f.read().startswith("version"):
                                raise OSError(
                                    "You seem to have cloned a repository without having git-lfs installed. Please install "
                                    "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                                    "you cloned.")
                            else:
                                raise ValueError from e
                    except (UnicodeDecodeError, ValueError):
                        raise EnvironmentError(
                            f"Unable to convert {archive_file} to Flax deserializable object. "
                        )
            # make sure all arrays are stored as jnp.arrays
            # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
            # https://github.com/google/flax/issues/1261
            if _do_init:
                state = jax.tree_util.tree_map(jnp.array, state)
            else:
                # keep the params on CPU if we don't want to initialize
                state = jax.tree_util.tree_map(
                    lambda x: jax.device_put(x,
                                             jax.devices("cpu")[0]), state)

        # if model is base model only use model_prefix key
        if cls.base_model_prefix not in dict(
                model.params_shape_tree) and cls.base_model_prefix in state:
            state = state[cls.base_model_prefix]

        # if model is head model and we are loading weights from base model
        # we initialize new params dict with base_model_prefix
        if cls.base_model_prefix in dict(
                model.params_shape_tree
        ) and cls.base_model_prefix not in state:
            state = {cls.base_model_prefix: state}

        # flatten dicts
        state = flatten_dict(state)

        random_state = flatten_dict(
            unfreeze(model.params if _do_init else model.params_shape_tree))

        missing_keys = model.required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - model.required_params

        if missing_keys and not _do_init:
            logger.warn(
                f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
                f"Make sure to call model.init_weights to initialize the missing weights."
            )
            cls._missing_keys = missing_keys

        # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
        # matching the weights in the model.
        mismatched_keys = []
        for key in state.keys():
            if key in random_state and state[key].shape != random_state[
                    key].shape:
                if ignore_mismatched_sizes:
                    mismatched_keys.append(
                        (key, state[key].shape, random_state[key].shape))
                    state[key] = random_state[key]
                else:
                    raise ValueError(
                        f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
                        f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
                        "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
                        "model.")

        # add missing keys as random parameters if we are initializing
        if missing_keys and _do_init:
            for missing_key in missing_keys:
                state[missing_key] = random_state[missing_key]

        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.info(
                f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
            )

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
        elif len(mismatched_keys) == 0:
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )
        if len(mismatched_keys) > 0:
            mismatched_warning = "\n".join([
                f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                for key, shape1, shape2 in mismatched_keys
            ])
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )

        # dictionary of key: dtypes for the model params
        param_dtypes = jax.tree_map(lambda x: x.dtype, state)
        # extract keys of parameters not in jnp.float32
        fp16_params = [
            k for k in param_dtypes if param_dtypes[k] == jnp.float16
        ]
        bf16_params = [
            k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16
        ]

        # raise a warning if any of the parameters are not in jnp.float32
        if len(fp16_params) > 0:
            logger.warning(
                f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
                f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
                "You should probably UPCAST the model weights to float32 if this was not intended. "
                "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
            )

        if len(bf16_params) > 0:
            logger.warning(
                f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
                f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
                "You should probably UPCAST the model weights to float32 if this was not intended. "
                "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
            )

        if _do_init:
            # set correct parameters
            model.params = unflatten_dict(state)
            return model
        else:
            return model, unflatten_dict(state)
Beispiel #22
0
def restore_checkpoint(ckpt_dir,
                       target,
                       step=None,
                       prefix='checkpoint_',
                       parallel=True):
    """Restore last/best checkpoint from checkpoints in path.

  Sorts the checkpoint files naturally, returning the highest-valued
  file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
    ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
    ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

  Args:
    ckpt_dir: str: checkpoint file or directory of checkpoints to restore from.
    target: matching object to rebuild via deserialized state-dict. If None,
      the deserialized state-dict is returned as-is.
    step: int: step number to load or None to load latest. If specified,
      ckpt_dir must be a directory.
    prefix: str: name prefix of checkpoint files.
    parallel: bool: whether to load seekable checkpoints in parallel, for speed.

  Returns:
    Restored `target` updated from checkpoint file, or if no step specified and
    no checkpoint files present, returns the passed-in `target` unchanged.
    If a file path is specified and is not found, the passed-in `target` will be
    returned. This is to match the behavior of the case where a directory path
    is specified but the directory has not yet been created.
  """
    if step:
        ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
        if not gfile.exists(ckpt_path):
            raise ValueError(f'Matching checkpoint not found: {ckpt_path}')
    else:
        if gfile.isdir(ckpt_dir):
            ckpt_path = latest_checkpoint(ckpt_dir, prefix)
            if not ckpt_path:
                logging.info(f'Found no checkpoint files in {ckpt_dir}')
                return target
        else:
            ckpt_path = ckpt_dir
            if not gfile.exists(ckpt_path):
                logging.info(f'Found no checkpoint file at {ckpt_path}')
                return target

    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        if parallel and fp.seekable():
            buf_size = 128 << 20  # 128M buffer.
            num_bufs = fp.size() / buf_size
            logging.debug('num_bufs: %d', num_bufs)
            checkpoint_contents = bytearray(fp.size())

            def read_chunk(i):
                # NOTE: We have to re-open the file to read each chunk, otherwise the
                # parallelism has no effect. But we could reuse the file pointers
                # within each thread.
                with gfile.GFile(ckpt_path, 'rb') as f:
                    f.seek(i * buf_size)
                    buf = f.read(buf_size)
                    if buf:
                        checkpoint_contents[i * buf_size:i * buf_size +
                                            len(buf)] = buf
                    return len(buf) / buf_size

            pool_size = 32
            pool = thread.ThreadPoolExecutor(pool_size)
            results = pool.map(read_chunk, range(int(num_bufs) + 1))
            results = list(results)
            pool.shutdown(wait=False)
            logging.debug('results: %s', results)
        else:
            checkpoint_contents = fp.read()

        if target is None:
            return serialization.msgpack_restore(checkpoint_contents)
        else:
            return serialization.from_bytes(target, checkpoint_contents)
    def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs):
        r"""
        Instantiate a pretrained Flax model from a pre-trained model configuration.
        """
        config = kwargs.pop("config", None)
        # state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        # from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        # output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)
        revision = kwargs.pop("revision", None)

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                revision=revision,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
            except EnvironmentError as err:
                logger.error(err)
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
        else:
            resolved_archive_file = None

        # Instantiate model.
        with open(resolved_archive_file, "rb") as state_f:
            try:
                from flax.serialization import from_bytes

                state = from_bytes(cls.model_class, state_f)
            except TypeError:
                try:
                    import torch

                    state = torch.load(state_f)
                    state = {k: v.numpy() for k, v in state.items()}
                    state = cls.convert_from_pytorch(state, config)
                    state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()})
                except UnpicklingError:
                    raise EnvironmentError(
                        f"Unable to convert model {archive_file} to Flax deserializable object. "
                        "Supported format are PyTorch archive or Flax msgpack"
                    )

        return cls(config, state, *model_args, **model_kwargs)
Beispiel #24
0
def load_model(filename, model):
    with open(filename, "rb") as fp:
        return serialization.from_bytes(model, fp.read())
Beispiel #25
0
    def train(self, training_data, training_z, batch_size=5000, niter=2000):
        """Trains the classifier

        Parameters:
        -----------
        training_data: numpy array, size Ngalaxes x Nbands
          training data, each row is a galaxy, each column is a band as per
          band defined above
        training_z: numpy array, size Ngalaxies
          true redshift for the training sample

        """
        # create scaler

        features = self.features_scaler.fit_transform(training_data)
        features = np.clip(features, -4, 4)
        labels = training_z

        # If model is already trained, we just load the weights
        if os.path.exists(self.export_name):
            with open(self.export_name, 'rb') as file:
                self.model = serialization.from_bytes(self.model,
                                                      pickle.load(file))
            return

        lr = 0.001
        optimizer = optim.Adam(learning_rate=lr).create(self.model)

        @jax.jit
        def train_step(optimizer, batch):
            # This is the loss function
            def loss_fn(model):
                # Apply classifier to features
                w = model(batch['features'])
                # returns - score, because we want to maximize score
                if self.metric == 'SNR':
                    return -metrics.compute_snr_score(w, batch['labels'])
                elif self.metric == 'FOM':
                    # Minimizing the Area
                    return 1. / metrics.compute_fom(w, batch['labels'])
                elif self.metric == 'FOM_DETF':
                    # Minimizing the Area
                    return 1. / metrics.compute_fom(
                        w, batch['labels'], inds=[5, 6])
                else:
                    raise NotImplementedError

            # Compute gradients
            loss, g = jax.value_and_grad(loss_fn)(optimizer.target)
            # Perform gradient descent
            optimizer = optimizer.apply_gradient(g)
            return optimizer, loss

        # This function provides random batches of data, TODO: convert to JAX
        print("Size of dataset", len(labels))

        def get_batch():
            inds = onp.random.choice(len(labels), batch_size)
            return {'labels': labels[inds], 'features': features[inds]}

        losses = []
        for i in range(niter):
            optimizer, loss = train_step(optimizer, get_batch())
            losses.append(loss)
            if i % 100 == 0:
                print('iter: %d; Loss : %f' % (i, loss))

        # Export model to disk
        with open(self.export_name, 'wb') as file:
            pickle.dump(serialization.to_bytes(optimizer.target), file)

        self.model = optimizer.target
Beispiel #26
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        r"""
        Instantiate a pretrained Flax model from a pre-trained model configuration.
        """
        config = kwargs.pop("config", None)
        # state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        # from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        # output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)
        use_cdn = kwargs.pop("use_cdn", True)

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
                    pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                archive_file = hf_bucket_url(pretrained_model_name_or_path,
                                             filename=WEIGHTS_NAME,
                                             use_cdn=use_cdn)

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
            except EnvironmentError:
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
                    msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
                else:
                    msg = (
                        f"Model name '{pretrained_model_name_or_path}' "
                        f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
                        f"We assumed '{archive_file}' was a path or url to model weight files but "
                        "couldn't find any such file at this path or url.")
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(
                    f"loading weights file {archive_file} from cache at {resolved_archive_file}"
                )
        else:
            resolved_archive_file = None

        # Instantiate model.
        with open(resolved_archive_file, "rb") as state_f:
            try:
                from flax.serialization import from_bytes

                state = from_bytes(cls.model_class, state_f)
            except TypeError:
                try:
                    import torch

                    state = torch.load(state_f)
                    state = {k: v.numpy() for k, v in state.items()}
                    state = cls.convert_from_pytorch(state, config)
                    state = unflatten_dict(
                        {tuple(k.split(".")[1:]): v
                         for k, v in state.items()})
                except UnpicklingError:
                    raise EnvironmentError(
                        f"Unable to convert model {archive_file} to Flax deserializable object. "
                        "Supported format are PyTorch archive or Flax msgpack")

        return cls(config, state, *model_args, **model_kwargs)