Ejemplo n.º 1
0
def _save_as_bytes(path: str, obj: Any):
    path = os.path.abspath(path)
    folder = os.path.dirname(path)
    os.makedirs(folder, exist_ok=True)

    with open(path, "wb") as f:
        f.write(to_bytes(obj))
Ejemplo n.º 2
0
    def save_pretrained(self, save_directory: Union[str, os.PathLike]):
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
        `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method

        Arguments:
            save_directory (:obj:`str` or :obj:`os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
        """
        if os.path.isfile(save_directory):
            logger.error(
                f"Provided path ({save_directory}) should be a directory, not a file"
            )
            return
        os.makedirs(save_directory, exist_ok=True)

        # get abs dir
        save_directory = os.path.abspath(save_directory)
        # save config as well
        self.config.architectures = [self.__class__.__name__[4:]]
        self.config.save_pretrained(save_directory)

        # save model
        with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
            model_bytes = to_bytes(self.params)
            f.write(model_bytes)
Ejemplo n.º 3
0
 def test_namedtuple_serialization(self):
     foo_class = collections.namedtuple('Foo', 'a b c')
     x1 = foo_class(a=1, b=2, c=3)
     x1_serialized = serialization.to_bytes(x1)
     x2 = foo_class(a=0, b=0, c=0)
     restored_x1 = serialization.from_bytes(x2, x1_serialized)
     self.assertEqual(x1, restored_x1)
Ejemplo n.º 4
0
def test_variables_from_tar(vstate, tmp_path):
    fname = str(tmp_path) + "/file.tar"

    with tarfile.TarFile(fname, "w") as f:
        for i in range(10):
            save_binary_to_tar(f, serialization.to_bytes(vstate.variables),
                               f"{i}.mpack")

    for name in [fname, fname[:-4]]:
        vstate2 = nk.variational.MCState(vstate.sampler,
                                         vstate.model,
                                         n_samples=10,
                                         seed=SEED + 100)

        for j in [0, 3, 8]:
            vstate2.variables = nk.variational.experimental.variables_from_tar(
                name, vstate2.variables, j)

            # check
            jax.tree_multimap(np.testing.assert_allclose, vstate.parameters,
                              vstate2.parameters)

        with pytest.raises(KeyError):
            nk.variational.experimental.variables_from_tar(
                name, vstate2.variables, 15)
Ejemplo n.º 5
0
 def test_model_serialization_to_bytes(self):
     rng = random.PRNGKey(0)
     module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
     _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
     model = nn.Model(module, initial_params)
     serialized_bytes = serialization.to_bytes(model)
     restored_model = serialization.from_bytes(model, serialized_bytes)
     self.assertEqual(restored_model.params, model.params)
Ejemplo n.º 6
0
    def save_pretrained(self,
                        save_directory: Union[str, os.PathLike],
                        params=None,
                        push_to_hub=False,
                        **kwargs):
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
        `[`~FlaxPreTrainedModel.from_pretrained`]` class method

        Arguments:
            save_directory (`str` or `os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it.

                <Tip warning={true}>

                Using `push_to_hub=True` will synchronize the repository you are pushing to with
                `save_directory`, which requires `save_directory` to be a local clone of the repo you are
                pushing to if it's an existing folder. Pass along `temp_dir=True` to use a temporary directory
                instead.

                </Tip>

            kwargs:
                Additional key word arguments passed along to the
                [`~file_utils.PushToHubMixin.push_to_hub`] method.
        """
        if os.path.isfile(save_directory):
            logger.error(
                f"Provided path ({save_directory}) should be a directory, not a file"
            )
            return

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo = self._create_or_get_repo(save_directory, **kwargs)

        os.makedirs(save_directory, exist_ok=True)

        # get abs dir
        save_directory = os.path.abspath(save_directory)
        # save config as well
        self.config.architectures = [self.__class__.__name__[4:]]
        self.config.save_pretrained(save_directory)

        # save model
        output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
        with open(output_model_file, "wb") as f:
            params = params if params is not None else self.params
            model_bytes = to_bytes(params)
            f.write(model_bytes)

        logger.info(f"Model weights saved in {output_model_file}")

        if push_to_hub:
            url = self._push_to_hub(repo, commit_message=commit_message)
            logger.info(f"Model pushed to the hub in this commit: {url}")
Ejemplo n.º 7
0
    def save_pretrained(self, folder):
        folder_abs = os.path.abspath(folder)

        if not os.path.exists(folder_abs):
            os.mkdir(folder_abs)

        with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f:
            model_bytes = to_bytes(self.params)
            f.write(model_bytes)
