def test_split_qkv_weight_loading(): config = BertConfig(task="SQUAD", vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, popart_dtype="FLOAT", num_layers=2, no_mask=True, split_qkv=False) def get_split(full_t, t): return np.split(full_t, 3, axis=1)["QKV".index(t)] mapping = {f"Layer{i}/Attention/{t}": f"Layer{i}/Attention/QKV" for i in range(config.num_layers) for t in "QKV"} transform = {f"Layer{i}/Attention/{t}": partial(get_split, t=t) for i in range(config.num_layers) for t in "QKV"} # Get a unsplit checkpoint np.random.seed(123) proto_1 = get_model_proto(config) initializers = get_initializers(proto_1) split_config = config._replace(split_qkv=True) # Load the unsplit checkpoint into a split model np.random.seed(456) proto_2 = get_model_proto(split_config, initializers=initializers) check_onnx_model(proto_1, proto_2, mapping, transform, allow_missing=False) # Extract weights initializers = get_initializers(proto_2) # Load the split checkpoint into an unsplit model np.random.seed(456) proto_3 = get_model_proto(config, initializers=initializers) check_onnx_model(proto_3, proto_2, mapping, transform, allow_missing=False)
def test_weight_mapping(num_vocab_splits, task): config = BertConfig(task=task, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, mask_tokens=8, popart_dtype="FLOAT", num_layers=2, no_mask=True, no_dropout=True, no_attn_dropout=True, embedding_serialization_vocab_steps=num_vocab_splits, inference=True) # Run pipelined BERT pipelined_proto = get_model_proto(config, mode=ExecutionMode.PIPELINE) # Extract weights with tempfile.TemporaryDirectory() as tmp: file_path = os.path.join(tmp, "model.onnx") onnx.save(pipelined_proto, file_path) initializers = load_initializers_from_onnx(file_path) initializers.update( **get_phased_initializers_from_default(config, initializers)) # Create phased_execution version of the model config_nosplit = config._replace(embedding_serialization_vocab_steps=1) phased_proto = get_model_proto(config, mode=ExecutionMode.PHASED, initializers=initializers) # Create a pipelined version of the model without any embedding split for the comparison pipelined_proto_nosplit = get_model_proto(config_nosplit, mode=ExecutionMode.PIPELINE, initializers=initializers) # Check inital protos match for pipelined vs phased_execution model check_onnx_model(pipelined_proto_nosplit, phased_proto, phased_to_default_mapping(config), phased_from_default_transform(config), allow_missing=False)