Beispiel #1
0
def process_iterator(tag: str,
                     item_ids: Sequence[str],
                     iterator,
                     rng: types.PRNGKey,
                     state: model_utils.TrainState,
                     step: int,
                     render_fn: Any,
                     summary_writer: tensorboard.SummaryWriter,
                     save_dir: Optional[gpath.GPath],
                     datasource: datasets.DataSource):
  """Process a dataset iterator and compute metrics."""
  save_dir = save_dir / f'{step:08d}' / tag if save_dir else None
  meters = collections.defaultdict(utils.ValueMeter)
  for i, (item_id, batch) in enumerate(zip(item_ids, iterator)):
    logging.info('[%s:%d/%d] Processing %s ', tag, i+1, len(item_ids), item_id)
    if tag == 'test':
      test_rng = random.PRNGKey(step)
      shape = batch['origins'][..., :1].shape
      metadata = {}
      if datasource.use_appearance_id:
        appearance_id = random.choice(
            test_rng, jnp.asarray(datasource.appearance_ids))
        logging.info('\tUsing appearance_id = %d', appearance_id)
        metadata['appearance'] = jnp.full(shape, fill_value=appearance_id,
                                          dtype=jnp.uint32)
      if datasource.use_warp_id:
        warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids))
        logging.info('\tUsing warp_id = %d', warp_id)
        metadata['warp'] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32)
      if datasource.use_camera_id:
        camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids))
        logging.info('\tUsing camera_id = %d', camera_id)
        metadata['camera'] = jnp.full(shape, fill_value=camera_id,
                                      dtype=jnp.uint32)
      if datasource.use_time:
        timestamp = random.uniform(test_rng, minval=0.0, maxval=1.0)
        logging.info('\tUsing time = %d', timestamp)
        metadata['time'] = jnp.full(
            shape, fill_value=timestamp, dtype=jnp.uint32)
      batch['metadata'] = metadata

    stats = process_batch(batch=batch,
                          rng=rng,
                          state=state,
                          tag=tag,
                          item_id=item_id,
                          step=step,
                          render_fn=render_fn,
                          summary_writer=summary_writer,
                          save_dir=save_dir,
                          datasource=datasource)
    if jax.process_index() == 0:
      for k, v in stats.items():
        meters[k].update(v)

  if jax.process_index() == 0:
    for meter_name, meter in meters.items():
      summary_writer.scalar(tag=f'metrics-eval/{meter_name}/{tag}',
                            value=meter.reduce('mean'),
                            step=step)
Beispiel #2
0
    def test_summarywriter_clipped_audio(self):
        log_dir = tempfile.mkdtemp()
        summary_writer = SummaryWriter(log_dir=log_dir)
        expected_audio_samples = onp.random.uniform(low=-2.,
                                                    high=2.,
                                                    size=(2, 48000, 2))
        summary_writer.audio(tag='audio_test',
                             audiodata=expected_audio_samples,
                             step=1,
                             max_outputs=1)

        summary_value = self.parse_and_return_summary_value(path=log_dir)
        self.assertEqual(summary_value.tag, 'audio_test')

        # Assert one audio files is parsed.
        self.assertLen(summary_value.tensor.string_val, 1)

        # actual_audio is clipped.
        actual_audio = tf.audio.decode_wav(
            summary_value.tensor.string_val[0]).audio
        self.assertFalse(
            onp.allclose(expected_audio_samples[0], actual_audio, atol=1e-04))

        clipped_audio = onp.clip(onp.array(expected_audio_samples[0]), -1, 1)
        self.assertTrue(onp.allclose(clipped_audio, actual_audio, atol=1e-04))
Beispiel #3
0
    def test_summarywriter_audio(self):
        log_dir = tempfile.mkdtemp()
        summary_writer = SummaryWriter(log_dir=log_dir)
        expected_audio_samples = onp.random.uniform(low=-1.,
                                                    high=1.,
                                                    size=(2, 48000, 2))
        summary_writer.audio(tag='audio_test',
                             audiodata=expected_audio_samples,
                             step=1)

        summary_value = self.parse_and_return_summary_value(path=log_dir)
        self.assertEqual(summary_value.tag, 'audio_test')

        # Assert two audio files are parsed.
        self.assertLen(summary_value.tensor.string_val, 2)

        # Assert values.
        actual_audio_1 = tf.audio.decode_wav(
            summary_value.tensor.string_val[0]).audio
        self.assertTrue(
            onp.allclose(expected_audio_samples[0], actual_audio_1,
                         atol=1e-04))

        actual_audio_2 = tf.audio.decode_wav(
            summary_value.tensor.string_val[1]).audio
        self.assertTrue(
            onp.allclose(expected_audio_samples[1], actual_audio_2,
                         atol=1e-04))
Beispiel #4
0
def train_for_one_epoch(
    dataset_source: dataset_source_lib.DatasetSource,
    optimizer: flax.optim.Optimizer, state: flax.nn.Collection,
    prng_key: jnp.ndarray, pmapped_train_step: _TrainStep,
    pmapped_update_ema: Optional[_EMAUpdateStep],
    moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage],
    summary_writer: tensorboard.SummaryWriter
) -> Tuple[flax.optim.Optimizer, flax.nn.Collection,
           Optional[efficientnet_optim.ExponentialMovingAverage]]:
  """Trains the model for one epoch.

  Args:
    dataset_source: Container for the training dataset.
    optimizer: The optimizer targeting the model to train.
    state: Current state associated with the model (contains the batch norm MA).
    prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual
      dropout mask). Is not used for shuffling the dataset.
    pmapped_train_step: A pmapped version of the `train_step` function (see its
      documentation for more details).
    pmapped_update_ema: Function to update the parameter moving average. Can be
      None if we don't use EMA.
    moving_averages: Parameters moving average if used.
    summary_writer: A Tensorboard SummaryWriter to use to log metrics.

  Returns:
    The updated optimizer (with the associated updated model), state and PRNG
      key.
  """
  start_time = time.time()
  cnt = 0
  train_metrics = []
  for batch in dataset_source.get_train(use_augmentations=True):
    # Generate a PRNG key that will be rolled into the batch.
    step_key = jax.random.fold_in(prng_key, optimizer.state.step[0])
    # Load and shard the TF batch.
    batch = tensorflow_to_numpy(batch)
    batch = shard_batch(batch)
    # Shard the step PRNG key.
    sharded_keys = common_utils.shard_prng_key(step_key)

    optimizer, state, metrics, lr = pmapped_train_step(
        optimizer, state, batch, sharded_keys)
    cnt += 1

    if moving_averages is not None:
      moving_averages = pmapped_update_ema(optimizer, state, moving_averages)

    train_metrics.append(metrics)
  train_metrics = common_utils.get_metrics(train_metrics)
  # Get training epoch summary for logging.
  train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
  train_summary['learning_rate'] = lr[0]
  current_step = int(optimizer.state.step[0])
  info = 'Whole training step done in {} ({} steps)'.format(
      time.time()-start_time, cnt)
  logging.info(info)
  for metric_name, metric_value in train_summary.items():
    summary_writer.scalar(metric_name, metric_value, current_step)
  summary_writer.flush()
  return optimizer, state, moving_averages