Ejemplo n.º 8
0
    def _flush_params(self, variational_state):
        if not self._save_params:
            return

        binary_data = serialization.to_bytes(variational_state.variables)
        with open(self._prefix + ".mpack", "wb") as outfile:
            outfile.write(binary_data)

        self._steps_notflushed_pars = 0
Ejemplo n.º 9
0
 def test_optimizer_serialization_to_bytes(self):
   rng = random.PRNGKey(0)
   module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
   _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
   model = nn.Model(module, initial_params)
   optim_def = optim.Momentum(learning_rate=1.)
   optimizer = optim_def.create(model)
   serialized_bytes = serialization.to_bytes(optimizer)
   restored_optimizer = serialization.from_bytes(optimizer, serialized_bytes)
   self.assertEqual(restored_optimizer, optimizer)
Ejemplo n.º 10
0
 def test_serialization_chunking2(self):
     old_chunksize = serialization.MAX_CHUNK_SIZE
     serialization.MAX_CHUNK_SIZE = 91 * 8
     try:
         tmp = {'a': np.ones((10, 10))}
         tmpbytes = serialization.to_bytes(tmp)
         newtmp = serialization.from_bytes(tmp, tmpbytes)
     finally:
         serialization.MAX_CHUNK_SIZE = old_chunksize
     jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp)
Ejemplo n.º 11
0
    def test_restore_chunked(self):
        old_chunksize = serialization.MAX_CHUNK_SIZE
        serialization.MAX_CHUNK_SIZE = 91 * 8
        try:
            tmp = np.random.uniform(-100, 100, size=(21, 37))
            serialized = serialization.to_bytes(tmp)
            restored = serialization.msgpack_restore(serialized)
        finally:
            serialization.MAX_CHUNK_SIZE = old_chunksize

        np.testing.assert_array_equal(restored, tmp)
Ejemplo n.º 12
0
 def save_checkpoint(self, save_dir, state):
     state = jax_utils.unreplicate(state)
     print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
     self.model_save_fn(save_dir, params=state.params)
     with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
         f.write(to_bytes(state.opt_state))
     joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
     joblib.dump(self.data_collator,
                 os.path.join(save_dir, "data_collator.joblib"))
     with open(os.path.join(save_dir, "training_state.json"), "w") as f:
         json.dump({"step": state.step.item()}, f)
     print("DONE")
Ejemplo n.º 13
0
    def _save_variables(self, variational_state):
        if self._init is False:
            self._init_output()

        _time = time.time()
        binary_data = serialization.to_bytes(variational_state.variables)
        if self._tar:
            save_binary_to_tar(
                self._tar_file, binary_data, str(self._file_step) + ".mpack"
            )
        else:
            with open(self._prefix + str(self._file_step) + ".mpack", "wb") as f:
                f.write(binary_data)

        self._file_step += 1
        self._runtime_taken += time.time() - _time
Ejemplo n.º 14
0
def save_checkpoint(ckpt_dir,
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1,
                    overwrite=False):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str: path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.
    overwrite: bool: allow overwriting when writing a checkpoint.

  Returns:
    Filename of saved checkpoint.
  """
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_path))

    logging.info('Writing to temporary checkpoint location: %s', ckpt_tmp_path)
    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(target))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite)
    logging.info('Saved checkpoint at %s', ckpt_path)

    # Remove old checkpoint files.
    base_path = os.path.join(ckpt_dir, f'{prefix}')
    checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    return ckpt_path
    def save_pretrained(self,
                        save_directory: Union[str, os.PathLike],
                        params=None,
                        push_to_hub=False,
                        **kwargs):
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
        `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method

        Arguments:
            save_directory (:obj:`str` or :obj:`os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
            push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to push your model to the Hugging Face model hub after saving it.
            kwargs:
                Additional key word arguments passed along to the
                :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
        """
        if os.path.isfile(save_directory):
            logger.error(
                f"Provided path ({save_directory}) should be a directory, not a file"
            )
            return
        os.makedirs(save_directory, exist_ok=True)

        # get abs dir
        save_directory = os.path.abspath(save_directory)
        # save config as well
        self.config.architectures = [self.__class__.__name__[4:]]
        self.config.save_pretrained(save_directory)

        # save model
        output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
        with open(output_model_file, "wb") as f:
            params = params if params is not None else self.params
            model_bytes = to_bytes(params)
            f.write(model_bytes)

        logger.info(f"Model weights saved in {output_model_file}")

        if push_to_hub:
            saved_files = [
                os.path.join(save_directory, CONFIG_NAME), output_model_file
            ]
            url = self._push_to_hub(save_files=saved_files, **kwargs)
            logger.info(f"Model pushed to the hub in this commit: {url}")
