def process_iterator(tag: str, item_ids: Sequence[str], iterator, rng: types.PRNGKey, state: model_utils.TrainState, step: int, render_fn: Any, summary_writer: tensorboard.SummaryWriter, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource): """Process a dataset iterator and compute metrics.""" save_dir = save_dir / f'{step:08d}' / tag if save_dir else None meters = collections.defaultdict(utils.ValueMeter) for i, (item_id, batch) in enumerate(zip(item_ids, iterator)): logging.info('[%s:%d/%d] Processing %s ', tag, i+1, len(item_ids), item_id) if tag == 'test': test_rng = random.PRNGKey(step) shape = batch['origins'][..., :1].shape metadata = {} if datasource.use_appearance_id: appearance_id = random.choice( test_rng, jnp.asarray(datasource.appearance_ids)) logging.info('\tUsing appearance_id = %d', appearance_id) metadata['appearance'] = jnp.full(shape, fill_value=appearance_id, dtype=jnp.uint32) if datasource.use_warp_id: warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids)) logging.info('\tUsing warp_id = %d', warp_id) metadata['warp'] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32) if datasource.use_camera_id: camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids)) logging.info('\tUsing camera_id = %d', camera_id) metadata['camera'] = jnp.full(shape, fill_value=camera_id, dtype=jnp.uint32) if datasource.use_time: timestamp = random.uniform(test_rng, minval=0.0, maxval=1.0) logging.info('\tUsing time = %d', timestamp) metadata['time'] = jnp.full( shape, fill_value=timestamp, dtype=jnp.uint32) batch['metadata'] = metadata stats = process_batch(batch=batch, rng=rng, state=state, tag=tag, item_id=item_id, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource) if jax.process_index() == 0: for k, v in stats.items(): meters[k].update(v) if jax.process_index() == 0: for meter_name, meter in meters.items(): summary_writer.scalar(tag=f'metrics-eval/{meter_name}/{tag}', value=meter.reduce('mean'), step=step)
def test_summarywriter_clipped_audio(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_audio_samples = onp.random.uniform(low=-2., high=2., size=(2, 48000, 2)) summary_writer.audio(tag='audio_test', audiodata=expected_audio_samples, step=1, max_outputs=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'audio_test') # Assert one audio files is parsed. self.assertLen(summary_value.tensor.string_val, 1) # actual_audio is clipped. actual_audio = tf.audio.decode_wav( summary_value.tensor.string_val[0]).audio self.assertFalse( onp.allclose(expected_audio_samples[0], actual_audio, atol=1e-04)) clipped_audio = onp.clip(onp.array(expected_audio_samples[0]), -1, 1) self.assertTrue(onp.allclose(clipped_audio, actual_audio, atol=1e-04))
def test_summarywriter_audio(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_audio_samples = onp.random.uniform(low=-1., high=1., size=(2, 48000, 2)) summary_writer.audio(tag='audio_test', audiodata=expected_audio_samples, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'audio_test') # Assert two audio files are parsed. self.assertLen(summary_value.tensor.string_val, 2) # Assert values. actual_audio_1 = tf.audio.decode_wav( summary_value.tensor.string_val[0]).audio self.assertTrue( onp.allclose(expected_audio_samples[0], actual_audio_1, atol=1e-04)) actual_audio_2 = tf.audio.decode_wav( summary_value.tensor.string_val[1]).audio self.assertTrue( onp.allclose(expected_audio_samples[1], actual_audio_2, atol=1e-04))
def train_for_one_epoch( dataset_source: dataset_source_lib.DatasetSource, optimizer: flax.optim.Optimizer, state: flax.nn.Collection, prng_key: jnp.ndarray, pmapped_train_step: _TrainStep, pmapped_update_ema: Optional[_EMAUpdateStep], moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage], summary_writer: tensorboard.SummaryWriter ) -> Tuple[flax.optim.Optimizer, flax.nn.Collection, Optional[efficientnet_optim.ExponentialMovingAverage]]: """Trains the model for one epoch. Args: dataset_source: Container for the training dataset. optimizer: The optimizer targeting the model to train. state: Current state associated with the model (contains the batch norm MA). prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual dropout mask). Is not used for shuffling the dataset. pmapped_train_step: A pmapped version of the `train_step` function (see its documentation for more details). pmapped_update_ema: Function to update the parameter moving average. Can be None if we don't use EMA. moving_averages: Parameters moving average if used. summary_writer: A Tensorboard SummaryWriter to use to log metrics. Returns: The updated optimizer (with the associated updated model), state and PRNG key. """ start_time = time.time() cnt = 0 train_metrics = [] for batch in dataset_source.get_train(use_augmentations=True): # Generate a PRNG key that will be rolled into the batch. step_key = jax.random.fold_in(prng_key, optimizer.state.step[0]) # Load and shard the TF batch. batch = tensorflow_to_numpy(batch) batch = shard_batch(batch) # Shard the step PRNG key. sharded_keys = common_utils.shard_prng_key(step_key) optimizer, state, metrics, lr = pmapped_train_step( optimizer, state, batch, sharded_keys) cnt += 1 if moving_averages is not None: moving_averages = pmapped_update_ema(optimizer, state, moving_averages) train_metrics.append(metrics) train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging. train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary['learning_rate'] = lr[0] current_step = int(optimizer.state.step[0]) info = 'Whole training step done in {} ({} steps)'.format( time.time()-start_time, cnt) logging.info(info) for metric_name, metric_value in train_summary.items(): summary_writer.scalar(metric_name, metric_value, current_step) summary_writer.flush() return optimizer, state, moving_averages
def test_summarywriter_image(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = onp.random.uniform(low=0., high=255., size=(30, 30, 3)) expected_img = expected_img.astype(onp.uint8) summary_writer.image(tag='image_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'image_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) self.assertTrue(onp.allclose(actual_img, expected_img))
def test_summarywriter_text(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) text = 'hello world.' summary_writer.text(tag='text_test', textdata=text, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'text_test') self.assertEqual( tensor_util.make_ndarray(summary_value.tensor).item().decode('utf-8'), text)
def test_summarywriter_scalar(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) # Write the scalar and check if the event exists and check data. float_value = 99.1232 summary_writer.scalar(tag='scalar_test', value=float_value, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'scalar_test') self.assertTrue(onp.allclose( tensor_util.make_ndarray(summary_value.tensor).item(), float_value))
def test_summarywriter_single_channel_image_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) img = onp.random.uniform(low=0., high=255., size=(30, 30, 1)) img = img.astype(onp.uint8) summary_writer.image(tag='2dimage_1channel_test', image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, '2dimage_1channel_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) # assert the image was increased in dimension self.assertEqual(actual_img.shape, (30, 30, 3))
def test_summarywriter_histogram_defaultbins(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) histogram = onp.arange(1000) # Histogram will be created for 30 (default) bins. summary_writer.histogram(tag='histogram_test', values=histogram, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'histogram_test') actual_histogram = tensor_util.make_ndarray(summary_value.tensor) self.assertTrue(actual_histogram.shape, (30, 3)) self.assertTrue( onp.allclose(actual_histogram[0], (0.0, 33.3, 34.0), atol=1e-01))
def test_summarywriter_image_float_pixel_values(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = onp.random.uniform(low=0., high=1., size=(30, 30, 3)) summary_writer.image(tag='image_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) # convert and scale expected_img appropriately to numpy uint8. expected_img = tf.image.convert_image_dtype( image=expected_img, dtype=onp.uint8) self.assertEqual(summary_value.tag, 'image_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) self.assertTrue(onp.allclose(actual_img, expected_img))
def test_summarywriter_histogram_2bins(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) histogram = onp.arange(1000) summary_writer.histogram( tag='histogram_test', values=histogram, step=1, bins=2) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'histogram_test') actual_histogram = tensor_util.make_ndarray(summary_value.tensor) self.assertTrue(actual_histogram.shape, (2, 3)) self.assertTrue( onp.allclose(actual_histogram[0], (0.0, 499.5, 500.0), atol=1e-01)) self.assertTrue( onp.allclose(actual_histogram[1], (499.5, 999.0, 500.0), atol=1e-01))
def test_summarywriter_multiple_2dimages_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) img = np.random.uniform(low=0., high=255., size=(2, 30, 30)) img = img.astype(np.uint8) summary_writer.image(tag='multiple_2dimages_test', image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'multiple_2dimages_test') actual_imgs = [ tf.image.decode_image(s) for s in summary_value.tensor.string_val[2:] ] # assert the images were increased in dimension self.assertEqual(np.stack(actual_imgs, axis=0).shape, (2, 30, 30, 3))
def test_summarywriter_multiple_images(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = np.random.uniform(low=0., high=255., size=(2, 30, 30, 3)) expected_img = expected_img.astype(np.uint8) summary_writer.image(tag='multiple_images_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'multiple_images_test') actual_imgs = [ tf.image.decode_image(s) for s in summary_value.tensor.string_val[2:] ] self.assertTrue( np.allclose(np.stack(actual_imgs, axis=0), expected_img))
def test_summarywriter_audio_sampled_output(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_audio_samples = onp.random.uniform( low=-1., high=1., size=(2, 48000, 2)) summary_writer.audio( tag='audio_test', audiodata=expected_audio_samples, step=1, max_outputs=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'audio_test') # Assert only the first audio clip is available. self.assertLen(summary_value.tensor.string_val, 2) # Assert values. actual_audio = tf.audio.decode_wav(summary_value.tensor.string_val[0]).audio self.assertTrue(onp.allclose( expected_audio_samples[0], actual_audio, atol=1e-04))
def _log_histograms(writer: tensorboard.SummaryWriter, model: models.NerfModel, state: model_utils.TrainState): """Log histograms to Tensorboard.""" step = int(state.optimizer.state.step) params = state.optimizer.target['model'] if 'appearance_encoder' in params: embeddings = params['appearance_encoder']['embed']['embedding'] writer.histogram('appearance_embedding', embeddings, step) if 'camera_encoder' in params: embeddings = params['camera_encoder']['embed']['embedding'] writer.histogram('camera_embedding', embeddings, step) if 'warp_field' in params and model.warp_metadata_encoder_type == 'glo': embeddings = params['warp_field']['metadata_encoder']['embed'][ 'embedding'] writer.histogram('warp_embedding', embeddings, step)
def local_train_loop(key, init_params, loss_fn, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): optimizer_def = optim.Adam() optimizer = optimizer_def.create(init_params) optimizer = util.maybe_load_checkpoint( logdir, optimizer, clobber_checkpoint=clobber_checkpoint) lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr) def train_step(optimizer, key): loss_val, loss_grad = jax.value_and_grad(loss_fn, argnums=0)(optimizer.target, key) new_optimizer = optimizer.apply_gradient(loss_grad, learning_rate=lr_fn( optimizer.state.step)) return loss_val, new_optimizer train_step = jit(train_step) sw = SummaryWriter(logdir) start = timeit.default_timer() first_step = optimizer.state.step for t in range(optimizer.state.step, num_steps): if t % checkpoint_every == 0 and t != first_step: checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) key, subkey = jax.random.split(key) try: loss_val, new_optimizer = train_step(optimizer, subkey) except FloatingPointError as e: print("Exception on step %d" % t) print(e) traceback.print_exc() checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) print("key ", subkey) sys.stdout.flush() sys.exit(1) optimizer = new_optimizer if t % summarize_every == 0: key, subkey = jax.random.split(key) print("Step %d loss: %0.4f" % (t, loss_val)) sw.scalar("loss", loss_val, step=t) summarize_fn(sw, t, optimizer.target, subkey) end = timeit.default_timer() if t == 0: steps_per_sec = 1. / (end - start) else: steps_per_sec = summarize_every / (end - start) print("Steps/sec: %0.2f" % steps_per_sec) sw.scalar("steps_per_sec", steps_per_sec, step=t) start = end sw.flush() sys.stdout.flush()
def parallel_train_loop(key, init_params, loss_fn, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): loss_fn = jax.jit(loss_fn) optimizer_def = optim.Adam() local_optimizer = optimizer_def.create(init_params) local_optimizer = util.maybe_load_checkpoint( logdir, local_optimizer, clobber_checkpoint=clobber_checkpoint) first_step = local_optimizer.state.step repl_optimizer = jax_utils.replicate(local_optimizer) lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr) @functools.partial(jax.pmap, axis_name="batch") def train_step(optimizer, key): key, subkey = jax.random.split(key) loss_grad = jax.grad(loss_fn, argnums=0)(optimizer.target, key) loss_grad = jax.lax.pmean(loss_grad, "batch") new_optimizer = optimizer.apply_gradient(loss_grad, learning_rate=lr_fn( optimizer.state.step)) return new_optimizer, subkey sw = SummaryWriter(logdir) repl_key = jax.pmap(jax.random.PRNGKey)(jnp.arange( jax.local_device_count())) start = timeit.default_timer() for t in range(first_step, num_steps): if t % checkpoint_every == 0 and t != first_step: optimizer = jax_utils.unreplicate(repl_optimizer) checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) repl_optimizer, repl_key = train_step(repl_optimizer, repl_key) if t % summarize_every == 0: key, subkey = jax.random.split(jax_utils.unreplicate(repl_key)) optimizer = jax_utils.unreplicate(repl_optimizer) loss_val = loss_fn(optimizer.target, key) print("Step %d loss: %0.4f" % (t, loss_val)) sw.scalar("loss", loss_val, step=t) summarize_fn(sw, t, optimizer.target, subkey) end = timeit.default_timer() if t == 0: steps_per_sec = 1. / (end - start) else: steps_per_sec = summarize_every / (end - start) print("Steps/sec: %0.2f" % steps_per_sec) sw.scalar("steps_per_sec", steps_per_sec, step=t) start = end sw.flush() sys.stdout.flush()
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained config and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." ) return output tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions # GPT2's vocab is odd, we need to resize it for fine-tuning model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) def get_initial_state(params): state = optimizer.init(params) return tuple(state), params # Get PartitionSpec for model params param_spec = set_partitions(unfreeze(model.params)) # Get the PyTree for opt_state, we don't actually initialize the opt_state yet. params_shapes = jax.tree_map(lambda x: x.shape, model.params) state_shapes = jax.eval_shape(get_initial_state, params_shapes) # get PartitionSpec for opt_state, this is very specific to adamw # TODO: optax returns different state for different optimizers, how can we handle this generically ? # or maybe we don't since in our examples we just use adamw or adafactor def get_opt_spec(x): if isinstance(x, dict): return param_spec return None opt_state_spec, param_spec = jax.tree_map( get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) ) # pjit the get_initial_state function to shard params and init # optimizer state in sharded way p_get_initial_state = pjit( get_initial_state, in_axis_resources=None, out_axis_resources=(opt_state_spec, param_spec), ) # hack: move the inital params to CPU to free up device memory # TODO: allow loading weights on CPU in pre-trained model model.params = jax.tree_map(lambda x: np.asarray(x), model.params) # mesh defination mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count()) # actually initialize the opt_state with mesh(mesh_devices, ("dp", "mp")): opt_state, params = p_get_initial_state(freeze(model.params)) # cross-entropy with z loss def loss_fn(logits, labels, z_loss=0): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] shift_labels = onehot(shift_labels, shift_logits.shape[-1]) shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True)) log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True)) log_softmax = shift_logits - log_z loss = -jnp.sum(shift_labels * log_softmax, axis=-1) loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss return loss.mean() # Define gradient update step fn # TODO: try to use TrainState instead of passing params and opt_state individually def train_step(params, opt_state, dropout_rng, batch, step): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels, z_loss=1.0) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grads = grad_fn(params) updates, new_opt_state = optimizer.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)} return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1 # Define eval fn def eval_step(input_ids, labels, params): logits = model(input_ids=input_ids, params=params, train=False)[0] loss = loss_fn(logits, labels) # metrics return {"loss": loss} p_train_step = pjit( train_step, in_axis_resources=(param_spec, opt_state_spec, None, None, None), out_axis_resources=(param_spec, opt_state_spec, None, None, None), donate_argnums=(0, 1), ) p_eval_step = pjit( eval_step, in_axis_resources=(None, None, param_spec), out_axis_resources=None, ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) global_step = 0 # we are not doing 2D parallelism (yet!), this just does model parallelism with mesh(mesh_devices, ("dp", "mp")): for _ in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_metrics = [] train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) params, opt_state, dropout_rng, train_metric, global_step = p_train_step( params, opt_state, dropout_rng, batch, global_step, ) train_metrics.append(train_metric) cur_step = global_step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_steps = len(eval_dataset) // eval_batch_size for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): batch = next(eval_loader) metrics = p_eval_step(batch["input_ids"], batch["labels"], params) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = stack_forest(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") logger.info( f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}" ) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(params) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message=f"Saving weights and logs of step {cur_step}", )
truncation=True, max_length=data_args.max_seq_length, ) tokenized_datasets = datasets.map( tokenize_function, input_columns=[text_column_name], batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node if has_tensorboard and jax.host_id() == 0: summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForLanguageModeling( tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) # Setup optimizer optimizer = Adam( learning_rate=training_args.learning_rate, weight_decay=training_args.weight_decay, beta1=training_args.adam_beta1, beta2=training_args.adam_beta2, ).create(model.params) # Create learning rate scheduler
def process_batch(*, batch: Dict[str, jnp.ndarray], rng: types.PRNGKey, state: model_utils.TrainState, tag: str, item_id: str, step: int, summary_writer: tensorboard.SummaryWriter, render_fn: Any, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource): """Process and plot a single batch.""" item_id = item_id.replace('/', '_') render = render_fn(state, batch, rng=rng) out = {} if jax.process_index() != 0: return out rgb = render['rgb'] acc = render['acc'] depth_exp = render['depth'] depth_med = render['med_depth'] colorize_depth = functools.partial(viz.colorize, cmin=datasource.near, cmax=datasource.far, invert=True) depth_exp_viz = colorize_depth(depth_exp) depth_med_viz = colorize_depth(depth_med) disp_exp_viz = viz.colorize(1.0 / depth_exp) disp_med_viz = viz.colorize(1.0 / depth_med) acc_viz = viz.colorize(acc, cmin=0.0, cmax=1.0) if save_dir: save_dir.mkdir(parents=True, exist_ok=True) image_utils.save_image(save_dir / f'rgb_{item_id}.png', image_utils.image_to_uint8(rgb)) image_utils.save_image(save_dir / f'depth_expected_viz_{item_id}.png', image_utils.image_to_uint8(depth_exp_viz)) image_utils.save_depth(save_dir / f'depth_expected_{item_id}.png', depth_exp) image_utils.save_image(save_dir / f'depth_median_viz_{item_id}.png', image_utils.image_to_uint8(depth_med_viz)) image_utils.save_depth(save_dir / f'depth_median_{item_id}.png', depth_med) summary_writer.image(f'rgb/{tag}/{item_id}', rgb, step) summary_writer.image(f'depth-expected/{tag}/{item_id}', depth_exp_viz, step) summary_writer.image(f'depth-median/{tag}/{item_id}', depth_med_viz, step) summary_writer.image(f'disparity-expected/{tag}/{item_id}', disp_exp_viz, step) summary_writer.image(f'disparity-median/{tag}/{item_id}', disp_med_viz, step) summary_writer.image(f'acc/{tag}/{item_id}', acc_viz, step) if 'rgb' in batch: rgb_target = batch['rgb'] mse = ((rgb - batch['rgb'])**2).mean() psnr = utils.compute_psnr(mse) ssim = compute_multiscale_ssim(rgb_target, rgb) out['mse'] = mse out['psnr'] = psnr out['ssim'] = ssim logging.info('\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f', mse, psnr, ssim) rgb_abs_error = viz.colorize( abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1) rgb_sq_error = viz.colorize( ((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1) summary_writer.image(f'rgb-target/{tag}/{item_id}', rgb_target, step) summary_writer.image(f'rgb-abs-error/{tag}/{item_id}', rgb_abs_error, step) summary_writer.image(f'rgb-sq-error/{tag}/{item_id}', rgb_sq_error, step) if 'depth' in batch: depth_target = batch['depth'] depth_target_viz = colorize_depth(depth_target[..., 0]) out['depth_abs'] = jnp.nanmean(jnp.abs(depth_target - depth_med)) summary_writer.image( f'depth-target/{tag}/{item_id}', depth_target_viz, step) depth_med_error = viz.colorize( abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image( f'depth-median-error/{tag}/{item_id}', depth_med_error, step) depth_exp_error = viz.colorize( abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image( f'depth-expected-error/{tag}/{item_id}', depth_exp_error, step) return out
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files this script will use the first column for the full texts and the second column for the # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). # if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file extension = data_args.train_file.split(".")[-1] if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] if data_args.test_file is not None: data_files["test"] = data_args.test_file extension = data_args.test_file.split(".")[-1] dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.model_name_or_path: model = FlaxAutoModelForSeq2SeqLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) else: model = FlaxAutoModelForSeq2SeqLM.from_config(config, seed=training_args.seed, dtype=getattr( jnp, model_args.dtype)) if model.config.decoder_start_token_id is None: raise ValueError( "Make sure that `config.decoder_start_token_id` is correctly defined" ) prefix = data_args.source_prefix if data_args.source_prefix is not None else "" # Preprocessing the datasets. # We need to tokenize inputs and targets. if training_args.do_train: column_names = dataset["train"].column_names elif training_args.do_eval: column_names = dataset["validation"].column_names elif training_args.do_predict: column_names = dataset["test"].column_names else: logger.info( "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`." ) return # Get the column names for input/target. dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) if data_args.text_column is None: text_column = dataset_columns[ 0] if dataset_columns is not None else column_names[0] else: text_column = data_args.text_column if text_column not in column_names: raise ValueError( f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" ) if data_args.summary_column is None: summary_column = dataset_columns[ 1] if dataset_columns is not None else column_names[1] else: summary_column = data_args.summary_column if summary_column not in column_names: raise ValueError( f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" ) # Temporarily set max_target_length for training. max_target_length = data_args.max_target_length # In Flax, for seq2seq models we need to pass `decoder_input_ids` # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here # for that dynamically import the `shift_tokens_right` function from the model file model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(examples): inputs = examples[text_column] targets = examples[summary_column] inputs = [prefix + inp for inp in inputs] model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np") # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer(targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np") model_inputs["labels"] = labels["input_ids"] decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id) model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs if training_args.do_train: if "train" not in dataset: raise ValueError("--do_train requires a train dataset") train_dataset = dataset["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select( range(data_args.max_train_samples)) train_dataset = train_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on train dataset", ) if training_args.do_eval: max_target_length = data_args.val_max_target_length if "validation" not in dataset: raise ValueError("--do_eval requires a validation dataset") eval_dataset = dataset["validation"] if data_args.max_eval_samples is not None: eval_dataset = eval_dataset.select( range(data_args.max_eval_samples)) eval_dataset = eval_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on validation dataset", ) if training_args.do_predict: max_target_length = data_args.val_max_target_length if "test" not in dataset: raise ValueError("--do_predict requires a test dataset") predict_dataset = dataset["test"] if data_args.max_predict_samples is not None: predict_dataset = predict_dataset.select( range(data_args.max_predict_samples)) predict_dataset = predict_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on prediction dataset", ) # Metric metric = load_metric("rouge") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] labels = [label.strip() for label in labels] # rougeLSum expects newline after each sentence preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] return preds, labels def compute_metrics(preds, labels): decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds, decoded_labels = postprocess_text( decoded_preds, decoded_labels) result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) # Extract a few results from ROUGE result = { key: value.mid.fmeasure * 100 for key, value in result.items() } prediction_lens = [ np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds ] result["gen_len"] = np.mean(prediction_lens) result = {k: round(v, 4) for k, v in result.items()} return result # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. # Note that this mask is specifically adapted for FlaxBart. # For FlaxT5, one should correct the layer norm parameter naming # accordingly - see `run_t5_mlm_flax.py` e.g. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) layer_norm_params = [(name, "scale") for name in [ "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm" ]] flat_mask = { path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): """ The label smoothing implementation is adapted from Flax's official example: https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 """ vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -(confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.sum() / padding_mask.sum() return loss # Define gradient update step fn def train_step(state, batch, label_smoothing_factor=0.0): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) } metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch, label_smoothing_factor=0.0): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) # summarize metrics metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Define generation function max_length = (data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length) num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(params, batch): model.params = params output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) return output_ids.sequences # Create parallel version of the train and eval step p_train_step = jax.pmap(partial( train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )) p_eval_step = jax.pmap( partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch") p_generate_step = jax.pmap(generate_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel & distributed) = {train_batch_size}" ) logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_time += time.time() - train_start train_metric = unreplicate(train_metric) epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) # ======================== Evaluating ============================== eval_metrics = [] eval_preds = [] eval_labels = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_steps = len(eval_dataset) // eval_batch_size for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) labels = batch["labels"] metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) # generation if data_args.predict_with_generate: generated_ids = p_generate_step(state.params, batch) eval_preds.extend( jax.device_get( generated_ids.reshape(-1, gen_kwargs["max_length"]))) eval_labels.extend( jax.device_get(labels.reshape(-1, labels.shape[-1]))) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # compute ROUGE metrics rouge_desc = "" if data_args.predict_with_generate: rouge_metrics = compute_metrics(eval_preds, eval_labels) eval_metrics.update(rouge_metrics) rouge_desc = " ".join([ f"Eval {key}: {value} |" for key, value in rouge_metrics.items() ]) # Print metrics and update progress bar desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) # ======================== Prediction loop ============================== if training_args.do_predict: logger.info("*** Predict ***") pred_metrics = [] pred_generations = [] pred_labels = [] pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size) pred_steps = len(predict_dataset) // eval_batch_size for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): # Model forward batch = next(pred_loader) labels = batch["labels"] metrics = p_eval_step(state.params, batch) pred_metrics.append(metrics) # generation if data_args.predict_with_generate: generated_ids = p_generate_step(state.params, batch) pred_generations.extend( jax.device_get( generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_labels.extend( jax.device_get(labels.reshape(-1, labels.shape[-1]))) # normalize prediction metrics pred_metrics = get_metrics(pred_metrics) pred_metrics = jax.tree_map(jnp.mean, pred_metrics) # compute ROUGE metrics rouge_desc = "" if data_args.predict_with_generate: rouge_metrics = compute_metrics(pred_generations, pred_labels) pred_metrics.update(rouge_metrics) rouge_desc = " ".join([ f"Predict {key}: {value} |" for key, value in rouge_metrics.items() ]) # Print metrics desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" logger.info(desc) # save final metrics in json if jax.process_index() == 0: rouge_metrics = { f"test_{metric_name}": value for metric_name, value in rouge_metrics.items() } path = os.path.join(training_args.output_dir, "test_results.json") with open(path, "w") as f: json.dump(rouge_metrics, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # set seed for random transforms and torch dataloaders set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Initialize datasets and pre-processing transforms # We use torchvision here for faster pre-processing # Note that here we are using some default pre-processing, for maximum accuray # one should tune this part and carefully select what transformations to use. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_dataset = torchvision.datasets.ImageFolder( data_args.train_dir, transforms.Compose([ transforms.RandomResizedCrop(data_args.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), ) eval_dataset = torchvision.datasets.ImageFolder( data_args.validation_dir, transforms.Compose([ transforms.Resize(data_args.image_size), transforms.CenterCrop(data_args.image_size), transforms.ToTensor(), normalize, ]), ) # Load pretrained model and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") if model_args.model_name_or_path: model = FlaxAutoModelForImageClassification.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) else: model = FlaxAutoModelForImageClassification.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs def collate_fn(examples): pixel_values = torch.stack([example[0] for example in examples]) labels = torch.tensor([example[1] for example in examples]) batch = {"pixel_values": pixel_values, "labels": labels} batch = {k: v.numpy() for k, v in batch.items()} return batch # Create data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) def loss_fn(logits, labels): loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) } metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics accuracy = (jnp.argmax(logits, axis=-1) == labels).mean() metrics = {"loss": loss, "accuracy": accuracy} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel & distributed) = {train_batch_size}" ) logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] steps_per_epoch = len(train_dataset) // train_batch_size train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) # train for batch in train_loader: batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_step_progress_bar.update(1) train_time += time.time() - train_start train_metric = unreplicate(train_metric) train_step_progress_bar.close() epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) # ======================== Evaluating ============================== eval_metrics = [] eval_steps = len(eval_dataset) // eval_batch_size eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) for batch in eval_loader: # Model forward batch = shard(batch) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) eval_step_progress_bar.update(1) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() desc = ( f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | " f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})") epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) if training_args.push_to_hub: repo.push_to_hub( commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
def main(): parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) elif model_args.text_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) model = FlaxHybridCLIP.from_text_vision_pretrained( model_args.text_model_name_or_path, model_args.vision_model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), text_from_pt=model_args.from_pt, vision_from_pt=model_args.from_pt, ) config = model.config # set seed for torch dataloaders set_seed(training_args.seed) # Initialize torchvision transforms and jit them for faster processing preprocess = Transform(config.vision_config.image_size) preprocess = torch.jit.script(preprocess) # Initialize the image-text dataset train_dataset = ImageTextDataset( data_args.data_dir, data_args.train_file, captions_per_image=2, transform=preprocess, ) eval_dataset = ImageTextDataset( data_args.data_dir, data_args.validation_file, captions_per_image=1, transform=preprocess, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Use collate function to tokenizer the text and convert the processed images to numpy def collate_fn(examples): pixel_values = torch.stack([example[0] for example in examples ]).permute(0, 2, 3, 1).numpy() captions = [example[1] for example in examples] inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np") batch = { "pixel_values": pixel_values, "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], } return batch # Create data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) # Enable tensorboard only on the master node if has_tensorboard and jax.process_index() == 0: summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) def cross_entropy(logits, axis): logprobs = jax.nn.log_softmax(logits, axis=axis) nll = jnp.diag(logprobs) ce = -jnp.mean(nll) return ce def clip_loss(similarity): loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2 return loss # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = clip_loss(logits) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) } metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): logits = model(**batch, params=params, train=False)[0] loss = clip_loss(logits) # summarize metrics metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel & distributed) = {train_batch_size}" ) logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 # Create sampling rng rng, input_rng = jax.random.split(rng) epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] steps_per_epoch = len(train_dataset) // train_batch_size train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) # train for batch in train_loader: batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_step_progress_bar.update(1) train_time += time.time() - train_start train_metric = unreplicate(train_metric) train_step_progress_bar.close() epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) # ======================== Evaluating ============================== eval_metrics = [] eval_steps = len(eval_dataset) // eval_batch_size eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) for batch in eval_loader: # Model forward batch = shard(batch) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) eval_step_progress_bar.update(1) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})" epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message=f"Saving weights and logs of epoch {epoch+1}", )
def process_batch( *, batch: Dict[str, jnp.ndarray], rng: types.PRNGKey, state: model_utils.TrainState, tag: str, item_id: str, step: int, summary_writer: tensorboard.SummaryWriter, render_fn: Any, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource, ): """Process and plot a single batch.""" rgb, depth_exp, depth_med, acc = render_fn(state, batch, rng=rng) out = {} if jax.host_id() != 0: return out colorize_depth = functools.partial(viz.colorize, cmin=datasource.near, cmax=datasource.far, invert=True) depth_exp_viz = colorize_depth(depth_exp[..., 0]) depth_med_viz = colorize_depth(depth_med[..., 0]) if save_dir: save_dir.mkdir(parents=True, exist_ok=True) image_utils.save_image(save_dir / f"rgb_{item_id}.png", image_utils.image_to_uint8(rgb)) image_utils.save_image( save_dir / f"depth_expected_viz_{item_id}.png", image_utils.image_to_uint8(depth_exp_viz), ) image_utils.save_depth(save_dir / f"depth_expected_{item_id}.png", depth_med[..., 0]) image_utils.save_image( save_dir / f"depth_median_viz_{item_id}.png", image_utils.image_to_uint8(depth_med_viz), ) image_utils.save_depth(save_dir / f"depth_median_{item_id}.png", depth_med[..., 0]) summary_writer.image(f"rgb/{tag}/{item_id}", rgb, step) summary_writer.image(f"depth-expected/{tag}/{item_id}", depth_exp_viz, step) summary_writer.image(f"depth-median/{tag}/{item_id}", depth_med_viz, step) summary_writer.image(f"acc/{tag}/{item_id}", acc, step) if "rgb" in batch: rgb_target = batch["rgb"] mse = ((rgb - batch["rgb"])**2).mean() psnr = utils.compute_psnr(mse) ssim = compute_multiscale_ssim(rgb_target, rgb) out["mse"] = mse out["psnr"] = psnr out["ssim"] = ssim logging.info("\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f", mse, psnr, ssim) rgb_abs_error = viz.colorize(abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1) rgb_sq_error = viz.colorize(((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1) summary_writer.image(f"rgb-target/{tag}/{item_id}", rgb_target, step) summary_writer.image(f"rgb-abs-error/{tag}/{item_id}", rgb_abs_error, step) summary_writer.image(f"rgb-sq-error/{tag}/{item_id}", rgb_sq_error, step) if "depth" in batch: depth_target = batch["depth"] depth_target_viz = colorize_depth(depth_target[..., 0]) out["depth_abs"] = jnp.nanmean(jnp.abs(depth_target - depth_med)) summary_writer.image(f"depth-target/{tag}/{item_id}", depth_target_viz, step) depth_med_error = viz.colorize(abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image(f"depth-median-error/{tag}/{item_id}", depth_med_error, step) depth_exp_error = viz.colorize(abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image(f"depth-expected-error/{tag}/{item_id}", depth_exp_error, step) rel_disp_pred = viz.colorize(1.0 / depth_exp[..., 0]) summary_writer.image(f"relative-disparity/{tag}/{item_id}", rel_disp_pred, step) return out
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) dataset["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.model_name_or_path: model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) else: model = FlaxAutoModelForCausalLM.from_config(config, seed=training_args.seed, dtype=getattr( jnp, model_args.dtype)) # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger( "transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." ) return output tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: sum(examples[k], []) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i:i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select( range(data_args.max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: eval_dataset = eval_dataset.select( range(data_args.max_eval_samples)) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. # Note that this mask is specifically adapted for FlaxGPT2. # For other models, one should correct the layer norm parameter naming # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng) def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] loss = optax.softmax_cross_entropy( shift_logits, onehot(shift_labels, shift_logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) } metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel & distributed) = {train_batch_size}" ) logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_steps = len(eval_dataset) // eval_batch_size for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) batch = shard(batch) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") # Print metrics and update progress bar desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})" epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message= f"Saving weights and logs of step {cur_step}", )
def main(): args = parse_args() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Handle the repository creation if args.push_to_hub: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named # label if at least two columns are provided. # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this # single column. You can easily tweak this behavior (see below) # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if args.task_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset("glue", args.task_name) else: # Loading the dataset from local csv or json file. data_files = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: data_files["validation"] = args.validation_file extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. # Labels if args.task_name is not None: is_regression = args.task_name == "stsb" if not is_regression: label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) else: num_labels = 1 else: # Trying to have good defaults here, don't hesitate to tweak to your needs. is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] if is_regression: num_labels = 1 else: # A useful fast method: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique label_list = raw_datasets["train"].unique("label") label_list.sort() # Let's sort it for determinism num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config) # Preprocessing the datasets if args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[args.task_name] else: # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: sentence1_key, sentence2_key = "sentence1", "sentence2" else: if len(non_label_column_names) >= 2: sentence1_key, sentence2_key = non_label_column_names[:2] else: sentence1_key, sentence2_key = non_label_column_names[0], None # Some models have set the order of the labels to use, so let's make sure we do use it. label_to_id = None if ( model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id and args.task_name is not None and not is_regression ): # Some have all caps in their config, some don't. label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): logger.info( f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " "Using it!" ) label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) elif args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} def preprocess_function(examples): # Tokenize the texts texts = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True) if "label" in examples: if label_to_id is not None: # Map labels to IDs (not necessary for GLUE tasks) result["labels"] = [label_to_id[l] for l in examples["label"]] else: # In all cases, rename the column to labels because the model will expect that. result["labels"] = examples["label"] return result processed_datasets = raw_datasets.map( preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names ) train_dataset = processed_datasets["train"] eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(args.output_dir) summary_writer.hparams(vars(args)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) def write_metric(train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(args.num_train_epochs) rng = jax.random.PRNGKey(args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = args.per_device_train_batch_size * jax.local_device_count() eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count() learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate ) state = create_train_state( model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay ) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") if args.task_name is not None: metric = load_metric("glue", args.task_name) else: metric = load_metric("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) for epoch in range(1, num_epochs + 1): logger.info(f"Epoch {epoch}") logger.info(" Training...") train_start = time.time() train_metrics = [] rng, input_rng = jax.random.split(rng) # train for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(metrics) train_time += time.time() - train_start logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(" Evaluating...") # evaluate for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): labels = batch.pop("labels") predictions = p_eval_step(state, batch) metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: jnp.array(v) for k, v in batch.items()} labels = batch.pop("labels") predictions = eval_step(unreplicate(state), batch) metric.add_batch(predictions=predictions, references=labels) eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}") cur_step = epoch * (len(train_dataset) // train_batch_size) # Save metrics if has_tensorboard and jax.process_index() == 0: write_metric(train_metrics, eval_metric, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(args.output_dir, params=params) tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) # save the eval metrics in json if jax.process_index() == 0: eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()} path = os.path.join(args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metric, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_ner", model_args, data_args, framework="flax") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called # 'tokens' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: # Loading the dataset from local csv or json file. data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] raw_datasets = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. if raw_datasets["train"] is not None: column_names = raw_datasets["train"].column_names features = raw_datasets["train"].features else: column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features if data_args.text_column_name is not None: text_column_name = data_args.text_column_name elif "tokens" in column_names: text_column_name = "tokens" else: text_column_name = column_names[0] if data_args.label_column_name is not None: label_column_name = data_args.label_column_name elif f"{data_args.task_name}_tags" in column_names: label_column_name = f"{data_args.task_name}_tags" else: label_column_name = column_names[1] # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # unique labels. def get_label_list(labels): unique_labels = set() for label in labels: unique_labels = unique_labels | set(label) label_list = list(unique_labels) label_list.sort() return label_list if isinstance(features[label_column_name].feature, ClassLabel): label_list = features[label_column_name].feature.names # No need to convert the labels since they are already ints. label_to_id = {i: i for i in range(len(label_list))} else: label_list = get_label_list(raw_datasets["train"][label_column_name]) label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, label2id=label_to_id, id2label={i: l for l, i in label_to_id.items()}, finetuning_task=data_args.task_name, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path if config.model_type in {"gpt2", "roberta"}: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, add_prefix_space=True, ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) model = FlaxAutoModelForTokenClassification.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) # Preprocessing the datasets # Tokenize all texts and align the labels with them. def tokenize_and_align_labels(examples): tokenized_inputs = tokenizer( examples[text_column_name], max_length=data_args.max_seq_length, padding="max_length", truncation=True, # We use this argument because the texts in our dataset are lists of words (with a label for each word). is_split_into_words=True, ) labels = [] for i, label in enumerate(examples[label_column_name]): word_ids = tokenized_inputs.word_ids(batch_index=i) previous_word_idx = None label_ids = [] for word_idx in word_ids: # Special tokens have a word id that is None. We set the label to -100 so they are automatically # ignored in the loss function. if word_idx is None: label_ids.append(-100) # We set the label for the first token of each word. elif word_idx != previous_word_idx: label_ids.append(label_to_id[label[word_idx]]) # For the other tokens in a word, we set the label to either the current label or -100, depending on # the label_all_tokens flag. else: label_ids.append(label_to_id[label[word_idx]] if data_args. label_all_tokens else -100) previous_word_idx = word_idx labels.append(label_ids) tokenized_inputs["labels"] = labels return tokenized_inputs processed_raw_datasets = raw_datasets.map( tokenize_and_align_labels, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=raw_datasets["train"].column_names, desc="Running tokenizer on dataset", ) train_dataset = processed_raw_datasets["train"] eval_dataset = processed_raw_datasets["validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info( f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(training_args.output_dir) summary_writer.hparams({ **training_args.to_dict(), **vars(model_args), **vars(data_args) }) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(training_args.num_train_epochs) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count( ) per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count( ) learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": learning_rate_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0, )) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") metric = evaluate.load("seqeval") def get_labels(y_pred, y_true): # Transform predictions and references tensos to numpy arrays # Remove ignored index (special tokens) true_predictions = [[ label_list[p] for (p, l) in zip(pred, gold_label) if l != -100 ] for pred, gold_label in zip(y_pred, y_true)] true_labels = [[ label_list[l] for (p, l) in zip(pred, gold_label) if l != -100 ] for pred, gold_label in zip(y_pred, y_true)] return true_predictions, true_labels def compute_metrics(): results = metric.compute() if data_args.return_entity_level_metrics: # Unpack nested dictionaries final_results = {} for key, value in results.items(): if isinstance(value, dict): for n, v in value.items(): final_results[f"{key}_{n}"] = v else: final_results[key] = value return final_results else: return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], } logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) train_time = 0 step_per_epoch = len(train_dataset) // train_batch_size total_steps = step_per_epoch * num_epochs epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # train for step, batch in enumerate( tqdm( train_data_collator(input_rng, train_dataset, train_batch_size), total=step_per_epoch, desc="Training...", position=1, )): state, train_metric, dropout_rngs = p_train_step( state, batch, dropout_rngs) train_metrics.append(train_metric) cur_step = (epoch * step_per_epoch) + (step + 1) if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:" f" {train_metric['learning_rate']})") train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: eval_metrics = {} # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2, ): labels = batch.pop("labels") predictions = pad_shard_unpad(p_eval_step)( state, batch, min_device_batch=per_device_eval_batch_size) predictions = np.array(predictions) labels[np.array(chain( *batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch( predictions=preds, references=refs, ) eval_metrics = compute_metrics() if data_args.return_entity_level_metrics: logger.info( f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}" ) else: logger.info( f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc:" f" {eval_metrics['accuracy']})") if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # Eval after training if training_args.do_eval: eval_metrics = {} eval_loader = eval_data_collator(eval_dataset, eval_batch_size) for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): labels = batch.pop("labels") predictions = pad_shard_unpad(p_eval_step)( state, batch, min_device_batch=per_device_eval_batch_size) predictions = np.array(predictions) labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch(predictions=preds, references=refs) eval_metrics = compute_metrics() if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax") if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO, datefmt="[%X]", ) # Log on each process the small summary: logger = logging.getLogger(__name__) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.config_name: config = T5Config.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer), use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = T5Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names else: column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Since we make sure that all sequences are of the same length, no attention_mask is needed. def tokenize_function(examples): return tokenizer(examples[text_column_name], return_attention_mask=False) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. # To ensure that the input length is `max_seq_length`, we need to increase the maximum length # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly. expanded_inputs_length, targets_length = compute_input_and_target_lengths( inputs_length=max_seq_length, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= expanded_inputs_length: total_length = (total_length // expanded_inputs_length) * expanded_inputs_length # Split by chunks of max_len. result = { k: [ t[i:i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length) ] for k, t in concatenated_examples.items() } return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value # might be slower to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxT5ForConditionalGeneration.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: config.vocab_size = len(tokenizer) model = FlaxT5ForConditionalGeneration( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForT5MLM( tokenizer=tokenizer, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, input_length=max_seq_length, target_length=targets_length, pad_token_id=model.config.pad_token_id, decoder_start_token_id=model.config.decoder_start_token_id, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() num_train_steps = len( tokenized_datasets["train"]) // train_batch_size * num_epochs num_of_hosts = jax.process_count() current_host_idx = jax.process_index() # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = set([ layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() ]) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])).mean() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # summarize metrics metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) # Avoid using jax.numpy here in case of TPU training train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) local_host_model_inputs = { key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx] for key, value in model_inputs.data.items() } # Model forward model_inputs = shard(local_host_model_inputs) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:" f" {train_metric['learning_rate'].mean()})") train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get( jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def process_iterator( tag: str, item_ids: Sequence[str], iterator, rng: types.PRNGKey, state: model_utils.TrainState, step: int, render_fn: Any, summary_writer: tensorboard.SummaryWriter, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource, ): """Process a dataset iterator and compute metrics.""" save_dir = save_dir / f"{step:08d}" / tag if save_dir else None meters = collections.defaultdict(utils.ValueMeter) for i, (item_id, batch) in enumerate(zip(item_ids, iterator)): logging.info("[%s:%d/%d] Processing %s ", tag, i + 1, len(item_ids), item_id) if tag == "test": test_rng = random.PRNGKey(step) shape = batch["origins"][..., :1].shape metadata = {} if datasource.use_appearance_id: appearance_id = random.choice( test_rng, jnp.asarray(datasource.appearance_ids)) logging.info("\tUsing appearance_id = %d", appearance_id) metadata["appearance"] = jnp.full(shape, fill_value=appearance_id, dtype=jnp.uint32) if datasource.use_warp_id: warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids)) logging.info("\tUsing warp_id = %d", warp_id) metadata["warp"] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32) if datasource.use_camera_id: camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids)) logging.info("\tUsing camera_id = %d", camera_id) metadata["camera"] = jnp.full(shape, fill_value=camera_id, dtype=jnp.uint32) batch["metadata"] = metadata stats = process_batch( batch=batch, rng=rng, state=state, tag=tag, item_id=item_id, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource, ) if jax.host_id() == 0: for k, v in stats.items(): meters[k].update(v) if jax.host_id() == 0: for meter_name, meter in meters.items(): summary_writer.scalar( tag=f"metrics-eval/{meter_name}/{tag}", value=meter.reduce("mean"), step=step, )