Beispiel #5
0
  def test_summarywriter_image(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    expected_img = onp.random.uniform(low=0., high=255., size=(30, 30, 3))
    expected_img = expected_img.astype(onp.uint8)
    summary_writer.image(tag='image_test', image=expected_img, step=1)
    summary_value = self.parse_and_return_summary_value(path=log_dir)

    self.assertEqual(summary_value.tag, 'image_test')
    actual_img = tf.image.decode_image(summary_value.tensor.string_val[2])
    self.assertTrue(onp.allclose(actual_img, expected_img))
Beispiel #6
0
  def test_summarywriter_text(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    text = 'hello world.'
    summary_writer.text(tag='text_test', textdata=text, step=1)

    summary_value = self.parse_and_return_summary_value(path=log_dir)
    self.assertEqual(summary_value.tag, 'text_test')
    self.assertEqual(
        tensor_util.make_ndarray(summary_value.tensor).item().decode('utf-8'),
        text)
Beispiel #7
0
  def test_summarywriter_scalar(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    # Write the scalar and check if the event exists and check data.
    float_value = 99.1232
    summary_writer.scalar(tag='scalar_test', value=float_value, step=1)

    summary_value = self.parse_and_return_summary_value(path=log_dir)
    self.assertEqual(summary_value.tag, 'scalar_test')
    self.assertTrue(onp.allclose(
        tensor_util.make_ndarray(summary_value.tensor).item(),
        float_value))
Beispiel #8
0
  def test_summarywriter_single_channel_image_scaled(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    img = onp.random.uniform(low=0., high=255., size=(30, 30, 1))
    img = img.astype(onp.uint8)
    summary_writer.image(tag='2dimage_1channel_test', image=img, step=1)
    summary_value = self.parse_and_return_summary_value(path=log_dir)

    self.assertEqual(summary_value.tag, '2dimage_1channel_test')
    actual_img = tf.image.decode_image(summary_value.tensor.string_val[2])
    # assert the image was increased in dimension
    self.assertEqual(actual_img.shape, (30, 30, 3))
Beispiel #9
0
  def test_summarywriter_histogram_defaultbins(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    histogram = onp.arange(1000)
    # Histogram will be created for 30 (default) bins.
    summary_writer.histogram(tag='histogram_test', values=histogram, step=1)

    summary_value = self.parse_and_return_summary_value(path=log_dir)
    self.assertEqual(summary_value.tag, 'histogram_test')
    actual_histogram = tensor_util.make_ndarray(summary_value.tensor)
    self.assertTrue(actual_histogram.shape, (30, 3))
    self.assertTrue(
        onp.allclose(actual_histogram[0], (0.0, 33.3, 34.0), atol=1e-01))
Beispiel #10
0
  def test_summarywriter_image_float_pixel_values(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    expected_img = onp.random.uniform(low=0., high=1., size=(30, 30, 3))
    summary_writer.image(tag='image_test', image=expected_img, step=1)
    summary_value = self.parse_and_return_summary_value(path=log_dir)

    # convert and scale expected_img appropriately to numpy uint8.
    expected_img = tf.image.convert_image_dtype(
        image=expected_img, dtype=onp.uint8)

    self.assertEqual(summary_value.tag, 'image_test')
    actual_img = tf.image.decode_image(summary_value.tensor.string_val[2])
    self.assertTrue(onp.allclose(actual_img, expected_img))
Beispiel #11
0
  def test_summarywriter_histogram_2bins(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    histogram = onp.arange(1000)
    summary_writer.histogram(
        tag='histogram_test', values=histogram, step=1, bins=2)

    summary_value = self.parse_and_return_summary_value(path=log_dir)
    self.assertEqual(summary_value.tag, 'histogram_test')
    actual_histogram = tensor_util.make_ndarray(summary_value.tensor)
    self.assertTrue(actual_histogram.shape, (2, 3))
    self.assertTrue(
        onp.allclose(actual_histogram[0], (0.0, 499.5, 500.0), atol=1e-01))
    self.assertTrue(
        onp.allclose(actual_histogram[1], (499.5, 999.0, 500.0), atol=1e-01))
Beispiel #12
0
    def test_summarywriter_multiple_2dimages_scaled(self):
        log_dir = tempfile.mkdtemp()
        summary_writer = SummaryWriter(log_dir=log_dir)
        img = np.random.uniform(low=0., high=255., size=(2, 30, 30))
        img = img.astype(np.uint8)
        summary_writer.image(tag='multiple_2dimages_test', image=img, step=1)
        summary_value = self.parse_and_return_summary_value(path=log_dir)

        self.assertEqual(summary_value.tag, 'multiple_2dimages_test')
        actual_imgs = [
            tf.image.decode_image(s)
            for s in summary_value.tensor.string_val[2:]
        ]
        # assert the images were increased in dimension
        self.assertEqual(np.stack(actual_imgs, axis=0).shape, (2, 30, 30, 3))
Beispiel #13
0
    def test_summarywriter_multiple_images(self):
        log_dir = tempfile.mkdtemp()
        summary_writer = SummaryWriter(log_dir=log_dir)
        expected_img = np.random.uniform(low=0.,
                                         high=255.,
                                         size=(2, 30, 30, 3))
        expected_img = expected_img.astype(np.uint8)
        summary_writer.image(tag='multiple_images_test',
                             image=expected_img,
                             step=1)
        summary_value = self.parse_and_return_summary_value(path=log_dir)

        self.assertEqual(summary_value.tag, 'multiple_images_test')
        actual_imgs = [
            tf.image.decode_image(s)
            for s in summary_value.tensor.string_val[2:]
        ]
        self.assertTrue(
            np.allclose(np.stack(actual_imgs, axis=0), expected_img))
Beispiel #14
0
  def test_summarywriter_audio_sampled_output(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    expected_audio_samples = onp.random.uniform(
        low=-1., high=1., size=(2, 48000, 2))
    summary_writer.audio(
        tag='audio_test', audiodata=expected_audio_samples, step=1,
        max_outputs=1)

    summary_value = self.parse_and_return_summary_value(path=log_dir)
    self.assertEqual(summary_value.tag, 'audio_test')

    # Assert only the first audio clip is available.
    self.assertLen(summary_value.tensor.string_val, 2)

    # Assert values.
    actual_audio = tf.audio.decode_wav(summary_value.tensor.string_val[0]).audio

    self.assertTrue(onp.allclose(
        expected_audio_samples[0], actual_audio, atol=1e-04))
Beispiel #15
0
def _log_histograms(writer: tensorboard.SummaryWriter, model: models.NerfModel,
                    state: model_utils.TrainState):
    """Log histograms to Tensorboard."""
    step = int(state.optimizer.state.step)
    params = state.optimizer.target['model']
    if 'appearance_encoder' in params:
        embeddings = params['appearance_encoder']['embed']['embedding']
        writer.histogram('appearance_embedding', embeddings, step)
    if 'camera_encoder' in params:
        embeddings = params['camera_encoder']['embed']['embedding']
        writer.histogram('camera_embedding', embeddings, step)
    if 'warp_field' in params and model.warp_metadata_encoder_type == 'glo':
        embeddings = params['warp_field']['metadata_encoder']['embed'][
            'embedding']
        writer.histogram('warp_embedding', embeddings, step)
Beispiel #16
0
def local_train_loop(key,
                     init_params,
                     loss_fn,
                     summarize_fn=default_summarize,
                     lr=1e-4,
                     num_steps=int(1e5),
                     summarize_every=100,
                     checkpoint_every=5000,
                     clobber_checkpoint=False,
                     logdir="/tmp/lda_inference"):

    optimizer_def = optim.Adam()
    optimizer = optimizer_def.create(init_params)
    optimizer = util.maybe_load_checkpoint(
        logdir, optimizer, clobber_checkpoint=clobber_checkpoint)
    lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr)

    def train_step(optimizer, key):
        loss_val, loss_grad = jax.value_and_grad(loss_fn,
                                                 argnums=0)(optimizer.target,
                                                            key)
        new_optimizer = optimizer.apply_gradient(loss_grad,
                                                 learning_rate=lr_fn(
                                                     optimizer.state.step))
        return loss_val, new_optimizer

    train_step = jit(train_step)

    sw = SummaryWriter(logdir)

    start = timeit.default_timer()
    first_step = optimizer.state.step
    for t in range(optimizer.state.step, num_steps):
        if t % checkpoint_every == 0 and t != first_step:
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)
        key, subkey = jax.random.split(key)
        try:
            loss_val, new_optimizer = train_step(optimizer, subkey)
        except FloatingPointError as e:
            print("Exception on step %d" % t)
            print(e)
            traceback.print_exc()
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)
            print("key ", subkey)
            sys.stdout.flush()
            sys.exit(1)
        optimizer = new_optimizer
        if t % summarize_every == 0:
            key, subkey = jax.random.split(key)
            print("Step %d loss: %0.4f" % (t, loss_val))
            sw.scalar("loss", loss_val, step=t)
            summarize_fn(sw, t, optimizer.target, subkey)
            end = timeit.default_timer()
            if t == 0:
                steps_per_sec = 1. / (end - start)
            else:
                steps_per_sec = summarize_every / (end - start)
            print("Steps/sec: %0.2f" % steps_per_sec)
            sw.scalar("steps_per_sec", steps_per_sec, step=t)
            start = end
            sw.flush()
            sys.stdout.flush()
Beispiel #17
0
def parallel_train_loop(key,
                        init_params,
                        loss_fn,
                        summarize_fn=default_summarize,
                        lr=1e-4,
                        num_steps=int(1e5),
                        summarize_every=100,
                        checkpoint_every=5000,
                        clobber_checkpoint=False,
                        logdir="/tmp/lda_inference"):

    loss_fn = jax.jit(loss_fn)

    optimizer_def = optim.Adam()
    local_optimizer = optimizer_def.create(init_params)
    local_optimizer = util.maybe_load_checkpoint(
        logdir, local_optimizer, clobber_checkpoint=clobber_checkpoint)
    first_step = local_optimizer.state.step
    repl_optimizer = jax_utils.replicate(local_optimizer)

    lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr)

    @functools.partial(jax.pmap, axis_name="batch")
    def train_step(optimizer, key):
        key, subkey = jax.random.split(key)
        loss_grad = jax.grad(loss_fn, argnums=0)(optimizer.target, key)
        loss_grad = jax.lax.pmean(loss_grad, "batch")
        new_optimizer = optimizer.apply_gradient(loss_grad,
                                                 learning_rate=lr_fn(
                                                     optimizer.state.step))
        return new_optimizer, subkey

    sw = SummaryWriter(logdir)

    repl_key = jax.pmap(jax.random.PRNGKey)(jnp.arange(
        jax.local_device_count()))
    start = timeit.default_timer()
    for t in range(first_step, num_steps):
        if t % checkpoint_every == 0 and t != first_step:
            optimizer = jax_utils.unreplicate(repl_optimizer)
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)

        repl_optimizer, repl_key = train_step(repl_optimizer, repl_key)

        if t % summarize_every == 0:
            key, subkey = jax.random.split(jax_utils.unreplicate(repl_key))
            optimizer = jax_utils.unreplicate(repl_optimizer)
            loss_val = loss_fn(optimizer.target, key)
            print("Step %d loss: %0.4f" % (t, loss_val))
            sw.scalar("loss", loss_val, step=t)
            summarize_fn(sw, t, optimizer.target, subkey)
            end = timeit.default_timer()
            if t == 0:
                steps_per_sec = 1. / (end - start)
            else:
                steps_per_sec = summarize_every / (end - start)
            print("Steps/sec: %0.2f" % steps_per_sec)
            sw.scalar("steps_per_sec", steps_per_sec, step=t)
            start = end
            sw.flush()
            sys.stdout.flush()
Beispiel #18
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
        )

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            dataset["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained config and tokenizer
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if training_args.do_train:
        column_names = dataset["train"].column_names
    else:
        column_names = dataset["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
            )
        return output

    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > config.max_position_embeddings:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --block_size xxx."
            )
            block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models
    # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions
    # GPT2's vocab is odd, we need to resize it for fine-tuning
    model = FlaxAutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
    )

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    optimizer = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
    )

    def get_initial_state(params):
        state = optimizer.init(params)
        return tuple(state), params

    # Get PartitionSpec for model params
    param_spec = set_partitions(unfreeze(model.params))

    # Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
    params_shapes = jax.tree_map(lambda x: x.shape, model.params)
    state_shapes = jax.eval_shape(get_initial_state, params_shapes)

    # get PartitionSpec for opt_state, this is very specific to adamw
    # TODO: optax returns different state for different optimizers, how can we handle this generically ?
    # or maybe we don't since in our examples we just use adamw or adafactor
    def get_opt_spec(x):
        if isinstance(x, dict):
            return param_spec
        return None

    opt_state_spec, param_spec = jax.tree_map(
        get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
    )

    # pjit the get_initial_state function to shard params and init
    # optimizer state in sharded way
    p_get_initial_state = pjit(
        get_initial_state,
        in_axis_resources=None,
        out_axis_resources=(opt_state_spec, param_spec),
    )

    # hack: move the inital params to CPU to free up device memory
    # TODO: allow loading weights on CPU in pre-trained model
    model.params = jax.tree_map(lambda x: np.asarray(x), model.params)

    # mesh defination
    mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())

    # actually initialize the opt_state
    with mesh(mesh_devices, ("dp", "mp")):
        opt_state, params = p_get_initial_state(freeze(model.params))

    # cross-entropy with z loss
    def loss_fn(logits, labels, z_loss=0):
        shift_logits = logits[..., :-1, :]
        shift_labels = labels[..., 1:]

        shift_labels = onehot(shift_labels, shift_logits.shape[-1])

        shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True))
        log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True))
        log_softmax = shift_logits - log_z
        loss = -jnp.sum(shift_labels * log_softmax, axis=-1)

        loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss

        return loss.mean()

    # Define gradient update step fn
    # TODO: try to use TrainState instead of passing params and opt_state individually
    def train_step(params, opt_state, dropout_rng, batch, step):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = loss_fn(logits, labels, z_loss=1.0)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grads = grad_fn(params)

        updates, new_opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)}
        return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1

    # Define eval fn
    def eval_step(input_ids, labels, params):
        logits = model(input_ids=input_ids, params=params, train=False)[0]
        loss = loss_fn(logits, labels)
        # metrics
        return {"loss": loss}

    p_train_step = pjit(
        train_step,
        in_axis_resources=(param_spec, opt_state_spec, None, None, None),
        out_axis_resources=(param_spec, opt_state_spec, None, None, None),
        donate_argnums=(0, 1),
    )

    p_eval_step = pjit(
        eval_step,
        in_axis_resources=(None, None, param_spec),
        out_axis_resources=None,
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    train_metrics = []
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
    global_step = 0
    # we are not doing 2D parallelism (yet!), this just does model parallelism
    with mesh(mesh_devices, ("dp", "mp")):
        for _ in epochs:
            # ======================== Training ================================
            train_start = time.time()

            # Create sampling rng
            rng, input_rng = jax.random.split(rng)

            # Generate an epoch by shuffling sampling indices from the train dataset
            train_metrics = []
            train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
            steps_per_epoch = len(train_dataset) // train_batch_size

            # train
            for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
                batch = next(train_loader)
                params, opt_state, dropout_rng, train_metric, global_step = p_train_step(
                    params,
                    opt_state,
                    dropout_rng,
                    batch,
                    global_step,
                )
                train_metrics.append(train_metric)

                cur_step = global_step

                if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                    # Save metrics
                    train_time += time.time() - train_start
                    if has_tensorboard and jax.process_index() == 0:
                        write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                    epochs.write(
                        f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                    )

                    train_metrics = []

                if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                    # ======================== Evaluating ==============================
                    eval_metrics = []
                    eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
                    eval_steps = len(eval_dataset) // eval_batch_size

                    for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
                        batch = next(eval_loader)
                        metrics = p_eval_step(batch["input_ids"], batch["labels"], params)
                        eval_metrics.append(metrics)

                    # normalize eval metrics
                    eval_metrics = stack_forest(eval_metrics)
                    eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                    try:
                        eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
                    except OverflowError:
                        eval_metrics["perplexity"] = float("inf")

                    logger.info(
                        f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
                    )

                if cur_step % training_args.save_steps == 0 and cur_step > 0:
                    # save checkpoint after each epoch and push checkpoint to the hub
                    if jax.process_index() == 0:
                        params = jax.device_get(params)
                        model.save_pretrained(
                            training_args.output_dir,
                            params=params,
                            push_to_hub=training_args.push_to_hub,
                            commit_message=f"Saving weights and logs of step {cur_step}",
                        )
            truncation=True,
            max_length=data_args.max_seq_length,
        )

    tokenized_datasets = datasets.map(
        tokenize_function,
        input_columns=[text_column_name],
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # Enable tensorboard only on the master node
    if has_tensorboard and jax.host_id() == 0:
        summary_writer = SummaryWriter(
            log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)

    # Setup optimizer
    optimizer = Adam(
        learning_rate=training_args.learning_rate,
        weight_decay=training_args.weight_decay,
        beta1=training_args.adam_beta1,
        beta2=training_args.adam_beta2,
    ).create(model.params)

    # Create learning rate scheduler
Beispiel #20
0
def process_batch(*,
                  batch: Dict[str, jnp.ndarray],
                  rng: types.PRNGKey,
                  state: model_utils.TrainState,
                  tag: str,
                  item_id: str,
                  step: int,
                  summary_writer: tensorboard.SummaryWriter,
                  render_fn: Any,
                  save_dir: Optional[gpath.GPath],
                  datasource: datasets.DataSource):
  """Process and plot a single batch."""
  item_id = item_id.replace('/', '_')
  render = render_fn(state, batch, rng=rng)
  out = {}
  if jax.process_index() != 0:
    return out

  rgb = render['rgb']
  acc = render['acc']
  depth_exp = render['depth']
  depth_med = render['med_depth']
  colorize_depth = functools.partial(viz.colorize,
                                     cmin=datasource.near,
                                     cmax=datasource.far,
                                     invert=True)

  depth_exp_viz = colorize_depth(depth_exp)
  depth_med_viz = colorize_depth(depth_med)
  disp_exp_viz = viz.colorize(1.0 / depth_exp)
  disp_med_viz = viz.colorize(1.0 / depth_med)
  acc_viz = viz.colorize(acc, cmin=0.0, cmax=1.0)
  if save_dir:
    save_dir.mkdir(parents=True, exist_ok=True)
    image_utils.save_image(save_dir / f'rgb_{item_id}.png',
                           image_utils.image_to_uint8(rgb))
    image_utils.save_image(save_dir / f'depth_expected_viz_{item_id}.png',
                           image_utils.image_to_uint8(depth_exp_viz))
    image_utils.save_depth(save_dir / f'depth_expected_{item_id}.png',
                           depth_exp)
    image_utils.save_image(save_dir / f'depth_median_viz_{item_id}.png',
                           image_utils.image_to_uint8(depth_med_viz))
    image_utils.save_depth(save_dir / f'depth_median_{item_id}.png',
                           depth_med)

  summary_writer.image(f'rgb/{tag}/{item_id}', rgb, step)
  summary_writer.image(f'depth-expected/{tag}/{item_id}', depth_exp_viz, step)
  summary_writer.image(f'depth-median/{tag}/{item_id}', depth_med_viz, step)
  summary_writer.image(f'disparity-expected/{tag}/{item_id}', disp_exp_viz,
                       step)
  summary_writer.image(f'disparity-median/{tag}/{item_id}', disp_med_viz, step)
  summary_writer.image(f'acc/{tag}/{item_id}', acc_viz, step)

  if 'rgb' in batch:
    rgb_target = batch['rgb']
    mse = ((rgb - batch['rgb'])**2).mean()
    psnr = utils.compute_psnr(mse)
    ssim = compute_multiscale_ssim(rgb_target, rgb)
    out['mse'] = mse
    out['psnr'] = psnr
    out['ssim'] = ssim
    logging.info('\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f',
                 mse, psnr, ssim)

    rgb_abs_error = viz.colorize(
        abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1)
    rgb_sq_error = viz.colorize(
        ((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1)
    summary_writer.image(f'rgb-target/{tag}/{item_id}', rgb_target, step)
    summary_writer.image(f'rgb-abs-error/{tag}/{item_id}', rgb_abs_error, step)
    summary_writer.image(f'rgb-sq-error/{tag}/{item_id}', rgb_sq_error, step)

  if 'depth' in batch:
    depth_target = batch['depth']
    depth_target_viz = colorize_depth(depth_target[..., 0])
    out['depth_abs'] = jnp.nanmean(jnp.abs(depth_target - depth_med))
    summary_writer.image(
        f'depth-target/{tag}/{item_id}', depth_target_viz, step)
    depth_med_error = viz.colorize(
        abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1)
    summary_writer.image(
        f'depth-median-error/{tag}/{item_id}', depth_med_error, step)
    depth_exp_error = viz.colorize(
        abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1)
    summary_writer.image(
        f'depth-expected-error/{tag}/{item_id}', depth_exp_error, step)

  return out
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files this script will use the first column for the full texts and the second column for the
    # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
    #
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(data_args.dataset_name,
                               data_args.dataset_config_name,
                               cache_dir=model_args.cache_dir,
                               keep_in_memory=False)
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
        dataset = load_dataset(extension,
                               data_files=data_files,
                               cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name,
                                            cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path,
                                            cache_dir=model_args.cache_dir)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype))
    else:
        model = FlaxAutoModelForSeq2SeqLM.from_config(config,
                                                      seed=training_args.seed,
                                                      dtype=getattr(
                                                          jnp,
                                                          model_args.dtype))

    if model.config.decoder_start_token_id is None:
        raise ValueError(
            "Make sure that `config.decoder_start_token_id` is correctly defined"
        )

    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
        column_names = dataset["train"].column_names
    elif training_args.do_eval:
        column_names = dataset["validation"].column_names
    elif training_args.do_predict:
        column_names = dataset["test"].column_names
    else:
        logger.info(
            "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
        )
        return

    # Get the column names for input/target.
    dataset_columns = summarization_name_mapping.get(data_args.dataset_name,
                                                     None)
    if data_args.text_column is None:
        text_column = dataset_columns[
            0] if dataset_columns is not None else column_names[0]
    else:
        text_column = data_args.text_column
        if text_column not in column_names:
            raise ValueError(
                f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
            )
    if data_args.summary_column is None:
        summary_column = dataset_columns[
            1] if dataset_columns is not None else column_names[1]
    else:
        summary_column = data_args.summary_column
        if summary_column not in column_names:
            raise ValueError(
                f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
            )

    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length

    # In Flax, for seq2seq models we need to pass `decoder_input_ids`
    # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
    # for that dynamically import the `shift_tokens_right` function from the model file
    model_module = __import__(model.__module__,
                              fromlist=["shift_tokens_tight"])
    shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")

    # Setting padding="max_length" as we need fixed length inputs for jitted functions
    def preprocess_function(examples):
        inputs = examples[text_column]
        targets = examples[summary_column]
        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs,
                                 max_length=data_args.max_source_length,
                                 padding="max_length",
                                 truncation=True,
                                 return_tensors="np")

        # Setup the tokenizer for targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(targets,
                               max_length=max_target_length,
                               padding="max_length",
                               truncation=True,
                               return_tensors="np")

        model_inputs["labels"] = labels["input_ids"]
        decoder_input_ids = shift_tokens_right_fn(
            labels["input_ids"], config.pad_token_id,
            config.decoder_start_token_id)
        model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)

        # We need decoder_attention_mask so we can ignore pad tokens from loss
        model_inputs["decoder_attention_mask"] = labels["attention_mask"]

        return model_inputs

    if training_args.do_train:
        if "train" not in dataset:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = dataset["train"]
        if data_args.max_train_samples is not None:
            train_dataset = train_dataset.select(
                range(data_args.max_train_samples))
        train_dataset = train_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on train dataset",
        )

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
        if "validation" not in dataset:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = dataset["validation"]
        if data_args.max_eval_samples is not None:
            eval_dataset = eval_dataset.select(
                range(data_args.max_eval_samples))
        eval_dataset = eval_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on validation dataset",
        )

    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
        if "test" not in dataset:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = dataset["test"]
        if data_args.max_predict_samples is not None:
            predict_dataset = predict_dataset.select(
                range(data_args.max_predict_samples))
        predict_dataset = predict_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on prediction dataset",
        )

    # Metric
    metric = load_metric("rouge")

    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

    def compute_metrics(preds, labels):
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels,
                                                skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(
            decoded_preds, decoded_labels)

        result = metric.compute(predictions=decoded_preds,
                                references=decoded_labels,
                                use_stemmer=True)
        # Extract a few results from ROUGE
        result = {
            key: value.mid.fmeasure * 100
            for key, value in result.items()
        }

        prediction_lens = [
            np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
        ]
        result["gen_len"] = np.mean(prediction_lens)
        result = {k: round(v, 4) for k, v in result.items()}
        return result

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    # Note that this mask is specifically adapted for FlaxBart.
    # For FlaxT5, one should correct the layer norm parameter naming
    # accordingly - see `run_t5_mlm_flax.py` e.g.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        layer_norm_params = [(name, "scale") for name in [
            "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"
        ]]
        flat_mask = {
            path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        mask=decay_mask_fn,
    )

    # Setup train state
    state = TrainState.create(apply_fn=model.__call__,
                              params=model.params,
                              tx=adamw,
                              dropout_rng=dropout_rng)

    # label smoothed cross entropy
    def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
        """
        The label smoothing implementation is adapted from Flax's official example:
        https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
        """
        vocab_size = logits.shape[-1]
        confidence = 1.0 - label_smoothing_factor
        low_confidence = (1.0 - confidence) / (vocab_size - 1)
        normalizing_constant = -(confidence * jnp.log(confidence) +
                                 (vocab_size - 1) * low_confidence *
                                 jnp.log(low_confidence + 1e-20))
        soft_labels = onehot(labels,
                             vocab_size,
                             on_value=confidence,
                             off_value=low_confidence)

        loss = optax.softmax_cross_entropy(logits, soft_labels)
        loss = loss - normalizing_constant

        # ignore padded tokens from loss
        loss = loss * padding_mask
        loss = loss.sum() / padding_mask.sum()
        return loss

    # Define gradient update step fn
    def train_step(state, batch, label_smoothing_factor=0.0):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = loss_fn(logits, labels, batch["decoder_attention_mask"],
                           label_smoothing_factor)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")

        new_state = state.apply_gradients(grads=grad,
                                          dropout_rng=new_dropout_rng)

        metrics = {
            "loss": loss,
            "learning_rate": linear_decay_lr_schedule_fn(state.step)
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch, label_smoothing_factor=0.0):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]
        loss = loss_fn(logits, labels, batch["decoder_attention_mask"],
                       label_smoothing_factor)

        # summarize metrics
        metrics = {"loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    # Define generation function
    max_length = (data_args.val_max_target_length
                  if data_args.val_max_target_length is not None else
                  model.config.max_length)
    num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
    gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

    def generate_step(params, batch):
        model.params = params
        output_ids = model.generate(batch["input_ids"],
                                    attention_mask=batch["attention_mask"],
                                    **gen_kwargs)
        return output_ids.sequences

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(partial(
        train_step,
        label_smoothing_factor=training_args.label_smoothing_factor),
                            "batch",
                            donate_argnums=(0, ))
    p_eval_step = jax.pmap(
        partial(eval_step,
                label_smoothing_factor=training_args.label_smoothing_factor),
        "batch")
    p_generate_step = jax.pmap(generate_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel & distributed) = {train_batch_size}"
    )
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)
        train_metrics = []

        # Generate an epoch by shuffling sampling indices from the train dataset
        train_loader = data_loader(input_rng,
                                   train_dataset,
                                   train_batch_size,
                                   shuffle=True)
        steps_per_epoch = len(train_dataset) // train_batch_size
        # train
        for _ in tqdm(range(steps_per_epoch),
                      desc="Training...",
                      position=1,
                      leave=False):
            batch = next(train_loader)
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

        train_time += time.time() - train_start

        train_metric = unreplicate(train_metric)

        epochs.write(
            f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
        )

        # ======================== Evaluating ==============================
        eval_metrics = []
        eval_preds = []
        eval_labels = []

        eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
        eval_steps = len(eval_dataset) // eval_batch_size
        for _ in tqdm(range(eval_steps),
                      desc="Evaluating...",
                      position=2,
                      leave=False):
            # Model forward
            batch = next(eval_loader)
            labels = batch["labels"]

            metrics = p_eval_step(state.params, batch)
            eval_metrics.append(metrics)

            # generation
            if data_args.predict_with_generate:
                generated_ids = p_generate_step(state.params, batch)
                eval_preds.extend(
                    jax.device_get(
                        generated_ids.reshape(-1, gen_kwargs["max_length"])))
                eval_labels.extend(
                    jax.device_get(labels.reshape(-1, labels.shape[-1])))

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

        # compute ROUGE metrics
        rouge_desc = ""
        if data_args.predict_with_generate:
            rouge_metrics = compute_metrics(eval_preds, eval_labels)
            eval_metrics.update(rouge_metrics)
            rouge_desc = " ".join([
                f"Eval {key}: {value} |"
                for key, value in rouge_metrics.items()
            ])

        # Print metrics and update progress bar
        desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
        epochs.write(desc)
        epochs.desc = desc

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            cur_step = epoch * (len(train_dataset) // train_batch_size)
            write_metric(summary_writer, train_metrics, eval_metrics,
                         train_time, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0:
            params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(training_args.output_dir, params=params)
            tokenizer.save_pretrained(training_args.output_dir)
            if training_args.push_to_hub:
                repo.push_to_hub(
                    commit_message=f"Saving weights and logs of epoch {epoch}",
                    blocking=False)

    # ======================== Prediction loop ==============================
    if training_args.do_predict:
        logger.info("*** Predict ***")

        pred_metrics = []
        pred_generations = []
        pred_labels = []

        pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
        pred_steps = len(predict_dataset) // eval_batch_size
        for _ in tqdm(range(pred_steps),
                      desc="Predicting...",
                      position=2,
                      leave=False):
            # Model forward
            batch = next(pred_loader)
            labels = batch["labels"]

            metrics = p_eval_step(state.params, batch)
            pred_metrics.append(metrics)

            # generation
            if data_args.predict_with_generate:
                generated_ids = p_generate_step(state.params, batch)
                pred_generations.extend(
                    jax.device_get(
                        generated_ids.reshape(-1, gen_kwargs["max_length"])))
                pred_labels.extend(
                    jax.device_get(labels.reshape(-1, labels.shape[-1])))

        # normalize prediction metrics
        pred_metrics = get_metrics(pred_metrics)
        pred_metrics = jax.tree_map(jnp.mean, pred_metrics)

        # compute ROUGE metrics
        rouge_desc = ""
        if data_args.predict_with_generate:
            rouge_metrics = compute_metrics(pred_generations, pred_labels)
            pred_metrics.update(rouge_metrics)
            rouge_desc = " ".join([
                f"Predict {key}: {value} |"
                for key, value in rouge_metrics.items()
            ])

        # Print metrics
        desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
        logger.info(desc)

        # save final metrics in json
        if jax.process_index() == 0:
            rouge_metrics = {
                f"test_{metric_name}": value
                for metric_name, value in rouge_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "test_results.json")
            with open(path, "w") as f:
                json.dump(rouge_metrics, f, indent=4, sort_keys=True)
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # set seed for random transforms and torch dataloaders
    set_seed(training_args.seed)

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Initialize datasets and pre-processing transforms
    # We use torchvision here for faster pre-processing
    # Note that here we are using some default pre-processing, for maximum accuray
    # one should tune this part and carefully select what transformations to use.
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_dataset = torchvision.datasets.ImageFolder(
        data_args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(data_args.image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
    )

    eval_dataset = torchvision.datasets.ImageFolder(
        data_args.validation_dir,
        transforms.Compose([
            transforms.Resize(data_args.image_size),
            transforms.CenterCrop(data_args.image_size),
            transforms.ToTensor(),
            normalize,
        ]),
    )

    # Load pretrained model and tokenizer
    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name,
            num_labels=len(train_dataset.classes),
            image_size=data_args.image_size,
            cache_dir=model_args.cache_dir,
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            num_labels=len(train_dataset.classes),
            image_size=data_args.image_size,
            cache_dir=model_args.cache_dir,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    if model_args.model_name_or_path:
        model = FlaxAutoModelForImageClassification.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype))
    else:
        model = FlaxAutoModelForImageClassification.from_config(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype))

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    def collate_fn(examples):
        pixel_values = torch.stack([example[0] for example in examples])
        labels = torch.tensor([example[1] for example in examples])

        batch = {"pixel_values": pixel_values, "labels": labels}
        batch = {k: v.numpy() for k, v in batch.items()}

        return batch

    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=data_args.preprocessing_num_workers,
        persistent_workers=True,
        drop_last=True,
        collate_fn=collate_fn,
    )

    eval_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=data_args.preprocessing_num_workers,
        persistent_workers=True,
        drop_last=True,
        collate_fn=collate_fn,
    )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
    )

    # Setup train state
    state = TrainState.create(apply_fn=model.__call__,
                              params=model.params,
                              tx=adamw,
                              dropout_rng=dropout_rng)

    def loss_fn(logits, labels):
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels, logits.shape[-1]))
        return loss.mean()

    # Define gradient update step fn
    def train_step(state, batch):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = loss_fn(logits, labels)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")

        new_state = state.apply_gradients(grads=grad,
                                          dropout_rng=new_dropout_rng)

        metrics = {
            "loss": loss,
            "learning_rate": linear_decay_lr_schedule_fn(state.step)
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]
        loss = loss_fn(logits, labels)

        # summarize metrics
        accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()
        metrics = {"loss": loss, "accuracy": accuracy}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))
    p_eval_step = jax.pmap(eval_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel & distributed) = {train_batch_size}"
    )
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)
        train_metrics = []

        steps_per_epoch = len(train_dataset) // train_batch_size
        train_step_progress_bar = tqdm(total=steps_per_epoch,
                                       desc="Training...",
                                       position=1,
                                       leave=False)
        # train
        for batch in train_loader:
            batch = shard(batch)
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

            train_step_progress_bar.update(1)

        train_time += time.time() - train_start

        train_metric = unreplicate(train_metric)

        train_step_progress_bar.close()
        epochs.write(
            f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
        )

        # ======================== Evaluating ==============================
        eval_metrics = []
        eval_steps = len(eval_dataset) // eval_batch_size
        eval_step_progress_bar = tqdm(total=eval_steps,
                                      desc="Evaluating...",
                                      position=2,
                                      leave=False)
        for batch in eval_loader:
            # Model forward
            batch = shard(batch)
            metrics = p_eval_step(state.params, batch)
            eval_metrics.append(metrics)

            eval_step_progress_bar.update(1)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

        # Print metrics and update progress bar
        eval_step_progress_bar.close()
        desc = (
            f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | "
            f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})")
        epochs.write(desc)
        epochs.desc = desc

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            cur_step = epoch * (len(train_dataset) // train_batch_size)
            write_metric(summary_writer, train_metrics, eval_metrics,
                         train_time, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0:
            params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(training_args.output_dir, params=params)
            if training_args.push_to_hub:
                repo.push_to_hub(
                    commit_message=f"Saving weights and logs of epoch {epoch}",
                    blocking=False)
Beispiel #23
0
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    elif model_args.text_model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.text_model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    model = FlaxHybridCLIP.from_text_vision_pretrained(
        model_args.text_model_name_or_path,
        model_args.vision_model_name_or_path,
        seed=training_args.seed,
        dtype=getattr(jnp, model_args.dtype),
        text_from_pt=model_args.from_pt,
        vision_from_pt=model_args.from_pt,
    )
    config = model.config
    # set seed for torch dataloaders
    set_seed(training_args.seed)

    # Initialize torchvision transforms and jit them for faster processing
    preprocess = Transform(config.vision_config.image_size)
    preprocess = torch.jit.script(preprocess)

    # Initialize the image-text dataset
    train_dataset = ImageTextDataset(
        data_args.data_dir,
        data_args.train_file,
        captions_per_image=2,
        transform=preprocess,
    )

    eval_dataset = ImageTextDataset(
        data_args.data_dir,
        data_args.validation_file,
        captions_per_image=1,
        transform=preprocess,
    )

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # Use collate function to tokenizer the text and convert the processed images to numpy
    def collate_fn(examples):
        pixel_values = torch.stack([example[0] for example in examples
                                    ]).permute(0, 2, 3, 1).numpy()
        captions = [example[1] for example in examples]
        inputs = tokenizer(captions,
                           max_length=data_args.max_seq_length,
                           padding="max_length",
                           return_tensors="np")

        batch = {
            "pixel_values": pixel_values,
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
        }

        return batch

    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=data_args.preprocessing_num_workers,
        persistent_workers=True,
        drop_last=True,
        collate_fn=collate_fn,
    )

    eval_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=data_args.preprocessing_num_workers,
        persistent_workers=True,
        drop_last=True,
        collate_fn=collate_fn,
    )

    # Enable tensorboard only on the master node
    if has_tensorboard and jax.process_index() == 0:
        summary_writer = SummaryWriter(
            log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
    )

    # Setup train state
    state = TrainState.create(apply_fn=model.__call__,
                              params=model.params,
                              tx=adamw,
                              dropout_rng=dropout_rng)

    def cross_entropy(logits, axis):
        logprobs = jax.nn.log_softmax(logits, axis=axis)
        nll = jnp.diag(logprobs)
        ce = -jnp.mean(nll)
        return ce

    def clip_loss(similarity):
        loss = (cross_entropy(similarity, axis=0) +
                cross_entropy(similarity, axis=1)) / 2
        return loss

    # Define gradient update step fn
    def train_step(state, batch):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = clip_loss(logits)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")

        new_state = state.apply_gradients(grads=grad,
                                          dropout_rng=new_dropout_rng)

        metrics = {
            "loss": loss,
            "learning_rate": linear_decay_lr_schedule_fn(state.step)
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch):
        logits = model(**batch, params=params, train=False)[0]
        loss = clip_loss(logits)

        # summarize metrics
        metrics = {"loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))
    p_eval_step = jax.pmap(eval_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel & distributed) = {train_batch_size}"
    )
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    # Create sampling rng
    rng, input_rng = jax.random.split(rng)

    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)
        train_metrics = []

        steps_per_epoch = len(train_dataset) // train_batch_size
        train_step_progress_bar = tqdm(total=steps_per_epoch,
                                       desc="Training...",
                                       position=1,
                                       leave=False)
        # train
        for batch in train_loader:
            batch = shard(batch)
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

            train_step_progress_bar.update(1)

        train_time += time.time() - train_start

        train_metric = unreplicate(train_metric)

        train_step_progress_bar.close()
        epochs.write(
            f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
        )

        # ======================== Evaluating ==============================
        eval_metrics = []
        eval_steps = len(eval_dataset) // eval_batch_size
        eval_step_progress_bar = tqdm(total=eval_steps,
                                      desc="Evaluating...",
                                      position=2,
                                      leave=False)
        for batch in eval_loader:
            # Model forward
            batch = shard(batch)
            metrics = p_eval_step(state.params, batch)
            eval_metrics.append(metrics)

            eval_step_progress_bar.update(1)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)

        eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

        # Print metrics and update progress bar
        eval_step_progress_bar.close()
        desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
        epochs.write(desc)
        epochs.desc = desc

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            cur_step = epoch * (len(train_dataset) // train_batch_size)
            write_metric(summary_writer, train_metrics, eval_metrics,
                         train_time, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0:
            params = jax.device_get(unreplicate(state.params))
            model.save_pretrained(
                training_args.output_dir,
                params=params,
                push_to_hub=training_args.push_to_hub,
                commit_message=f"Saving weights and logs of epoch {epoch+1}",
            )
Beispiel #24
0
def process_batch(
    *,
    batch: Dict[str, jnp.ndarray],
    rng: types.PRNGKey,
    state: model_utils.TrainState,
    tag: str,
    item_id: str,
    step: int,
    summary_writer: tensorboard.SummaryWriter,
    render_fn: Any,
    save_dir: Optional[gpath.GPath],
    datasource: datasets.DataSource,
):
    """Process and plot a single batch."""
    rgb, depth_exp, depth_med, acc = render_fn(state, batch, rng=rng)
    out = {}
    if jax.host_id() != 0:
        return out

    colorize_depth = functools.partial(viz.colorize,
                                       cmin=datasource.near,
                                       cmax=datasource.far,
                                       invert=True)

    depth_exp_viz = colorize_depth(depth_exp[..., 0])
    depth_med_viz = colorize_depth(depth_med[..., 0])
    if save_dir:
        save_dir.mkdir(parents=True, exist_ok=True)
        image_utils.save_image(save_dir / f"rgb_{item_id}.png",
                               image_utils.image_to_uint8(rgb))
        image_utils.save_image(
            save_dir / f"depth_expected_viz_{item_id}.png",
            image_utils.image_to_uint8(depth_exp_viz),
        )
        image_utils.save_depth(save_dir / f"depth_expected_{item_id}.png",
                               depth_med[..., 0])
        image_utils.save_image(
            save_dir / f"depth_median_viz_{item_id}.png",
            image_utils.image_to_uint8(depth_med_viz),
        )
        image_utils.save_depth(save_dir / f"depth_median_{item_id}.png",
                               depth_med[..., 0])

    summary_writer.image(f"rgb/{tag}/{item_id}", rgb, step)
    summary_writer.image(f"depth-expected/{tag}/{item_id}", depth_exp_viz,
                         step)
    summary_writer.image(f"depth-median/{tag}/{item_id}", depth_med_viz, step)
    summary_writer.image(f"acc/{tag}/{item_id}", acc, step)

    if "rgb" in batch:
        rgb_target = batch["rgb"]
        mse = ((rgb - batch["rgb"])**2).mean()
        psnr = utils.compute_psnr(mse)
        ssim = compute_multiscale_ssim(rgb_target, rgb)
        out["mse"] = mse
        out["psnr"] = psnr
        out["ssim"] = ssim
        logging.info("\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f", mse, psnr,
                     ssim)

        rgb_abs_error = viz.colorize(abs(rgb_target - rgb).sum(axis=-1),
                                     cmin=0,
                                     cmax=1)
        rgb_sq_error = viz.colorize(((rgb_target - rgb)**2).sum(axis=-1),
                                    cmin=0,
                                    cmax=1)
        summary_writer.image(f"rgb-target/{tag}/{item_id}", rgb_target, step)
        summary_writer.image(f"rgb-abs-error/{tag}/{item_id}", rgb_abs_error,
                             step)
        summary_writer.image(f"rgb-sq-error/{tag}/{item_id}", rgb_sq_error,
                             step)

    if "depth" in batch:
        depth_target = batch["depth"]
        depth_target_viz = colorize_depth(depth_target[..., 0])
        out["depth_abs"] = jnp.nanmean(jnp.abs(depth_target - depth_med))
        summary_writer.image(f"depth-target/{tag}/{item_id}", depth_target_viz,
                             step)
        depth_med_error = viz.colorize(abs(depth_target -
                                           depth_med).squeeze(axis=-1),
                                       cmin=0,
                                       cmax=1)
        summary_writer.image(f"depth-median-error/{tag}/{item_id}",
                             depth_med_error, step)
        depth_exp_error = viz.colorize(abs(depth_target -
                                           depth_exp).squeeze(axis=-1),
                                       cmin=0,
                                       cmax=1)
        summary_writer.image(f"depth-expected-error/{tag}/{item_id}",
                             depth_exp_error, step)
        rel_disp_pred = viz.colorize(1.0 / depth_exp[..., 0])
        summary_writer.image(f"relative-disparity/{tag}/{item_id}",
                             rel_disp_pred, step)

    return out
Beispiel #25
0
        # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
        # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
        tokenized_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=not data_args.overwrite_cache,
        )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
Beispiel #26
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    #  Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(data_args.dataset_name,
                               data_args.dataset_config_name,
                               cache_dir=model_args.cache_dir,
                               keep_in_memory=False)

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            dataset["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        dataset = load_dataset(extension,
                               data_files=data_files,
                               cache_dir=model_args.cache_dir)

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            dataset["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name,
                                            cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path,
                                            cache_dir=model_args.cache_dir)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = FlaxAutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype))
    else:
        model = FlaxAutoModelForCausalLM.from_config(config,
                                                     seed=training_args.seed,
                                                     dtype=getattr(
                                                         jnp,
                                                         model_args.dtype))

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = dataset["train"].column_names
    else:
        column_names = dataset["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger(
        "transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
            )
        return output

    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > config.max_position_embeddings:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --block_size xxx."
            )
            block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: sum(examples[k], [])
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k:
            [t[i:i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets["train"]
        if data_args.max_train_samples is not None:
            train_dataset = train_dataset.select(
                range(data_args.max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets["validation"]
        if data_args.max_eval_samples is not None:
            eval_dataset = eval_dataset.select(
                range(data_args.max_eval_samples))

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    # Note that this mask is specifically adapted for FlaxGPT2.
    # For other models, one should correct the layer norm parameter naming
    # accordingly.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        flat_mask = {
            path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"),
                                                            ("ln_2", "scale"),
                                                            ("ln_f", "scale")])
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn, )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            eps=training_args.adam_epsilon,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # Setup train state
    state = TrainState.create(apply_fn=model.__call__,
                              params=model.params,
                              tx=optimizer,
                              dropout_rng=dropout_rng)

    def loss_fn(logits, labels):
        shift_logits = logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        loss = optax.softmax_cross_entropy(
            shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
        return loss.mean()

    # Define gradient update step fn
    def train_step(state, batch):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = loss_fn(logits, labels)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")

        new_state = state.apply_gradients(grads=grad,
                                          dropout_rng=new_dropout_rng)

        metrics = {
            "loss": loss,
            "learning_rate": linear_decay_lr_schedule_fn(state.step)
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]
        loss = loss_fn(logits, labels)

        # summarize metrics
        metrics = {"loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))
    p_eval_step = jax.pmap(eval_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel & distributed) = {train_batch_size}"
    )
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    train_metrics = []
    epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        train_loader = data_loader(input_rng,
                                   train_dataset,
                                   train_batch_size,
                                   shuffle=True)
        steps_per_epoch = len(train_dataset) // train_batch_size
        # train
        for step in tqdm(range(steps_per_epoch),
                         desc="Training...",
                         position=1,
                         leave=False):
            batch = next(train_loader)
            batch = shard(batch)
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

            cur_step = epoch * (len(train_dataset) // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
                )

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                eval_metrics = []
                eval_loader = data_loader(input_rng, eval_dataset,
                                          eval_batch_size)
                eval_steps = len(eval_dataset) // eval_batch_size
                for _ in tqdm(range(eval_steps),
                              desc="Evaluating...",
                              position=2,
                              leave=False):
                    # Model forward
                    batch = next(eval_loader)
                    batch = shard(batch)
                    metrics = p_eval_step(state.params, batch)
                    eval_metrics.append(metrics)

                # normalize eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                try:
                    eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
                except OverflowError:
                    eval_metrics["perplexity"] = float("inf")

                # Print metrics and update progress bar
                desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
                epochs.write(desc)
                epochs.desc = desc

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
                    model.save_pretrained(
                        training_args.output_dir,
                        params=params,
                        push_to_hub=training_args.push_to_hub,
                        commit_message=
                        f"Saving weights and logs of step {cur_step}",
                    )
Beispiel #27
0
def main():
    args = parse_args()

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Handle the repository creation
    if args.push_to_hub:
        if args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token)
        else:
            repo_name = args.hub_model_id
        repo = Repository(args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).

    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.

    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)

    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.task_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset("glue", args.task_name)
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
    model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)

    # Preprocessing the datasets
    if args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and args.task_name is not None
        and not is_regression
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            logger.info(
                f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
                "Using it!"
            )
            label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif args.task_name is None:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    def preprocess_function(examples):
        # Tokenize the texts
        texts = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True)

        if "label" in examples:
            if label_to_id is not None:
                # Map labels to IDs (not necessary for GLUE tasks)
                result["labels"] = [label_to_id[l] for l in examples["label"]]
            else:
                # In all cases, rename the column to labels because the model will expect that.
                result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
        preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
    )

    train_dataset = processed_datasets["train"]
    eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(args.output_dir)
            summary_writer.hparams(vars(args))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )

    def write_metric(train_metrics, eval_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(args.num_train_epochs)
    rng = jax.random.PRNGKey(args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
    eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()

    learning_rate_fn = create_learning_rate_fn(
        len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
    )

    state = create_train_state(
        model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
    )

    # define step functions
    def train_step(
        state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
    ) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
        return new_state, metrics, new_dropout_rng

    p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    if args.task_name is not None:
        metric = load_metric("glue", args.task_name)
    else:
        metric = load_metric("accuracy")

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    # make sure weights are replicated on each device
    state = replicate(state)

    for epoch in range(1, num_epochs + 1):
        logger.info(f"Epoch {epoch}")
        logger.info("  Training...")

        train_start = time.time()
        train_metrics = []
        rng, input_rng = jax.random.split(rng)

        # train
        for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
            state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
            train_metrics.append(metrics)
        train_time += time.time() - train_start
        logger.info(f"    Done! Training metrics: {unreplicate(metrics)}")

        logger.info("  Evaluating...")

        # evaluate
        for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
            labels = batch.pop("labels")
            predictions = p_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))

        # evaluate also on leftover examples (not divisible by batch_size)
        num_leftover_samples = len(eval_dataset) % eval_batch_size

        # make sure leftover batch is evaluated on one device
        if num_leftover_samples > 0 and jax.process_index() == 0:
            # take leftover samples
            batch = eval_dataset[-num_leftover_samples:]
            batch = {k: jnp.array(v) for k, v in batch.items()}

            labels = batch.pop("labels")
            predictions = eval_step(unreplicate(state), batch)
            metric.add_batch(predictions=predictions, references=labels)

        eval_metric = metric.compute()
        logger.info(f"    Done! Eval metrics: {eval_metric}")

        cur_step = epoch * (len(train_dataset) // train_batch_size)

        # Save metrics
        if has_tensorboard and jax.process_index() == 0:
            write_metric(train_metrics, eval_metric, train_time, cur_step)

        # save checkpoint after each epoch and push checkpoint to the hub
        if jax.process_index() == 0:
            params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
            model.save_pretrained(args.output_dir, params=params)
            tokenizer.save_pretrained(args.output_dir)
            if args.push_to_hub:
                repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)

    # save the eval metrics in json
    if jax.process_index() == 0:
        eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
        path = os.path.join(args.output_dir, "eval_results.json")
        with open(path, "w") as f:
            json.dump(eval_metric, f, indent=4, sort_keys=True)
