Exemplo n.º 1
0
    def testBasicCollective(self):
        local_devices = list(jax.local_devices())
        if len(local_devices) < 4:
            raise SkipTest("Test requires at least 4 local devices")

        def f(a, b):
            return lax.psum(a * 2, 'a'), b * 4

        devices = np.array(local_devices[:4]).reshape((2, 2))
        with mesh(devices, ('x', 'y')):
            fm = xmap(f,
                      in_axes=[A({
                          'a': 0,
                          'b': 1
                      }), A({'c': 0})],
                      out_axes=[A({'b': 0}), A({'c': 0})],
                      schedule=[
                          ('a', 'x'),
                          ('b', 'y'),
                          ('c', 'x'),
                          ('a', 'vectorize'),
                          ('b', 'vectorize'),
                      ])
            ashape = (16, 8, 5)
            a = jnp.arange(np.prod(ashape)).reshape(ashape)
            bshape = (2, 7)
            b = jnp.arange(np.prod(bshape)).reshape(bshape)
            c, d = fm(a, b)
            self.assertAllClose(c, (a * 2).sum(0))
            self.assertAllClose(d, b * 4)
Exemplo n.º 2
0
    def testBasic(self):
        local_devices = list(jax.local_devices())
        if len(local_devices) < 4:
            raise SkipTest("Test requires at least 4 local devices")

        def f(a, b):
            return a * 2, b * 4

        devices = np.array(local_devices[:4]).reshape((2, 2))
        with mesh(devices, ('x', 'y')):
            fm = xmap(f,
                      in_axes=[{
                          0: 'a',
                          1: 'b'
                      }, ['c', ...]],
                      out_axes=[{
                          0: 'a',
                          1: 'b'
                      }, ['c', ...]],
                      axis_resources={
                          'a': 'x',
                          'b': 'y',
                          'c': 'x'
                      })
            ashape = (16, 8, 5)
            a = jnp.arange(np.prod(ashape)).reshape(ashape)
            bshape = (2, 7)
            b = jnp.arange(np.prod(bshape)).reshape(bshape)
            c, d = fm(a, b)
            self.assertAllClose(c, a * 2)
            self.assertAllClose(d, b * 4)
Exemplo n.º 3
0
        def f(x):
            with mesh(np.empty((), dtype=np.object), ()):

                @partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
                def h(x):
                    return x

                return h(x)
Exemplo n.º 4
0
 def new_f(*args, **kwargs):
   axis_names, shape = unzip2(named_shape)
   size = np.prod(shape)
   local_devices = list(jax.local_devices())
   if len(local_devices) < size:
     raise SkipTest(f"Test requires {size} local devices")
   mesh_devices = np.array(local_devices[:size]).reshape(shape)
   with mesh(mesh_devices, axis_names):
     return f(*args, **kwargs)
Exemplo n.º 5
0
  def testCaching(self):
    def f(x):
      assert should_be_tracing
      return jnp.sin(x) * 2

    x = np.arange(16).reshape(4, 4)
    devices = np.array(list(jax.local_devices())[:4])
    if devices.size < 4:
      raise unittest.SkipTest("Test requires 4 devices")
    devices = devices.reshape((2, 2))
    with mesh(devices, ('x', 'y')):
      should_be_tracing = True
      pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
      should_be_tracing = False
      pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
    # Re-create the mesh to make sure that has no influence on caching
    with mesh(devices, ('x', 'y')):
      should_be_tracing = False
      pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
Exemplo n.º 6
0
 def testCaching(self):
   def f(x):
     assert python_should_be_executing
     return x * 2
   devices = np.array(jax.local_devices()[:2])
   if devices.size < 2:
     raise SkipTest("Test requires 2 devices")
   x = np.arange(8).reshape((2, 2, 2))
   with mesh(devices, ('x',)):
     python_should_be_executing = True
     xmap(f, in_axes=['a', ...], out_axes=['a', ...],
          axis_resources={'a': 'x'})(x)
     python_should_be_executing = False
     xmap(f, in_axes=['a', ...], out_axes=['a', ...],
          axis_resources={'a': 'x'})(x)
   with mesh(devices, ('x',)):
     python_should_be_executing = False
     xmap(f, in_axes=['a', ...], out_axes=['a', ...],
          axis_resources={'a': 'x'})(x)
Exemplo n.º 7
0
        def f(x):
            with mesh(np.array([jax.local_devices()[0]]), ('x')):

                @partial(pjit,
                         in_axis_resources=P('x'),
                         out_axis_resources=None)
                def h(x):
                    return x

                return h(x)
Exemplo n.º 8
0
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
    """Test utility for setting up meshes given mesh data from `schedules`."""
    # This is similar to the `with_mesh` function above, but isn't a decorator.
    axis_names, shape = unzip2(named_shape)
    size = prod(shape)
    local_devices = list(jax.local_devices())
    if len(local_devices) < size:
        raise SkipTest(f"Test requires {size} local devices")
    mesh_devices = np.array(local_devices[:size]).reshape(shape)
    with mesh(mesh_devices, axis_names):
        yield
