def main(): parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments)) ( model_args, data_args, train_args, log_args, path_args, remaining_strings, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) # SageMaker may have some extra strings. TODO: Test this on SM. assert len(remaining_strings ) == 0, f"The args {remaining_strings} could not be parsed." tf.random.set_seed(train_args.seed) tf.autograph.set_verbosity(0) # Settings init parse_bool = lambda arg: arg == "true" do_gradient_accumulation = train_args.gradient_accumulation_steps > 1 do_xla = not parse_bool(train_args.skip_xla) do_eager = parse_bool(train_args.eager) skip_sop = parse_bool(train_args.skip_sop) skip_mlm = parse_bool(train_args.skip_mlm) pre_layer_norm = parse_bool(model_args.pre_layer_norm) fast_squad = parse_bool(log_args.fast_squad) dummy_eval = parse_bool(log_args.dummy_eval) is_sagemaker = path_args.filesystem_prefix.startswith("/opt/ml") disable_tqdm = is_sagemaker global max_grad_norm max_grad_norm = train_args.max_grad_norm # Horovod init hvd.init() gpus = tf.config.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU") # XLA, AutoGraph tf.config.optimizer.set_jit(do_xla) tf.config.experimental_run_functions_eagerly(do_eager) if hvd.rank() == 0: # Run name should only be used on one process to avoid race conditions current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") platform = "sm" if is_sagemaker else "eks" if skip_sop: loss_str = "-skipsop" elif skip_mlm: loss_str = "-skipmlm" else: loss_str = "" if log_args.run_name is None: metadata = ( f"{model_args.model_type}" f"-{model_args.model_size}" f"-{model_args.load_from}" f"-{hvd.size()}gpus" f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch" f"-{train_args.learning_rate}maxlr" f"-{train_args.learning_rate_decay_power}power" f"-{train_args.optimizer}opt" f"-{train_args.total_steps}steps" f"-{'preln' if pre_layer_norm else 'postln'}" f"{loss_str}" f"-{model_args.hidden_dropout_prob}dropout") run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}" else: run_name = log_args.run_name # Logging should only happen on a single process # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time level = logging.INFO format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s" handlers = [ logging.FileHandler( os.path.join(path_args.filesystem_prefix, path_args.log_dir, f"{run_name}.log")), TqdmLoggingHandler(), ] logging.basicConfig(level=level, format=format, handlers=handlers) # Check that arguments passed in properly, only after registering the alert_func and logging assert not (skip_sop and skip_mlm), "Cannot use --skip_sop and --skip_mlm" wrap_global_functions(do_gradient_accumulation) # Create optimizer and enable AMP loss scaling. if train_args.optimizer == "lamb": optimizer = get_lamb_optimizer(train_args) elif train_args.optimizer == "adamw": optimizer = get_adamw_optimizer(train_args) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale="dynamic") gradient_accumulator = GradientAccumulator() loaded_optimizer_weights = None model = create_model(model_class=TFAutoModelForPreTraining, model_args=model_args) tokenizer = create_tokenizer(model_args.model_type) if model_args.load_from == "checkpoint": checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path) model_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix( checkpoint_path) if hvd.rank() == 0: model.load_weights(model_ckpt) if model_args.load_optimizer_state == "true": loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True) # We do not set the weights yet, we have to do a first step to initialize the optimizer. # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories # Move to same folder structure and remove if/else train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord") validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord") train_filenames = glob.glob(train_glob) validation_filenames = glob.glob(validation_glob) train_dataset = get_dataset_from_tfrecords( model_type=model_args.model_type, filenames=train_filenames, max_seq_length=data_args.max_seq_length, max_predictions_per_seq=data_args.max_predictions_per_seq, per_gpu_batch_size=train_args.per_gpu_batch_size, ) # Of shape [per_gpu_batch_size, ...] # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, per_gpu_batch_size, ...] train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps) # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps. train_dataset = train_dataset.prefetch(buffer_size=8) # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks if hvd.rank() == 0: validation_dataset = get_dataset_from_tfrecords( model_type=model_args.model_type, filenames=validation_filenames, max_seq_length=data_args.max_seq_length, max_predictions_per_seq=data_args.max_predictions_per_seq, per_gpu_batch_size=train_args.per_gpu_batch_size, ) # validation_dataset = validation_dataset.batch(1) validation_dataset = validation_dataset.prefetch(buffer_size=8) pbar = tqdm.tqdm(total=train_args.total_steps, disable=disable_tqdm) summary_writer = None # Only create a writer if we make it through a successful step logger.info(f"Starting training, job name {run_name}") i = 1 start_time = time.perf_counter() for batch in train_dataset: learning_rate = optimizer.learning_rate( step=tf.constant(i, dtype=tf.float32)) # weight_decay = wd_schedule(step=tf.constant(i, dtype=tf.float32)) loss_scale = optimizer.loss_scale() loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step( model=model, optimizer=optimizer, gradient_accumulator=gradient_accumulator, batch=batch, gradient_accumulation_steps=train_args.gradient_accumulation_steps, skip_sop=skip_sop, skip_mlm=skip_mlm, ) # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors if i == 1: if hvd.rank() == 0 and loaded_optimizer_weights is not None: optimizer.set_weights(loaded_optimizer_weights) hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(optimizer.variables(), root_rank=0) i = optimizer.get_weights()[0] is_final_step = i >= train_args.total_steps do_squad = (log_args.squad_frequency != 0) and ( (i % log_args.squad_frequency == 0) or is_final_step) # Squad requires all the ranks to train, but results are only returned on rank 0 if do_squad: squad_results = get_squad_results_while_pretraining( model=model, tokenizer=tokenizer, model_size=model_args.model_size, filesystem_prefix=path_args.filesystem_prefix, step=i, dataset=data_args.squad_version, fast=log_args.fast_squad, dummy_eval=log_args.dummy_eval, ) if hvd.rank() == 0: squad_exact, squad_f1 = squad_results["exact"], squad_results[ "f1"] logger.info( f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}" ) # Re-wrap autograph so it doesn't get arg mismatches wrap_global_functions(do_gradient_accumulation) gc.collect() if hvd.rank() == 0: do_log = i % log_args.log_frequency == 0 do_checkpoint = (log_args.checkpoint_frequency != 0) and ( (i % log_args.checkpoint_frequency == 0) or is_final_step) do_validation = (log_args.validation_frequency != 0) and ( (i % log_args.validation_frequency == 0) or is_final_step) pbar.update(1) description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}" pbar.set_description(description) if do_log: elapsed_time = time.perf_counter() - start_time if i == 1: logger.info(f"First step: {elapsed_time:.3f} secs") else: it_per_sec = log_args.log_frequency / elapsed_time logger.info( f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}" ) start_time = time.perf_counter() if do_checkpoint: checkpoint_prefix = os.path.join(path_args.filesystem_prefix, path_args.checkpoint_dir, f"{run_name}-step{i}") model_ckpt = f"{checkpoint_prefix}.ckpt" optimizer_ckpt = f"{checkpoint_prefix}-optimizer.npy" logger.info( f"Saving model at {model_ckpt}, optimizer at {optimizer_ckpt}" ) model.save_weights(model_ckpt) # model.load_weights(model_ckpt) optimizer_weights = optimizer.get_weights() np.save(optimizer_ckpt, optimizer_weights) # optimizer.set_weights(optimizer_weights) if do_validation: val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation( model=model, validation_dataset=validation_dataset, skip_sop=skip_sop, skip_mlm=skip_mlm, ) description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}" logger.info(f"Validation step {i} -- {description}") # Create summary_writer after the first step if summary_writer is None: summary_writer = tf.summary.create_file_writer( os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name)) config = { **asdict(model_args), **asdict(data_args), **asdict(train_args), **asdict(log_args), "global_batch_size": train_args.per_gpu_batch_size * hvd.size(), } if is_wandb_available(): wandb.init(config=config, project=model_args.model_type) wandb.run.save() wandb_run_name = wandb.run.name train_metrics = { "weight_norm": weight_norm, "grad_norm": grad_norm, "loss_scale": loss_scale, "learning_rate": learning_rate, "train/loss": loss, "train/mlm_loss": mlm_loss, "train/mlm_acc": mlm_acc, "train/sop_loss": sop_loss, "train/sop_acc": sop_acc, } all_metrics = {**train_metrics} if do_validation: val_metrics = { "val/loss": val_loss, "val/mlm_loss": val_mlm_loss, "val/mlm_acc": val_mlm_acc, "val/sop_loss": val_sop_loss, "val/sop_acc": val_sop_acc, } all_metrics = {**all_metrics, **val_metrics} if do_squad: squad_metrics = { "squad/f1": squad_f1, "squad/exact": squad_exact, } all_metrics = {**all_metrics, **squad_metrics} # Log to TensorBoard with summary_writer.as_default(): for name, val in all_metrics.items(): tf.summary.scalar(name, val, step=i) # Log to Weights & Biases if is_wandb_available(): wandb.log({"step": i, **all_metrics}) i += 1 if is_final_step: break if hvd.rank() == 0: pbar.close() logger.info(f"Finished pretraining, job name {run_name}")
def main(): parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments) ) ( model_args, data_args, train_args, log_args, path_args, remaining_strings, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) # SageMaker may have some extra strings. TODO: Test this on SM. assert len(remaining_strings) == 0, f"The args {remaining_strings} could not be parsed." hvd.init() gpus = tf.config.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU") if train_args.eager == "true": tf.config.experimental_run_functions_eagerly(True) tokenizer = ElectraTokenizerFast.from_pretrained("bert-base-uncased") gen_config = ElectraConfig.from_pretrained(f"google/electra-{model_args.model_size}-generator") dis_config = ElectraConfig.from_pretrained( f"google/electra-{model_args.model_size}-discriminator" ) gen = TFElectraForMaskedLM(config=gen_config) dis = TFElectraForPreTraining(config=dis_config) optimizer = get_adamw_optimizer(train_args) # Tie the weights if model_args.electra_tie_weights == "true": gen.electra.embeddings = dis.electra.embeddings loaded_optimizer_weights = None if model_args.load_from == "checkpoint": checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path) dis_ckpt, gen_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(checkpoint_path) if hvd.rank() == 0: dis.load_weights(dis_ckpt) gen.load_weights(gen_ckpt) loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True) start_time = time.perf_counter() if hvd.rank() == 0: # Logging should only happen on a single process # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time level = logging.INFO format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s" handlers = [ TqdmLoggingHandler(), ] summary_writer = None # Only create a writer if we make it through a successful step logging.basicConfig(level=level, format=format, handlers=handlers) wandb_run_name = None current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") if log_args.run_name is None: metadata = ( f"electra-{hvd.size()}gpus" f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch" f"-{train_args.total_steps}steps" ) run_name = ( f"{current_time}-{metadata}-{train_args.name if train_args.name else 'unnamed'}" ) else: run_name = log_args.run_name logger.info(f"Training with dataset at {path_args.train_dir}") logger.info(f"Validating with dataset at {path_args.val_dir}") train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord*") validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord*") train_filenames = glob.glob(train_glob) validation_filenames = glob.glob(validation_glob) logger.info( f"Number of train files {len(train_filenames)}, number of validation files {len(validation_filenames)}" ) tf_train_dataset = get_dataset_from_tfrecords( model_type=model_args.model_type, filenames=train_filenames, per_gpu_batch_size=train_args.per_gpu_batch_size, max_seq_length=data_args.max_seq_length, ) tf_train_dataset = tf_train_dataset.prefetch(buffer_size=8) if hvd.rank() == 0: tf_val_dataset = get_dataset_from_tfrecords( model_type=model_args.model_type, filenames=validation_filenames, per_gpu_batch_size=train_args.per_gpu_batch_size, max_seq_length=data_args.max_seq_length, ) tf_val_dataset = tf_val_dataset.prefetch(buffer_size=8) wandb_run_name = None step = 1 for batch in tf_train_dataset: learning_rate = optimizer.learning_rate(step=tf.constant(step, dtype=tf.float32)) ids = batch["input_ids"] attention_mask = batch["attention_mask"] train_result = train_step( optimizer=optimizer, gen=gen, dis=dis, ids=ids, attention_mask=attention_mask, mask_token_id=tokenizer.mask_token_id, ) if step == 1: # Horovod broadcast if hvd.rank() == 0 and loaded_optimizer_weights is not None: optimizer.set_weights(loaded_optimizer_weights) hvd.broadcast_variables(gen.variables, root_rank=0) hvd.broadcast_variables(dis.variables, root_rank=0) hvd.broadcast_variables(optimizer.variables(), root_rank=0) step = optimizer.get_weights()[0] is_final_step = step >= train_args.total_steps if hvd.rank() == 0: do_log = step % log_args.log_frequency == 0 do_checkpoint = (step > 1) and ( (step % log_args.checkpoint_frequency == 0) or is_final_step ) do_validation = step % log_args.validation_frequency == 0 if do_log: elapsed_time = time.perf_counter() - start_time # Off for first log it_s = log_args.log_frequency / elapsed_time start_time = time.perf_counter() description = f"Step {step} -- gen_loss: {train_result.gen_loss:.3f}, dis_loss: {train_result.dis_loss:.3f}, gen_acc: {train_result.gen_acc:.3f}, dis_acc: {train_result.dis_acc:.3f}, it/s: {it_s:.3f}\n" logger.info(description) if do_validation: for batch in tf_val_dataset.take(1): val_ids = batch["input_ids"] val_attention_mask = batch["attention_mask"] val_result = val_step( gen=gen, dis=dis, ids=val_ids, attention_mask=val_attention_mask, mask_token_id=tokenizer.mask_token_id, ) log_example( tokenizer, val_ids, val_result.masked_ids, val_result.corruption_mask, val_result.gen_ids, val_result.dis_preds, ) description = f"VALIDATION, Step {step} -- val_gen_loss: {val_result.gen_loss:.3f}, val_dis_loss: {val_result.dis_loss:.3f}, val_gen_acc: {val_result.gen_acc:.3f}, val_dis_acc: {val_result.dis_acc:.3f}\n" logger.info(description) train_metrics = { "learning_rate": learning_rate, "train/loss": train_result.loss, "train/gen_loss": train_result.gen_loss, "train/dis_loss": train_result.dis_loss, "train/gen_acc": train_result.gen_acc, "train/dis_acc": train_result.dis_acc, } all_metrics = {**train_metrics} if do_validation: val_metrics = { "val/loss": val_result.loss, "val/gen_loss": val_result.gen_loss, "val/dis_loss": val_result.dis_loss, "val/gen_acc": val_result.gen_acc, "val/dis_acc": val_result.dis_acc, } all_metrics = {**all_metrics, **val_metrics} if do_log: all_metrics = {"it_s": it_s, **all_metrics} if is_wandb_available(): if wandb_run_name is None: config = { **asdict(model_args), **asdict(data_args), **asdict(train_args), **asdict(log_args), **asdict(path_args), "global_batch_size": train_args.per_gpu_batch_size * hvd.size(), "n_gpus": hvd.size(), } wandb.init(config=config, project="electra") wandb.run.save() wandb_run_name = wandb.run.name wandb.log({"step": step, **all_metrics}) # Create summary_writer after the first step if summary_writer is None: summary_writer = tf.summary.create_file_writer( os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name) ) config = { **asdict(model_args), **asdict(data_args), **asdict(train_args), **asdict(log_args), **asdict(path_args), "global_batch_size": train_args.per_gpu_batch_size * hvd.size(), "n_gpus": hvd.size(), } # Log to TensorBoard with summary_writer.as_default(): for name, val in all_metrics.items(): tf.summary.scalar(name, val, step=step) if do_checkpoint: dis_model_ckpt = os.path.join( path_args.filesystem_prefix, path_args.checkpoint_dir, f"{run_name}-step{step}-discriminator.ckpt", ) gen_model_ckpt = os.path.join( path_args.filesystem_prefix, path_args.checkpoint_dir, f"{run_name}-step{step}-generator.ckpt", ) optimizer_ckpt = os.path.join( path_args.filesystem_prefix, path_args.checkpoint_dir, f"{run_name}-step{step}-optimizer.npy", ) logger.info( f"Saving discriminator model at {dis_model_ckpt}, generator model at {gen_model_ckpt}, optimizer at {optimizer_ckpt}" ) dis.save_weights(dis_model_ckpt) gen.save_weights(gen_model_ckpt) np.save(optimizer_ckpt, optimizer.get_weights()) step += 1 if is_final_step: break