Ejemplo n.º 16
0
def test_serialization(vstate):
    from flax import serialization

    bdata = serialization.to_bytes(vstate)

    old_params = vstate.parameters
    old_samples = vstate.samples
    old_nsamples = vstate.n_samples
    old_ndiscard = vstate.n_discard_per_chain

    vstate = nk.vqs.MCState(vstate.sampler, vstate.model, n_samples=10, seed=SEED + 100)

    vstate = serialization.from_bytes(vstate, bdata)

    jax.tree_multimap(np.testing.assert_allclose, vstate.parameters, old_params)
    np.testing.assert_allclose(vstate.samples, old_samples)
    assert vstate.n_samples == old_nsamples
    assert vstate.n_discard_per_chain == old_ndiscard
Ejemplo n.º 17
0
def test_variables_from_file(vstate, tmp_path):
    fname = str(tmp_path) + "/file.mpack"

    with open(fname, "wb") as f:
        f.write(serialization.to_bytes(vstate.variables))

    for name in [fname, fname[:-6]]:
        vstate2 = nk.vqs.MCState(vstate.sampler,
                                 vstate.model,
                                 n_samples=10,
                                 seed=SEED + 100)

        vstate2.variables = nkx.vqs.variables_from_file(
            name, vstate2.variables)

        # check
        jax.tree_map(np.testing.assert_allclose, vstate.parameters,
                     vstate2.parameters)
Ejemplo n.º 18
0
def test_serialization(vstate):
    from flax import serialization

    bdata = serialization.to_bytes(vstate)

    vstate_new = nk.variational.MCMixedState(
        vstate.sampler, vstate.model, n_samples=10, seed=SEED + 313
    )

    vstate_new = serialization.from_bytes(vstate_new, bdata)

    jax.tree_multimap(
        np.testing.assert_allclose, vstate.parameters, vstate_new.parameters
    )
    np.testing.assert_allclose(vstate.samples, vstate_new.samples)
    np.testing.assert_allclose(vstate.diagonal.samples, vstate_new.diagonal.samples)
    assert vstate.n_samples == vstate_new.n_samples
    assert vstate.n_discard == vstate_new.n_discard
    assert vstate.n_samples_diag == vstate_new.n_samples_diag
    assert vstate.n_discard_diag == vstate_new.n_discard_diag
