def test_total_num_batches(self): batch_config = BatchConfig(micro_batch_size=2, num_replicas=2, gradient_accumulation_count=2, dataset_size=16, global_batches_per_log=1) assert batch_config.num_micro_batches_per_epoch == 8 assert batch_config.steps_per_execution == 4 assert batch_config.epochs == 28803072 assert batch_config.total_num_micro_batches == 230424576 batch_config = BatchConfig(micro_batch_size=2, num_replicas=2, gradient_accumulation_count=2, dataset_size=16, global_batches_per_log=2) assert batch_config.num_micro_batches_per_epoch == 8 assert batch_config.steps_per_execution == 8 assert batch_config.epochs == 28803072 assert batch_config.total_num_micro_batches == 230424576 batch_config = BatchConfig(micro_batch_size=2, num_replicas=2, gradient_accumulation_count=2, dataset_size=18, global_batches_per_log=2) assert batch_config.num_micro_batches_per_epoch == 8 assert batch_config.steps_per_execution == 8 assert batch_config.total_num_micro_batches == 204821840
def test_get_num_micro_batches_per_epoch(self): batch_config = BatchConfig(micro_batch_size=3, num_replicas=2, gradient_accumulation_count=3, dataset_size=18) assert batch_config.num_micro_batches_per_epoch == 6 batch_config = BatchConfig(micro_batch_size=3, num_replicas=2, gradient_accumulation_count=3, dataset_size=20) assert batch_config.num_micro_batches_per_epoch == 6
def test_raise_error_for_no_phase(self): with pytest.raises(ValueError): BatchConfig(micro_batch_size=2, num_replicas=2, gradient_accumulation_count=2, dataset_size=40, task=Task.OTHER)
def test_custom_num_train_steps(self): batch_config = BatchConfig(micro_batch_size=2, num_replicas=2, gradient_accumulation_count=2, dataset_size=8, total_num_train_samples=80) assert batch_config.num_train_steps == 10 assert batch_config.epochs == 10 assert batch_config.total_num_micro_batches == 40
def test_num_train_steps(self): batch_config = BatchConfig(micro_batch_size=2, num_replicas=2, gradient_accumulation_count=2, dataset_size=16) assert batch_config.num_train_steps == 57606144 assert batch_config.epochs == 28803072 assert batch_config.num_micro_batches_per_epoch == 8 assert batch_config.total_num_micro_batches == 230424576
def test_micro_batch_size(self): batch_config = BatchConfig(micro_batch_size=3, num_replicas=1, gradient_accumulation_count=1, dataset_size=18) assert batch_config.micro_batch_size == 3
def test_calc_global_batch_size(self): batch_config = BatchConfig(micro_batch_size=3, num_replicas=4, gradient_accumulation_count=2, dataset_size=40) assert batch_config.global_batch_size == 24
def test_gradient_accumulation_count(self): batch_config = BatchConfig(micro_batch_size=1, num_replicas=1, gradient_accumulation_count=2, dataset_size=40) assert batch_config.gradient_accumulation_count == 2
def test_num_micro_batches_per_weight_update(self): batch_config = BatchConfig(micro_batch_size=1, num_replicas=4, gradient_accumulation_count=4, dataset_size=40) assert batch_config.num_micro_batches_per_weight_update == 4 * 4
def pretrain(**config): # Get required options micro_batch_size = config["micro_batch_size"] replicas = config["replicas"] grad_acc_steps_per_replica = config["grad_acc_steps_per_replica"] optimizer_opts = config["optimizer_opts"] use_outlining = config["use_outlining"] replicated_tensor_sharding = config["replicated_tensor_sharding"] fp_exceptions = config["fp_exceptions"] bert_config = config["bert_config"] wandb_opts = config["wandb_opts"] pipeline_stages = config["pipeline_stages"] device_mapping = config["device_mapping"] # Get optional options total_num_train_samples = config.get("total_num_train_samples", None) save_ckpt_path = config.get("save_ckpt_path", Path(__file__).parent.joinpath("checkpoints").absolute()) pretrained_ckpt_path = config.get("pretrained_ckpt_path", None) ckpt_every_n_steps_per_execution = config.get("ckpt_every_n_steps_per_execution", 2000) universal_run_name = config.get( "name", f"{Path(config['config']).stem}-{wandb_opts['init']['name']}") universal_run_name += f"-{datetime.now().strftime('%Y%m%d_%H%M%S')}" print(f"Universal name for run: {universal_run_name}") set_random_seeds(config["seed"]) num_pipeline_stages = max(device_mapping) + 1 num_ipus = replicas * num_pipeline_stages bert_config = BertConfig(**bert_config, hidden_act=ipu.nn_ops.gelu) dataset, filenames = get_pretraining_dataset(micro_batch_size=micro_batch_size, dataset_dir=config["dataset_dir"], max_seq_length=bert_config.max_seq_length, max_predictions_per_seq=bert_config.max_predictions_per_seq, distributed_worker_count=1, seed=config["seed"], data_type=tf.float16) num_samples = get_dataset_files_count(filenames) if bert_config.max_seq_length == 128: task = Task.PRETRAIN_PHASE_ONE elif bert_config.max_seq_length == 384: task = Task.PRETRAIN_PHASE_TWO else: raise ValueError("Sequence length must be 128 or 384") batch_config = BatchConfig(micro_batch_size=micro_batch_size, num_replicas=replicas, gradient_accumulation_count=grad_acc_steps_per_replica, dataset_size=num_samples, global_batches_per_log=config["global_batches_per_log"], total_num_train_samples=total_num_train_samples, task=task) policy = tf.keras.mixed_precision.Policy("float16") tf.keras.mixed_precision.set_global_policy(policy) strategy = create_ipu_strategy(num_ipus, fp_exceptions=fp_exceptions, enable_recomputation=config["enable_recomputation"], min_remote_tensor_size=config["min_remote_tensor_size"]) with strategy.scope(): model = IpuTFBertForPreTraining(config=bert_config) # Convert subclass model to functional, expand main layers to enable pipelining, and replace some layers to # optimise performance. model = convert_tf_bert_model( model, dataset, post_process_bert_input_layer, replace_layers=config["replace_layers"], use_outlining=use_outlining, embedding_serialization_factor=config["embedding_serialization_factor"] ) # Load from pretrained checkpoint if requested. if pretrained_ckpt_path: print("Attempting to load pretrained checkpoint from" f" path {pretrained_ckpt_path}") load_checkpoint_into_model(model, pretrained_ckpt_path) else: if task == Task.PRETRAIN_PHASE_TWO: print("WARNING: Phase 2 pre-training should be done from a completed Phase 1 checkpoint. " "Please specify the path to the Phase 1 checkpoint with 'pretrained_ckpt_path' in the config or " "as a command line argument.") if num_pipeline_stages > 1: # Configure pipeline stages. pipeline_assigner = PipelineStagesAssigner(PIPELINE_ALLOCATE_PREVIOUS, PIPELINE_NAMES) assignments = model.get_pipeline_stage_assignment() assignments = pipeline_assigner.assign_pipeline_stages(assignments, pipeline_stages) model.set_pipeline_stage_assignment(assignments) model.print_pipeline_stage_assignment_summary() poplar_options_per_pipeline_stage = get_poplar_options_per_pipeline_stage( num_pipeline_stages, device_mapping, config["matmul_available_memory_proportion_per_pipeline_stage"], config["matmul_partials_type"]) model.set_pipelining_options( gradient_accumulation_steps_per_replica=batch_config.gradient_accumulation_count, pipeline_schedule=ipu.ops.pipelining_ops.PipelineSchedule.Grouped, device_mapping=device_mapping, offload_weight_update_variables=config["optimizer_state_offchip"], replicated_optimizer_state_sharding=replicated_tensor_sharding, forward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage, backward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage, recomputation_mode=ipu.pipelining_ops.RecomputationMode.RecomputeAndBackpropagateInterleaved, ) # Prepare losses and wrap them in an out-feed queue. nsp_loss_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue() nsp_loss = wrap_loss_in_enqueuer(NSPLossFunction, nsp_loss_outfeed_queue, ["nsp_loss_average"])() mlm_loss_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue() mlm_loss = wrap_loss_in_enqueuer(MLMLossFunction, mlm_loss_outfeed_queue, ["mlm_loss_average"])() # Prepare learning rate and wrap it in an out-feed queue. config['learning_rate']['lr_schedule_params']['total_steps'] = batch_config.num_train_steps lr_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue(outfeed_mode=ipu.ipu_outfeed_queue.IPUOutfeedMode.LAST) lr_scheduler = get_lr_scheduler(scheduler_name=config["learning_rate"]["lr_schedule"], schedule_params=config["learning_rate"]["lr_schedule_params"], queue=lr_outfeed_queue) # Prepare optimizer. outline_optimizer_apply_gradients = False if replicated_tensor_sharding else use_outlining optimizer = get_optimizer( optimizer_opts["name"], grad_acc_steps_per_replica, replicas, lr_scheduler, outline_optimizer_apply_gradients, loss_scaling=config["loss_scaling"], weight_decay_rate=optimizer_opts["params"]["weight_decay_rate"], ) # Compile the model. model.compile( optimizer=optimizer, loss={"nsp___cls": nsp_loss, "mlm___cls": mlm_loss}, steps_per_execution=batch_config.steps_per_execution, ) # Set up callbacks callbacks = CallbackFactory.get_callbacks( universal_run_name=universal_run_name, batch_config=batch_config, model=model, checkpoint_path=save_ckpt_path, ckpt_every_n_steps_per_execution=ckpt_every_n_steps_per_execution, outfeed_queues=[lr_outfeed_queue, nsp_loss_outfeed_queue, mlm_loss_outfeed_queue], config=config, ) # Print configs to be logged in wandb's terminal. print(config) print(f"Training batch config:\n{batch_config}") # Train the model # In order to achieve a specific number of steps, we set the number of # epochs to 1 and the steps per epoch to the number of steps we require. print("Forcing `model.fit` to run a particular number of steps by" " running a single 'epoch' with the number of steps we" " require. This allows running a fraction of actual" " epochs.") history = model.fit(dataset, steps_per_epoch=batch_config.total_num_micro_batches, epochs=1, callbacks=callbacks) return history
def fine_tune_squad(**config): # Get checkpoint from Hugging Face's models repo or from our own checkpoint. bert_model_name = config.get("bert_model_name", None) pretrained_ckpt_path = config.get("pretrained_ckpt_path", None) # This can be passed from both cli and config file. assert bert_model_name is not None or pretrained_ckpt_path is not None, \ "SQuAD requires a pretrained model, either `bert_model_name` (via config file) or `pretrained_ckpt_path` " \ "(via config file or `--pretrained-ckpt-path` command line argument) but none provided." assert (bert_model_name is not None and pretrained_ckpt_path is None) \ or (bert_model_name is None and pretrained_ckpt_path is not None), \ f"Only one checkpoint is accepted, but two provided: `bert_model_name`={bert_model_name}, " \ f"and `pretrained_ckpt_path`={pretrained_ckpt_path}." if pretrained_ckpt_path is not None: bert_config_params = config["bert_config"] # Get required options micro_batch_size = config["micro_batch_size"] num_epochs = config["num_epochs"] optimizer_opts = config["optimizer_opts"] learning_rate = config['learning_rate'] replicas = config["replicas"] grad_acc_steps_per_replica = config["grad_acc_steps_per_replica"] wandb_opts = config["wandb_opts"] use_outlining = config["use_outlining"] replace_layers = config["replace_layers"] enable_recomputation = config["enable_recomputation"] embedding_serialization_factor = config["embedding_serialization_factor"] optimizer_state_offchip = config["optimizer_state_offchip"] matmul_available_memory_proportion_per_pipeline_stage = config[ "matmul_available_memory_proportion_per_pipeline_stage"] matmul_partials_type = config["matmul_partials_type"] pipeline_stages = config["pipeline_stages"] device_mapping = config["device_mapping"] global_batches_per_log = config["global_batches_per_log"] seed = config["seed"] cache_dir = config["cache_dir"] output_dir = config["output_dir"] # Get optional options save_ckpt_path = config.get("save_ckpt_path", Path(__file__).parent.joinpath("checkpoints").absolute()) ckpt_every_n_steps_per_execution = config.get("ckpt_every_n_steps_per_execution", 2000) universal_run_name = config.get("name", f"{Path(config['config']).stem}-{wandb_opts['init']['name']}") universal_run_name += f"-{datetime.now().strftime('%Y%m%d_%H%M%S')}" print(f"Universal name for run: {universal_run_name}") set_random_seeds(seed) num_pipeline_stages = len(device_mapping) num_ipus_per_replicas = max(device_mapping) + 1 num_ipus = replicas * num_ipus_per_replicas # Load training and validation data # ================================= train_dataset, eval_dataset, num_train_samples, num_eval_samples, raw_datasets = get_squad_data( micro_batch_size, cache_dir ) train_batch_config = BatchConfig(micro_batch_size=micro_batch_size, total_num_train_samples=num_epochs * num_train_samples.numpy(), num_replicas=replicas, gradient_accumulation_count=grad_acc_steps_per_replica, dataset_size=num_train_samples.numpy(), global_batches_per_log=global_batches_per_log, task=Task.OTHER) # Create model # ============ policy = tf.keras.mixed_precision.Policy("float16") tf.keras.mixed_precision.set_global_policy(policy) strategy = create_ipu_strategy(num_ipus, enable_recomputation=enable_recomputation) with strategy.scope(): # Instantiate the pretrained model given in the config. if bert_model_name is not None: model = TFBertForQuestionAnswering.from_pretrained(bert_model_name) else: bert_config = BertConfig(**bert_config_params, hidden_act=ipu.nn_ops.gelu) model = TFBertForQuestionAnswering(config=bert_config) # Convert subclass model to functional, expand main layers to enable pipelining, and replace some layers to # optimise performance. model = convert_tf_bert_model( model, train_dataset, post_process_bert_input_layer, replace_layers=replace_layers, use_outlining=use_outlining, embedding_serialization_factor=embedding_serialization_factor, rename_outputs={'tf.compat.v1.squeeze': 'start_positions', 'tf.compat.v1.squeeze_1': 'end_positions'} ) # Load from pretrained checkpoint if requested. if pretrained_ckpt_path is not None: print(f"Attempting to load pretrained checkpoint from path {pretrained_ckpt_path}. " f"This will overwrite the current weights") load_checkpoint_into_model(model, pretrained_ckpt_path) # Configure pipeline stages # ========================= if num_pipeline_stages > 1: pipeline_assigner = PipelineStagesAssigner(PIPELINE_ALLOCATE_PREVIOUS, PIPELINE_NAMES) assignments = model.get_pipeline_stage_assignment() assignments = pipeline_assigner.assign_pipeline_stages(assignments, pipeline_stages) model.set_pipeline_stage_assignment(assignments) model.print_pipeline_stage_assignment_summary() poplar_options_per_pipeline_stage = get_poplar_options_per_pipeline_stage( num_ipus_per_replicas, device_mapping, matmul_available_memory_proportion_per_pipeline_stage, matmul_partials_type ) model.set_pipelining_options( gradient_accumulation_steps_per_replica=grad_acc_steps_per_replica, pipeline_schedule=ipu.ops.pipelining_ops.PipelineSchedule.Grouped, device_mapping=device_mapping, offload_weight_update_variables=optimizer_state_offchip, forward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage, backward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage, recomputation_mode=ipu.pipelining_ops.RecomputationMode.RecomputeAndBackpropagateInterleaved, ) # Compile the model for training # ============================== # Wrap loss in an out-feed queue. loss_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue() qa_loss = wrap_loss_in_enqueuer(QuestionAnsweringLossFunction, loss_outfeed_queue, ["end_positions_loss", "start_positions_loss"])() # Define optimiser with polynomial decay learning rate. learning_rate['lr_schedule_params']['total_steps'] = train_batch_config.num_train_steps lr_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue(outfeed_mode=ipu.ipu_outfeed_queue.IPUOutfeedMode.LAST) lr_scheduler = get_lr_scheduler(scheduler_name=learning_rate["lr_schedule"], schedule_params=learning_rate["lr_schedule_params"], queue=lr_outfeed_queue) # Prepare optimizer. outline_optimizer_apply_gradients = use_outlining optimizer = get_optimizer( optimizer_opts["name"], grad_acc_steps_per_replica, replicas, lr_scheduler, outline_optimizer_apply_gradients, weight_decay_rate=optimizer_opts["params"]["weight_decay_rate"], ) # Compile the model. model.compile( optimizer=optimizer, loss={"end_positions": qa_loss, "start_positions": qa_loss}, metrics='accuracy', steps_per_execution=train_batch_config.steps_per_execution ) # Train the model # =============== # Set up callbacks callbacks = CallbackFactory.get_callbacks( universal_run_name=universal_run_name, batch_config=train_batch_config, model=model, checkpoint_path=save_ckpt_path, ckpt_every_n_steps_per_execution=ckpt_every_n_steps_per_execution, outfeed_queues=[lr_outfeed_queue, loss_outfeed_queue], config=config, ) # Print configs to be logged in wandb's terminal. print(config) print(f"Training batch config:\n{train_batch_config}") # Train the model history = model.fit( train_dataset, steps_per_epoch=train_batch_config.num_micro_batches_per_epoch, epochs=num_epochs, callbacks=callbacks ) # Evaluate the model on the validation set # ======================================== # Prepare the dataset to be evaluated in the IPU. eval_batch_config = BatchConfig(micro_batch_size=micro_batch_size, total_num_train_samples=num_eval_samples.numpy(), num_replicas=replicas, gradient_accumulation_count=grad_acc_steps_per_replica, dataset_size=num_eval_samples.numpy(), task=Task.OTHER) max_eval_samples = eval_batch_config.micro_batch_size * eval_batch_config.num_micro_batches_per_epoch eval_pred_dataset = get_prediction_dataset(eval_dataset, max_eval_samples) with strategy.scope(): # Re-compile the model for prediction if needed. if train_batch_config.steps_per_execution != eval_batch_config.steps_per_execution: model.compile(steps_per_execution=eval_batch_config.steps_per_execution) # Get predictions for the validation data. print(f"Running inference:\nGenerating predictions on the validation data...") predictions = model.predict( eval_pred_dataset, batch_size=eval_batch_config.micro_batch_size ) # The predictions for the end position goes first in the model outputs tuple (note the output of model.summary()). end_predictions, start_predictions = predictions # Match the predictions to answers in the original context. # This will also write out the predictions to a json file in the directory given by `output_dir`. final_predictions = postprocess_qa_predictions( list(raw_datasets["validation"].as_numpy_iterator()), list(eval_dataset.unbatch().as_numpy_iterator()), (start_predictions, end_predictions), output_dir=output_dir ) # Format the predictions and the actual labels as expected by the metric. formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()] formatted_labels = format_raw_data_for_metric(raw_datasets["validation"]) metric = load_metric("squad") metrics = metric.compute(predictions=formatted_predictions, references=formatted_labels) print("Evaluation metrics:") for key, value in metrics.items(): print(f"{key}: {value:.3f}") return history