Exemplo n.º 9
0
        def f(x):
            with mesh(np.empty((), dtype=np.object), ()):

                @partial(xmap,
                         in_axes=A({'b': 0}),
                         out_axes=A({'b': 0}),
                         schedule=[('b', 'vectorize')])
                def h(x):
                    return x

                return h(x)
Exemplo n.º 10
0
  def test_from_gda_duplicates(self):
    global_mesh = create_global_mesh((1, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = ['x', 'y']
    input_gda = create_gda(global_input_shape, global_mesh, mesh_axes)

    # It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
    # pickling in_axis_resources and sending to other processes). Make sure this
    # this doesn't cause an error to avoid user confusion.
    from_gda_dup = pjit_lib._FromGsdaSingleton()
    with mesh(global_mesh.devices, global_mesh.axis_names):
      pjit(lambda x: x, in_axis_resources=from_gda_dup, out_axis_resources=None)(
          input_gda)
Exemplo n.º 11
0
    def _pjit(inp):
        if isinstance(inp, GlobalDeviceArray):
            if inp.is_fully_replicated:
                return inp.local_data(0).to_py()
            global_mesh = inp._global_mesh
            in_axis_resources = FROM_GDA
        else:
            # DA/SDA/np.array will be sharded based on global_mesh.local_mesh.
            # Shape of local_mesh will always be (1, local_device_count())
            devices = np.array(jax.devices()).reshape(jax.process_count(),
                                                      jax.local_device_count())
            global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
            in_axis_resources = P('processes')
            if inp.ndim == 0 or not titled:
                inp = np.expand_dims(inp, axis=0)

        with maps.mesh(global_mesh.devices, global_mesh.axis_names):
            out = pjit(lambda x: x,
                       in_axis_resources=in_axis_resources,
                       out_axis_resources=None)(inp)
        return out.local_data(0).to_py()
Exemplo n.º 12
0
  def test_no_recompilation_due_to_in_axis_resources(self):
    global_mesh = create_global_mesh((1, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(None,)
    input_gda = create_gda(global_input_shape, global_mesh, mesh_axes)

    with jax._src.config.parallel_functions_output_gda(True):
      @partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=mesh_axes)
      def f(x):
        return x

      with mesh(global_mesh.devices, global_mesh.axis_names):
        out_gda = f(input_gda)
        self.assertEqual(out_gda._mesh_axes, ())

        before_cache = pjit_lib._pjit_lower.cache_info()
        f(out_gda)
        after_cache = pjit_lib._pjit_lower.cache_info()

        self.assertNotEqual(id(before_cache), id(after_cache))
        self.assertEqual(before_cache.hits + 1, after_cache.hits)
        self.assertEqual(before_cache.misses, after_cache.misses)
Exemplo n.º 13
0
    def testInfeed(self):
        devices = np.array(jax.local_devices())
        nr_devices = len(devices)
        shape = (nr_devices * 3, nr_devices * 5)

        def f_for_jit(x):
            token = lax.create_token(x)
            (y, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))
            (z, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))
            (w, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))

            return x + y + z + w

        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
        y = x * 2.
        z = x * 3.
        w = x * 4.

        # Transfer data to infeed before executing the function. For GPUs, the
        # execution of the compiled function is blocking, so transferring data
        # to infeed before executing ensures that the execution does not deadlock
        # waiting for the infeed data.
        logging.info('Transfering to infeed for the jit call')
        d = devices[0]
        d.transfer_to_infeed((y, ))
        d.transfer_to_infeed((z, ))
        d.transfer_to_infeed((w, ))

        # JIT
        logging.info('Making jit call')
        res0 = jax.jit(f_for_jit)(x)
        self.assertAllClose(res0, x + y + z + w, check_dtypes=True)

        # PJIT
        def f_for_pjit(x):
            token = lax.create_token(x)
            # A replicated infeed
            (y, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(None, ))
            # An infeed sharded on first axis
            (z, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(P(nr_devices, 1), ))
            # An infeed sharded on second axis
            (w, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(P(1, nr_devices), ))
            return x + y + z + w

        logging.info('Transfering to infeed for the pjit call')
        for didx, d in enumerate(devices):
            # Transfer the whole array to all devices for replicated.
            d.transfer_to_infeed((y, ))
            # For sharded infeed, transfer only the needed slices to each device.
            d.transfer_to_infeed((z[3 * didx:3 * didx + 3, :]))
            d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5], ))

        with mesh(devices, ['d']):
            logging.info('Making pjit call')
            res = pjit(f_for_pjit,
                       in_axis_resources=(P('d'), ),
                       out_axis_resources=P('d'))(x)

        self.assertAllClose(res0, res, check_dtypes=True)