Beispiel #28
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_ner", model_args, data_args, framework="flax")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called
    # 'tokens' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = (data_args.train_file if data_args.train_file is not None
                     else data_args.valid_file).split(".")[-1]
        raw_datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    if raw_datasets["train"] is not None:
        column_names = raw_datasets["train"].column_names
        features = raw_datasets["train"].features
    else:
        column_names = raw_datasets["validation"].column_names
        features = raw_datasets["validation"].features

    if data_args.text_column_name is not None:
        text_column_name = data_args.text_column_name
    elif "tokens" in column_names:
        text_column_name = "tokens"
    else:
        text_column_name = column_names[0]

    if data_args.label_column_name is not None:
        label_column_name = data_args.label_column_name
    elif f"{data_args.task_name}_tags" in column_names:
        label_column_name = f"{data_args.task_name}_tags"
    else:
        label_column_name = column_names[1]

    # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
    # unique labels.
    def get_label_list(labels):
        unique_labels = set()
        for label in labels:
            unique_labels = unique_labels | set(label)
        label_list = list(unique_labels)
        label_list.sort()
        return label_list

    if isinstance(features[label_column_name].feature, ClassLabel):
        label_list = features[label_column_name].feature.names
        # No need to convert the labels since they are already ints.
        label_to_id = {i: i for i in range(len(label_list))}
    else:
        label_list = get_label_list(raw_datasets["train"][label_column_name])
        label_to_id = {l: i for i, l in enumerate(label_list)}
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        label2id=label_to_id,
        id2label={i: l
                  for l, i in label_to_id.items()},
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
    if config.model_type in {"gpt2", "roberta"}:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
            add_prefix_space=True,
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    model = FlaxAutoModelForTokenClassification.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # Preprocessing the datasets
    # Tokenize all texts and align the labels with them.
    def tokenize_and_align_labels(examples):
        tokenized_inputs = tokenizer(
            examples[text_column_name],
            max_length=data_args.max_seq_length,
            padding="max_length",
            truncation=True,
            # We use this argument because the texts in our dataset are lists of words (with a label for each word).
            is_split_into_words=True,
        )

        labels = []

        for i, label in enumerate(examples[label_column_name]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                # Special tokens have a word id that is None. We set the label to -100 so they are automatically
                # ignored in the loss function.
                if word_idx is None:
                    label_ids.append(-100)
                # We set the label for the first token of each word.
                elif word_idx != previous_word_idx:
                    label_ids.append(label_to_id[label[word_idx]])
                # For the other tokens in a word, we set the label to either the current label or -100, depending on
                # the label_all_tokens flag.
                else:
                    label_ids.append(label_to_id[label[word_idx]] if data_args.
                                     label_all_tokens else -100)
                previous_word_idx = word_idx

            labels.append(label_ids)
        tokenized_inputs["labels"] = labels
        return tokenized_inputs

    processed_raw_datasets = raw_datasets.map(
        tokenize_and_align_labels,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
        remove_columns=raw_datasets["train"].column_names,
        desc="Running tokenizer on dataset",
    )

    train_dataset = processed_raw_datasets["train"]
    eval_dataset = processed_raw_datasets["validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(
            f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(training_args.output_dir)
            summary_writer.hparams({
                **training_args.to_dict(),
                **vars(model_args),
                **vars(data_args)
            })
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    def write_train_metric(summary_writer, train_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

    def write_eval_metric(summary_writer, eval_metrics, step):
        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(training_args.num_train_epochs)
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count(
    )
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count(
    )

    learning_rate_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    state = create_train_state(model,
                               learning_rate_fn,
                               num_labels=num_labels,
                               training_args=training_args)

    # define step functions
    def train_step(
            state: train_state.TrainState, batch: Dict[str, Array],
            dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": learning_rate_fn(state.step)
            },
            axis_name="batch")
        return new_state, metrics, new_dropout_rng

    p_train_step = jax.pmap(train_step,
                            axis_name="batch",
                            donate_argnums=(0, ))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    metric = evaluate.load("seqeval")

    def get_labels(y_pred, y_true):
        # Transform predictions and references tensos to numpy arrays

        # Remove ignored index (special tokens)
        true_predictions = [[
            label_list[p] for (p, l) in zip(pred, gold_label) if l != -100
        ] for pred, gold_label in zip(y_pred, y_true)]
        true_labels = [[
            label_list[l] for (p, l) in zip(pred, gold_label) if l != -100
        ] for pred, gold_label in zip(y_pred, y_true)]
        return true_predictions, true_labels

    def compute_metrics():
        results = metric.compute()
        if data_args.return_entity_level_metrics:
            # Unpack nested dictionaries
            final_results = {}
            for key, value in results.items():
                if isinstance(value, dict):
                    for n, v in value.items():
                        final_results[f"{key}_{n}"] = v
                else:
                    final_results[key] = value
            return final_results
        else:
            return {
                "precision": results["overall_precision"],
                "recall": results["overall_recall"],
                "f1": results["overall_f1"],
                "accuracy": results["overall_accuracy"],
            }

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    # make sure weights are replicated on each device
    state = replicate(state)

    train_time = 0
    step_per_epoch = len(train_dataset) // train_batch_size
    total_steps = step_per_epoch * num_epochs
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:

        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # train
        for step, batch in enumerate(
                tqdm(
                    train_data_collator(input_rng, train_dataset,
                                        train_batch_size),
                    total=step_per_epoch,
                    desc="Training...",
                    position=1,
                )):
            state, train_metric, dropout_rngs = p_train_step(
                state, batch, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = (epoch * step_per_epoch) + (step + 1)

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
                    f" {train_metric['learning_rate']})")

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:

                eval_metrics = {}
                # evaluate
                for batch in tqdm(
                        eval_data_collator(eval_dataset, eval_batch_size),
                        total=math.ceil(len(eval_dataset) / eval_batch_size),
                        desc="Evaluating ...",
                        position=2,
                ):
                    labels = batch.pop("labels")
                    predictions = pad_shard_unpad(p_eval_step)(
                        state,
                        batch,
                        min_device_batch=per_device_eval_batch_size)
                    predictions = np.array(predictions)
                    labels[np.array(chain(
                        *batch["attention_mask"])) == 0] = -100
                    preds, refs = get_labels(predictions, labels)
                    metric.add_batch(
                        predictions=preds,
                        references=refs,
                    )

                eval_metrics = compute_metrics()

                if data_args.return_entity_level_metrics:
                    logger.info(
                        f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}"
                    )
                else:
                    logger.info(
                        f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc:"
                        f" {eval_metrics['accuracy']})")

                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if (cur_step % training_args.save_steps == 0
                    and cur_step > 0) or (cur_step == total_steps):
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)
        epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"

    # Eval after training
    if training_args.do_eval:
        eval_metrics = {}
        eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
        for batch in tqdm(eval_loader,
                          total=len(eval_dataset) // eval_batch_size,
                          desc="Evaluating ...",
                          position=2):
            labels = batch.pop("labels")
            predictions = pad_shard_unpad(p_eval_step)(
                state, batch, min_device_batch=per_device_eval_batch_size)
            predictions = np.array(predictions)
            labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
            preds, refs = get_labels(predictions, labels)
            metric.add_batch(predictions=preds, references=refs)

        eval_metrics = compute_metrics()

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_t5_mlm",
                           model_args,
                           data_args,
                           framework="flax")

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        level=logging.INFO,
        datefmt="[%X]",
    )

    # Log on each process the small summary:
    logger = logging.getLogger(__name__)

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.config_name:
        config = T5Config.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
            vocab_size=len(tokenizer),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        config = T5Config.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = datasets["train"].column_names
    else:
        column_names = datasets["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
    # Since we make sure that all sequences are of the same length, no attention_mask is needed.
    def tokenize_function(examples):
        return tokenizer(examples[text_column_name],
                         return_attention_mask=False)

    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
    # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
    # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
    expanded_inputs_length, targets_length = compute_input_and_target_lengths(
        inputs_length=max_seq_length,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
    )

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: list(chain(*examples[k]))
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= expanded_inputs_length:
            total_length = (total_length //
                            expanded_inputs_length) * expanded_inputs_length
        # Split by chunks of max_len.
        result = {
            k: [
                t[i:i + expanded_inputs_length]
                for i in range(0, total_length, expanded_inputs_length)
            ]
            for k, t in concatenated_examples.items()
        }
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
    # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
    # might be slower to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
    tokenized_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    if model_args.model_name_or_path:
        model = FlaxT5ForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config.vocab_size = len(tokenizer)
        model = FlaxT5ForConditionalGeneration(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
        )

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForT5MLM(
        tokenizer=tokenizer,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
        input_length=max_seq_length,
        target_length=targets_length,
        pad_token_id=model.config.pad_token_id,
        decoder_start_token_id=model.config.decoder_start_token_id,
    )

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()

    num_train_steps = len(
        tokenized_datasets["train"]) // train_batch_size * num_epochs

    num_of_hosts = jax.process_count()
    current_host_idx = jax.process_index()

    # Create learning rate schedule
    warmup_fn = optax.linear_schedule(
        init_value=0.0,
        end_value=training_args.learning_rate,
        transition_steps=training_args.warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=training_args.learning_rate,
        end_value=0,
        transition_steps=num_train_steps - training_args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn],
        boundaries=[training_args.warmup_steps])

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = set([
            layer[-2:] for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        ])
        flat_mask = {
            path: (path[-1] != "bias"
                   and path[-2:] not in layer_norm_named_params)
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn, )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # Setup train state
    state = train_state.TrainState.create(apply_fn=model.__call__,
                                          params=model.params,
                                          tx=optimizer)

    # Define gradient update step fn
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]

            # compute loss
            loss = optax.softmax_cross_entropy(
                logits, onehot(labels, logits.shape[-1])).mean()

            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)

        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": linear_decay_lr_schedule_fn(state.step)
            },
            axis_name="batch")

        return new_state, metrics, new_dropout_rng

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # summarize metrics
        metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(tokenized_datasets["train"])
        # Avoid using jax.numpy here in case of TPU training
        train_samples_idx = np.random.permutation(np.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx,
                                                train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(
                tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [
                tokenized_datasets["train"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            local_host_model_inputs = {
                key: np.split(model_inputs.data[key], num_of_hosts,
                              axis=0)[current_host_idx]
                for key, value in model_inputs.data.items()
            }

            # Model forward
            model_inputs = shard(local_host_model_inputs)
            state, train_metric, dropout_rngs = p_train_step(
                state, model_inputs, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
                    f" {train_metric['learning_rate'].mean()})")

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                num_eval_samples = len(tokenized_datasets["validation"])
                # Avoid using jax.numpy here in case of TPU training
                eval_samples_idx = np.arange(num_eval_samples)
                eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                                       eval_batch_size,
                                                       drop_last=False)

                eval_metrics = []
                for i, batch_idx in enumerate(
                        tqdm(eval_batch_idx, desc="Evaluating ...",
                             position=2)):
                    samples = [
                        tokenized_datasets["validation"][int(idx)]
                        for idx in batch_idx
                    ]
                    model_inputs = data_collator(samples)

                    # Model forward
                    metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                        state.params,
                        model_inputs.data,
                        min_device_batch=per_device_eval_batch_size)
                    eval_metrics.append(metrics)

                # get eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                # Update progress bar
                epochs.write(
                    f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
                )

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(
                        jax.tree_map(lambda x: x[0], state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)

    # Eval after training
    if training_args.do_eval:
        num_eval_samples = len(tokenized_datasets["validation"])
        # Avoid using jax.numpy here in case of TPU training
        eval_samples_idx = np.arange(num_eval_samples)
        eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                               eval_batch_size,
                                               drop_last=False)

        eval_metrics = []
        for i, batch_idx in enumerate(
                tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [
                tokenized_datasets["validation"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            # Model forward
            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params,
                model_inputs.data,
                min_device_batch=per_device_eval_batch_size)
            eval_metrics.append(metrics)

        # get eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(),
                                    eval_metrics)

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
Beispiel #30
0
def process_iterator(
    tag: str,
    item_ids: Sequence[str],
    iterator,
    rng: types.PRNGKey,
    state: model_utils.TrainState,
    step: int,
    render_fn: Any,
    summary_writer: tensorboard.SummaryWriter,
    save_dir: Optional[gpath.GPath],
    datasource: datasets.DataSource,
):
    """Process a dataset iterator and compute metrics."""
    save_dir = save_dir / f"{step:08d}" / tag if save_dir else None
    meters = collections.defaultdict(utils.ValueMeter)
    for i, (item_id, batch) in enumerate(zip(item_ids, iterator)):
        logging.info("[%s:%d/%d] Processing %s ", tag, i + 1, len(item_ids),
                     item_id)
        if tag == "test":
            test_rng = random.PRNGKey(step)
            shape = batch["origins"][..., :1].shape
            metadata = {}
            if datasource.use_appearance_id:
                appearance_id = random.choice(
                    test_rng, jnp.asarray(datasource.appearance_ids))
                logging.info("\tUsing appearance_id = %d", appearance_id)
                metadata["appearance"] = jnp.full(shape,
                                                  fill_value=appearance_id,
                                                  dtype=jnp.uint32)
            if datasource.use_warp_id:
                warp_id = random.choice(test_rng,
                                        jnp.asarray(datasource.warp_ids))
                logging.info("\tUsing warp_id = %d", warp_id)
                metadata["warp"] = jnp.full(shape,
                                            fill_value=warp_id,
                                            dtype=jnp.uint32)
            if datasource.use_camera_id:
                camera_id = random.choice(test_rng,
                                          jnp.asarray(datasource.camera_ids))
                logging.info("\tUsing camera_id = %d", camera_id)
                metadata["camera"] = jnp.full(shape,
                                              fill_value=camera_id,
                                              dtype=jnp.uint32)
            batch["metadata"] = metadata

        stats = process_batch(
            batch=batch,
            rng=rng,
            state=state,
            tag=tag,
            item_id=item_id,
            step=step,
            render_fn=render_fn,
            summary_writer=summary_writer,
            save_dir=save_dir,
            datasource=datasource,
        )
        if jax.host_id() == 0:
            for k, v in stats.items():
                meters[k].update(v)

    if jax.host_id() == 0:
        for meter_name, meter in meters.items():
            summary_writer.scalar(
                tag=f"metrics-eval/{meter_name}/{tag}",
                value=meter.reduce("mean"),
                step=step,
            )