Ejemplo n.º 19
0
def save_checkpoint(ckpt_dir, state, step, keep=10, prefix='ckpt_'):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: Path to store checkpoint files in.
    state: Serializable flax object, usually a flax optimizer.
    step: Training step number or other metric number.
    keep: Number of checkpoints to keep.
    prefix: Checkepoint filename prefix.

  Returns:
    Filename of saved checkpoint.
  """
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    ckpt_tmp_path = os.path.join(ckpt_dir, 'tmp')
    ckpt_destination_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_destination_path))

    save_state = SaveState(state, step)
    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(save_state))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_destination_path, overwrite=True)

    logging.info('Saved checkpoint at %s', ckpt_destination_path)

    # Remove old checkpoint files.
    base_path = os.path.join(ckpt_dir, f'{prefix}')
    checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint: %s', path)
            gfile.remove(path)

    return ckpt_destination_path
Ejemplo n.º 20
0
def save_checkpoint(
    model,
    save_dir,
    state,
    cur_step: int,
    with_opt: bool = True,
    push_to_hub: bool = False,
):
    state = jax_utils.unreplicate(state)
    if with_opt:
        logger.info(f"Saving optimizer and training state in {save_dir}...")
        with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
            f.write(to_bytes(state.opt_state))
        with open(os.path.join(save_dir, "training_state.json"), "w") as f:
            json.dump({"step": state.step.item()}, f)
    logger.info(
        f'Saving model in {save_dir} {"and pushing it to HF Hub" if push_to_hub else ""}'
    )
    model.save_pretrained(
        save_dir,
        params=state.params,
        push_to_hub=push_to_hub,
        commit_message=f"Saving weights and logs of step {cur_step}",
    )
Ejemplo n.º 21
0
def save_model(filename: str, model: nn.Module) -> None:
    gfile.makedirs(os.path.dirname(filename))
    with gfile.GFile(filename, "wb") as fp:
        fp.write(serialization.to_bytes(model))
Ejemplo n.º 22
0
    def train(self, training_data, training_z, batch_size=5000, niter=2000):
        """Trains the classifier

        Parameters:
        -----------
        training_data: numpy array, size Ngalaxes x Nbands
          training data, each row is a galaxy, each column is a band as per
          band defined above
        training_z: numpy array, size Ngalaxies
          true redshift for the training sample

        """
        # create scaler

        features = self.features_scaler.fit_transform(training_data)
        features = np.clip(features, -4, 4)
        labels = training_z

        # If model is already trained, we just load the weights
        if os.path.exists(self.export_name):
            with open(self.export_name, 'rb') as file:
                self.model = serialization.from_bytes(self.model,
                                                      pickle.load(file))
            return

        lr = 0.001
        optimizer = optim.Adam(learning_rate=lr).create(self.model)

        @jax.jit
        def train_step(optimizer, batch):
            # This is the loss function
            def loss_fn(model):
                # Apply classifier to features
                w = model(batch['features'])
                # returns - score, because we want to maximize score
                if self.metric == 'SNR':
                    return -metrics.compute_snr_score(w, batch['labels'])
                elif self.metric == 'FOM':
                    # Minimizing the Area
                    return 1. / metrics.compute_fom(w, batch['labels'])
                elif self.metric == 'FOM_DETF':
                    # Minimizing the Area
                    return 1. / metrics.compute_fom(
                        w, batch['labels'], inds=[5, 6])
                else:
                    raise NotImplementedError

            # Compute gradients
            loss, g = jax.value_and_grad(loss_fn)(optimizer.target)
            # Perform gradient descent
            optimizer = optimizer.apply_gradient(g)
            return optimizer, loss

        # This function provides random batches of data, TODO: convert to JAX
        print("Size of dataset", len(labels))

        def get_batch():
            inds = onp.random.choice(len(labels), batch_size)
            return {'labels': labels[inds], 'features': features[inds]}

        losses = []
        for i in range(niter):
            optimizer, loss = train_step(optimizer, get_batch())
            losses.append(loss)
            if i % 100 == 0:
                print('iter: %d; Loss : %f' % (i, loss))

        # Export model to disk
        with open(self.export_name, 'wb') as file:
            pickle.dump(serialization.to_bytes(optimizer.target), file)

        self.model = optimizer.target
Ejemplo n.º 23
0
def save_model(filename, model):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, "wb") as fp:
        fp.write(serialization.to_bytes(model))
Ejemplo n.º 24
0
def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1,
                    overwrite=False):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str or pathlib-like path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.
    overwrite: overwrite existing checkpoint files if a checkpoint
      at the current or a later step already exits (default: False).
  Returns:
    Filename of saved checkpoint.
  """
    ckpt_dir = os.fspath(ckpt_dir)  # Pathlib -> str
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    if ckpt_dir.startswith('./'):
        ckpt_dir = ckpt_dir[2:]  # gfile.glob() can remove leading './'
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_path))
    base_path = os.path.join(ckpt_dir, prefix)
    checkpoint_files = gfile.glob(base_path + '*')

    if ckpt_path in checkpoint_files:
        if not overwrite:
            raise errors.InvalidCheckpointError(ckpt_path, step)
    else:
        checkpoint_files.append(ckpt_path)

    checkpoint_files = natural_sort(checkpoint_files)
    if checkpoint_files[-1] == ckpt_tmp_path:
        checkpoint_files.pop(-1)
    if ckpt_path != checkpoint_files[-1]:
        if not overwrite:
            raise errors.InvalidCheckpointError(ckpt_path, step)

    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(target))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite)
    logging.info('Saved checkpoint at %s', ckpt_path)
    print(ckpt_path)

    # Remove newer checkpoints
    if overwrite:
        ind = checkpoint_files.index(ckpt_path) + 1
        newer_ckpts = checkpoint_files[ind:]
        checkpoint_files = checkpoint_files[:ind]
        for path in newer_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    # Remove old checkpoint files.
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    return ckpt_path
Ejemplo n.º 25
0
def save_weights(weight_path: str, model_params: Dict[str, Any]):
  """Save serialized weight dictionary."""
  serialized_params = serialization.to_bytes(model_params)
  gfile.makedirs(os.path.dirname(weight_path))
  with gfile.GFile(weight_path, 'wb') as fp:
    fp.write(serialized_params)