Exemplo n.º 14
0
 def dispatch():
   with mesh(devices, ['d']):
     logging.info('Making pjit call')
     pjit(f, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(x)
Exemplo n.º 15
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
        )

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            dataset["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained config and tokenizer
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if training_args.do_train:
        column_names = dataset["train"].column_names
    else:
        column_names = dataset["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
            )
        return output

    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > config.max_position_embeddings:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --block_size xxx."
            )
            block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models
    # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions
    # GPT2's vocab is odd, we need to resize it for fine-tuning
    model = FlaxAutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
    )

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    optimizer = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
    )

    def get_initial_state(params):
        state = optimizer.init(params)
        return tuple(state), params

    # Get PartitionSpec for model params
    param_spec = set_partitions(unfreeze(model.params))

    # Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
    params_shapes = jax.tree_map(lambda x: x.shape, model.params)
    state_shapes = jax.eval_shape(get_initial_state, params_shapes)

    # get PartitionSpec for opt_state, this is very specific to adamw
    # TODO: optax returns different state for different optimizers, how can we handle this generically ?
    # or maybe we don't since in our examples we just use adamw or adafactor
    def get_opt_spec(x):
        if isinstance(x, dict):
            return param_spec
        return None

    opt_state_spec, param_spec = jax.tree_map(
        get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
    )

    # pjit the get_initial_state function to shard params and init
    # optimizer state in sharded way
    p_get_initial_state = pjit(
        get_initial_state,
        in_axis_resources=None,
        out_axis_resources=(opt_state_spec, param_spec),
    )

    # hack: move the inital params to CPU to free up device memory
    # TODO: allow loading weights on CPU in pre-trained model
    model.params = jax.tree_map(lambda x: np.asarray(x), model.params)

    # mesh defination
    mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())

    # actually initialize the opt_state
    with mesh(mesh_devices, ("dp", "mp")):
        opt_state, params = p_get_initial_state(freeze(model.params))

    # cross-entropy with z loss
    def loss_fn(logits, labels, z_loss=0):
        shift_logits = logits[..., :-1, :]
        shift_labels = labels[..., 1:]

        shift_labels = onehot(shift_labels, shift_logits.shape[-1])

        shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True))
        log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True))
        log_softmax = shift_logits - log_z
        loss = -jnp.sum(shift_labels * log_softmax, axis=-1)

        loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss

        return loss.mean()

    # Define gradient update step fn
    # TODO: try to use TrainState instead of passing params and opt_state individually
    def train_step(params, opt_state, dropout_rng, batch, step):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = loss_fn(logits, labels, z_loss=1.0)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grads = grad_fn(params)

        updates, new_opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)}
        return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1

    # Define eval fn
    def eval_step(input_ids, labels, params):
        logits = model(input_ids=input_ids, params=params, train=False)[0]
        loss = loss_fn(logits, labels)
        # metrics
        return {"loss": loss}

    p_train_step = pjit(
        train_step,
        in_axis_resources=(param_spec, opt_state_spec, None, None, None),
        out_axis_resources=(param_spec, opt_state_spec, None, None, None),
        donate_argnums=(0, 1),
    )

    p_eval_step = pjit(
        eval_step,
        in_axis_resources=(None, None, param_spec),
        out_axis_resources=None,
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    train_metrics = []
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
    global_step = 0
    # we are not doing 2D parallelism (yet!), this just does model parallelism
    with mesh(mesh_devices, ("dp", "mp")):
        for _ in epochs:
            # ======================== Training ================================
            train_start = time.time()

            # Create sampling rng
            rng, input_rng = jax.random.split(rng)

            # Generate an epoch by shuffling sampling indices from the train dataset
            train_metrics = []
            train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
            steps_per_epoch = len(train_dataset) // train_batch_size

            # train
            for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
                batch = next(train_loader)
                params, opt_state, dropout_rng, train_metric, global_step = p_train_step(
                    params,
                    opt_state,
                    dropout_rng,
                    batch,
                    global_step,
                )
                train_metrics.append(train_metric)

                cur_step = global_step

                if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                    # Save metrics
                    train_time += time.time() - train_start
                    if has_tensorboard and jax.process_index() == 0:
                        write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                    epochs.write(
                        f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                    )

                    train_metrics = []

                if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                    # ======================== Evaluating ==============================
                    eval_metrics = []
                    eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
                    eval_steps = len(eval_dataset) // eval_batch_size

                    for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
                        batch = next(eval_loader)
                        metrics = p_eval_step(batch["input_ids"], batch["labels"], params)
                        eval_metrics.append(metrics)

                    # normalize eval metrics
                    eval_metrics = stack_forest(eval_metrics)
                    eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                    try:
                        eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
                    except OverflowError:
                        eval_metrics["perplexity"] = float("inf")

                    logger.info(
                        f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
                    )

                if cur_step % training_args.save_steps == 0 and cur_step > 0:
                    # save checkpoint after each epoch and push checkpoint to the hub
                    if jax.process_index() == 0:
                        params = jax.device_get(params)
                        model.save_pretrained(
                            training_args.output_dir,
                            params=params,
                            push_to_hub=training_args.push_to_hub,
                            commit_message=f"Saving weights and logs of step {cur_step}",
                        )