def test_multi_value_matmul_prop(): args = """ --config unit_test --layers-per-ipu 3 7 7 7 --num-hidden-layers 24 --matmul-proportion 0.15 0.3 0.3 0.3 """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert config.matmul_proportion == [0.15, 0.3, 0.3, 0.3] # Invalid inputs args = """ --config unit_test --layers-per-ipu 3 7 7 7 --num-hidden-layers 24 --matmul-proportion 0.15 0.3 0.3 """.split() with pytest.raises(SystemExit): config = BertConfig(**(vars(parse_bert_args(args)))) args = """ --config unit_test --layers-per-ipu 3 7 7 7 --num-hidden-layers 24 --matmul-proportion 0.15 0.3 0.3 0.3 0.3 """.split() with pytest.raises(SystemExit): config = BertConfig(**(vars(parse_bert_args(args))))
def test_invalid_layers_per_ipu(): args = """ --config unit_test --layers-per-ipu 1 1 1 1 --num-hidden-layers 3 """.split() with pytest.raises(SystemExit): config = BertConfig(**(vars(parse_bert_args(args)))) args = """ --config unit_test --layers-per-ipu 4 --num-hidden-layers 3 """.split() with pytest.raises(SystemExit): config = BertConfig(**(vars(parse_bert_args(args)))) args = """ --config unit_test --layers-per-ipu 0 1 2 1 --num-hidden-layers 3 """.split() with pytest.raises(SystemExit): config = BertConfig(**(vars(parse_bert_args(args)))) args = """ --config unit_test --layers-per-ipu 0 1 1 1 1 --num-hidden-layers 3 """.split() with pytest.raises(SystemExit): config = BertConfig(**(vars(parse_bert_args(args))))
def test_checkpoint_recompute_checkpoint(recompute_checkpoint): """ If a checkpoint is saved with `recompute_checkpoint_every_layer` then we should be able to restore the checkpoint in a new run that doesn't use `recompute_checkpoint_every_layer` and vice-verse. """ args = """ --config unit_test """.split() config1 = BertConfig(**(vars(parse_bert_args(args)))) config1.recompute_checkpoint_every_layer = recompute_checkpoint model1 = PipelinedBertForPretraining(config1).parallelize() with tempfile.TemporaryDirectory() as dir: # Save checkpoint config1.checkpoint_output_dir = dir save_checkpoint(config1, model1, 0) # New model with opposite `recompute_checkpoint` to model1 config2 = BertConfig(**(vars(parse_bert_args(args)))) config2.recompute_checkpoint_every_layer = not recompute_checkpoint model2 = PipelinedBertForPretraining.from_pretrained(os.path.join(dir, "step_0"), config=config2).parallelize() # Models should now have the same weights for name, tensor1 in model1.state_dict().items(): tensor2 = model2.state_dict()[name] assert torch.allclose(tensor1, tensor2)
def test_get_layer_ipu(): args = """ --config unit_test --layers-per-ipu 2 --num-hidden-layers 12 """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert (_get_layer_ipu(config.layers_per_ipu) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5] ) args = """ --config unit_test --layers-per-ipu 2 2 2 2 2 1 --num-hidden-layers 11 """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert (_get_layer_ipu(config.layers_per_ipu) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5] ) args = """ --config unit_test --layers-per-ipu 0 1 1 1 --num-hidden-layers 3 """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert (_get_layer_ipu(config.layers_per_ipu) == [1, 2, 3] )
def test_checkpoint_embedding_serialization(embedding_serialization_factor): """ If a checkpoint is saved with embedding_serialization_factor then we should be able to restore the checkpoint in a new run where embedding_serialization_factor isn't used. The reverse should also hold. """ args = """ --config unit_test """.split() config1 = BertConfig(**(vars(parse_bert_args(args)))) config1.embedding_serialization_factor = embedding_serialization_factor model1 = PipelinedBertForPretraining(config1).parallelize() with tempfile.TemporaryDirectory() as dir: # Save checkpoint config1.checkpoint_output_dir = dir save_checkpoint(config1, model1, 0) # New model with opposite embedding_serialization to model1 config2 = BertConfig(**(vars(parse_bert_args(args)))) config2.embedding_serialization_factor = 5 if embedding_serialization_factor == 1 else 1 model2 = PipelinedBertForPretraining.from_pretrained(os.path.join(dir, "step_0"), config=config2).parallelize() assert model2.config.embedding_serialization_factor == config2.embedding_serialization_factor # Models should now have the same weights for name, tensor1 in model1.state_dict().items(): tensor2 = model2.state_dict()[name] assert torch.allclose(tensor1, tensor2)
def test_multi_value_layers_per_ipu(): args = """ --config unit_test --layers-per-ipu 1 2 3 4 --num-hidden-layers 10 """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert config.layers_per_ipu == [1, 2, 3, 4] args = """ --config unit_test --layers-per-ipu 0 3 3 4 --num-hidden-layers 10 """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert config.layers_per_ipu == [0, 3, 3, 4]
def dataset(): """ Check if the data in two instances is different """ args = "--config demo_tiny_128".split() config = transformers.BertConfig(**(vars(parse_bert_args(args)))) opts = get_options(config) loader = TFRecordPretrainingDataset(config.input_files) loader = get_dataloader(config, opts) # Save part of the data as list loader_list = list(loader)[0][0][0].numpy() # MPI to broadcast data in root=1 to root=0 from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() loader_list_copy = np.copy(loader_list) comm.Bcast(loader_list, root=1) # Assert if data broadcast to root=0 is different if comm.Get_rank() == 0 and not np.all(loader_list_copy == loader_list): print('Passed test: instances have different data') # Wait until both roots are finished time.sleep(2)
def test_checkpoint_embedding_serialization_qa(embedding_serialization_factor): """ If a checkpoint is saved with embedding_serialization_factor then we should be able to restore the checkpoint in a new run where embedding_serialization_factor isn't used. The reverse should also hold. For PipelinedBertForQuestionAnswering we will need to call `deparallelize` before checkpointing. """ args = """ --config unit_test """.split() config = BertConfig(**(vars(parse_bert_args(args)))) config.embedding_serialization_factor = embedding_serialization_factor model1 = PipelinedBertForQuestionAnswering(config).parallelize() with tempfile.TemporaryDirectory() as dir: # Save checkpoint config.checkpoint_output_dir = dir model1.deparallelize() save_checkpoint(config, model1, 0) # Load the checkpoint, but don't call parallelize model2 = PipelinedBertForQuestionAnswering.from_pretrained(os.path.join(dir, "step_0")) # Models should have the same weights for name, tensor1 in model1.state_dict().items(): tensor2 = model2.state_dict()[name] assert torch.allclose(tensor1, tensor2)
def test_single_value_layers_per_ipu(): args = """ --config unit_test --layers-per-ipu 1 --num-hidden-layers 4 """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert config.layers_per_ipu == [1, 1, 1, 1]
def test_single_value_matmul_prop(): # Matmul proportion on all IPUs, not just encoder IPUs args = """ --config unit_test --layers-per-ipu 1 --num-hidden-layers 4 --matmul-proportion 0.2 """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert config.matmul_proportion == [0.2, 0.2, 0.2, 0.2]
def test_wikipedia_dataset(): args = "--config demo_tiny_128".split() config = transformers.BertConfig(**(vars(parse_bert_args(args)))) config.vocab_size = 30522 config.input_files = ["data/wikipedia/128/wiki_000.tfrecord"] num_tokens = 0 replacement_counts = Counter({"103": 0, "same": 0, "random": 0}) opts = get_options(config) loader = get_dataloader(config, opts) for datum in tqdm(loader): tokens, attn_mask, types, mask_lm_pos, labels, nsp = datum tokens = tokens.numpy() attn_mask = attn_mask.numpy() types = types.numpy() mask_lm_pos = mask_lm_pos.numpy() labels = labels.numpy() nsp = nsp.numpy() for b in range(config.micro_batch_size): check_dimensions(config, tokens[b], attn_mask[b], types[b], mask_lm_pos[b], labels[b], nsp[b]) check_tokens(config, tokens[b], mask_lm_pos[b], labels[b]) check_attention_mask(attn_mask[b], tokens[b]) check_mask_lm_positions(config, mask_lm_pos[b]) check_labels(config, tokens[b], mask_lm_pos[b], labels[b]) check_token_type(types[b]) check_nsp(nsp[b]) replacement_counts += mask_type_count(tokens[b], mask_lm_pos[b], labels[b]) # Number of tokens, not including padding num_tokens += attn_mask[b, attn_mask[b] == 1].shape[0] # Test masked token proportions total = sum(replacement_counts.values()) for k in replacement_counts: replacement_counts[k] /= total assert (0.79 < replacement_counts["103"] < 0.81) assert (0.09 < replacement_counts["same"] < 0.11) assert (0.09 < replacement_counts["random"] < 0.11) assert (0.14 < total / num_tokens < 0.16) # should be ~0.15
def test_checkpoint_save_restore(recompute_checkpoint, embedding_serialization_factor): """ Test that saving and restoring checkpoints works. Also test checkpointing with recomputation checkpoints and embedding serialization. """ args = """ --config unit_test """.split() config = BertConfig(**(vars(parse_bert_args(args)))) config.recompute_checkpoint_every_layer = recompute_checkpoint config.embedding_serialization_factor = embedding_serialization_factor model1 = PipelinedBertForPretraining(config).parallelize() model2 = PipelinedBertForPretraining(config).parallelize() # The two models should have different initial weights for name, tensor1 in model1.state_dict().items(): tensor2 = model2.state_dict()[name] if (tensor1.dtype is not torch.int64) and ("LayerNorm" not in name) and ("bias" not in name): assert not torch.allclose(tensor1, tensor2) # Save and restore checkpoint with tempfile.TemporaryDirectory() as dir: config.checkpoint_output_dir = dir # No checkpoints should exist yet assert not checkpoints_exist(config.checkpoint_output_dir) save_checkpoint(config, model1, 0) # Checkpoint should now exist assert checkpoints_exist(config.checkpoint_output_dir) # Restore from checkpoint model2 = PipelinedBertForPretraining.from_pretrained(os.path.join(dir, "step_0"), config=config) # Models should now have the same weights for name, tensor1 in model1.state_dict().items(): tensor2 = model2.state_dict()[name] assert torch.allclose(tensor1, tensor2)
def test_recompute_checkpoint_not_in_ir(): import warnings warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # Config args = """ --config unit_test --lr-schedule constant --layers-per-ipu 0 3 --vocab-size 30400 --weight-decay 0.0 --recompute-checkpoint-every-layer False """.split() config = BertConfig(**(vars(parse_bert_args(args)))) assert config.recompute_checkpoint_every_layer is False # Execution parameters opts = get_options(config) model = PipelinedBertForPretraining(config).parallelize().half().train() optimizer = get_optimizer(config, model) poptorch_model = poptorch.trainingModel(model, opts, optimizer=optimizer) # Compile model datum = get_generated_datum(config) poptorch_model.compile(*datum) ir = json.loads(poptorch_model._debugGetPopartIR()) assert not any(["Checkpoint" in node["name"] for node in ir["maingraph"] ]), ("Popart IR should contain a checkpoint") # Stash: 5 inputs, and 1 stash for transformers on ipu1 exp_num_stash = 5 + 1 assert sum([ "Stash" in node["type"] for node in ir["maingraph"] ]) == exp_num_stash, ("Both the graph input and the checkpoint(s) " "should be stashed") print(sum(["Stash" in node["type"] for node in ir["maingraph"]]))
def main(): config = transformers.BertConfig(**(vars(parse_bert_args()))) if not config.pretrained_checkpoint: logger( "[warning] --pretrained-checkpoint was not specified; training with uninitialized BERT..." ) # Warnings for configs where embeddings may not fit if config.embedding_serialization_factor == 1: if config.replication_factor == 1: logger( "[warning] With replication_factor == 1 you may need to set " "embedding_serialization_factor > 1 for the model to fit") elif not config.replicated_tensor_sharding: logger( "[warning] With replicated_tensor_sharding=False you may need to set " "embedding_serialization_factor > 1 for the model to fit") samples_per_step = config.batches_per_step * config.micro_batch_size * \ config.gradient_accumulation * config.replication_factor do_training = config.squad_do_training do_validation = config.squad_do_validation opts = get_options(config) opts.outputMode(poptorch.OutputMode.All) logger("Loading Dataset...") datasets = load_dataset("squad") train_dataset = datasets["train"] # Create train features from dataset logger("Tokenizing Train Dataset...") train_dataset = train_dataset.map( prepare_train_features, batched=True, num_proc=1, remove_columns=train_dataset.column_names, load_from_cache_file=True, ) # Create validation features from dataset logger("Tokenizing Validation Dataset...") validation_features = datasets["validation"].map( prepare_validation_features, batched=True, num_proc=1, remove_columns=datasets["validation"].column_names, load_from_cache_file=True, ) # W&B if config.wandb and (not config.use_popdist or config.popdist_rank == 0): wandb.init(project="torch-bert", settings=wandb.Settings(console="wrap")) wandb_config = vars(config) wandb_config['sdk_version'] = get_sdk_version() wandb.config.update(wandb_config) # Create the model if config.pretrained_checkpoint: model_ipu = PipelinedBertForQuestionAnswering.from_pretrained( config.pretrained_checkpoint, config=config).parallelize().half() else: model_ipu = PipelinedBertForQuestionAnswering( config).parallelize().half() if do_training: train_dl = poptorch.DataLoader( opts, train_dataset, batch_size=config.micro_batch_size, shuffle=True, drop_last=False, collate_fn=PadCollate( samples_per_step, { "input_ids": 0, "attention_mask": 0, "token_type_ids": 0, "start_positions": config.sequence_length, "end_positions": config.sequence_length })) optimizer = get_optimizer(config, model_ipu) model_ipu.train() training_model = poptorch.trainingModel(model_ipu, opts, optimizer) sample_batch = next(iter(train_dl)) logger("Compiling Model...") start_compile = time.perf_counter() training_model.compile(sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"], sample_batch["start_positions"], sample_batch["end_positions"]) duration_compilation = time.perf_counter() - start_compile logger(f"Compiled/Loaded model in {duration_compilation} secs") if config.compile_only: sys.exit() # Train scheduler = get_lr_scheduler(optimizer, "linear", config.lr_warmup, config.num_epochs * len(train_dl)) logger("Training...") for epoch in range(config.num_epochs): for step, batch in enumerate(train_dl): start_step = time.perf_counter() outputs = training_model(batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["start_positions"], batch["end_positions"]) scheduler.step() training_model.setOptimizer(optimizer) step_length = time.perf_counter() - start_step step_throughput = samples_per_step / step_length loss = outputs[0].mean().item() logger( f"Epoch: {epoch}, Step:{step}, LR={scheduler.get_last_lr()[0]:.2e}, loss={loss:3.3f}, throughput={step_throughput:3.3f} samples/s" ) if config.wandb: wandb.log({ "Loss": loss, "LR": scheduler.get_last_lr()[0], "Step": step, "Throughput": step_throughput }) training_model.detachFromDevice() if do_validation: config.micro_batch_size = 2 config.batches_per_step = 16 config.gradient_accumulation = 1 config.replication_factor = 1 samples_per_step = config.batches_per_step * config.micro_batch_size * \ config.gradient_accumulation * config.replication_factor opts = get_options(config) opts.outputMode(poptorch.OutputMode.All) val_dl = poptorch.DataLoader(opts, validation_features.remove_columns( ['example_id', 'offset_mapping']), batch_size=config.micro_batch_size, shuffle=False, drop_last=False, collate_fn=default_data_collator) raw_predictions = [[], []] model_ipu.eval() inference_model = poptorch.inferenceModel(model_ipu, opts) sample_batch = next(iter(val_dl)) logger("Compiling Inference Model...") inference_model.compile(sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"]) if config.compile_only: sys.exit() logger("Validating...") for step, batch in enumerate(val_dl): start_step = time.perf_counter() outputs = inference_model(batch["input_ids"], batch["attention_mask"], batch["token_type_ids"]) step_length = time.perf_counter() - start_step step_throughput = samples_per_step / step_length raw_predictions[0].append(outputs[0]) raw_predictions[1].append(outputs[1]) logger(f"Step:{step}, throughput={step_throughput} samples/s") raw_predictions[0] = torch.vstack(raw_predictions[0]).float().numpy() raw_predictions[1] = torch.vstack(raw_predictions[1]).float().numpy() final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions) metric = load_metric("squad") formatted_predictions = [{ "id": k, "prediction_text": v } for k, v in final_predictions.items()] references = [{ "id": ex["id"], "answers": ex["answers"] } for ex in datasets["validation"]] metrics = metric.compute(predictions=formatted_predictions, references=references) logger(metrics) if config.wandb: for k, v in metrics.items(): wandb.run.summary[k] = v
from pretraining_data import get_dataloader, get_generated_datum from modeling import PipelinedBertForPretraining from ipu_options import get_options from optimization import get_lr_scheduler, get_optimizer from checkpointing import save_checkpoint, checkpoints_exist from utils import get_sdk_version, cycle, logger, sync_metrics from args import parse_bert_args if __name__ == "__main__": # Ignore known warnings warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) logging.getLogger("poptorch::python").setLevel(logging.ERROR) # Build config from args config = transformers.BertConfig(**(vars(parse_bert_args()))) # Warnings for configs where embeddings may not fit if config.embedding_serialization_factor == 1: if config.replication_factor == 1: logger( "[warning] With replication_factor == 1 you may need to set " "embedding_serialization_factor > 1 for the model to fit") elif not config.replicated_tensor_sharding: logger( "[warning] With replicated_tensor_sharding=False you may need to set " "embedding_serialization_factor > 1 for the model to fit") # prevent overwriting of existing checkpoints if checkpoints_exist(config.checkpoint_output_dir): raise RuntimeError( "Found previously saved checkpoint(s) at checkpoint-dir. "