def test_tied_gather_pattern_correctness(splits, phase, optimizer, custom_ops): train = phase == "bwd" outputs_1, proto_1 = session(train, skip_execution=False, splits=splits, optim=optimizer) outputs_2, proto_2 = session(train, skip_execution=False, include_patterns=False, splits=splits, optim=optimizer) check_tensors(outputs_1, outputs_2) if train: check_onnx_model(proto_1, proto_2)
def test_tied_gather_pattern_outlining_correctness(phase, custom_ops): train = phase == "bwd" outputs_1, proto_1 = session(train, skip_execution=False, splits=4, outline=True) outputs_2, proto_2 = session(train, skip_execution=False, include_patterns=False, splits=4, outline=True) check_tensors(outputs_1, outputs_2) if train: check_onnx_model(proto_1, proto_2)
def test_split_qkv_weight_loading(): config = BertConfig(task="SQUAD", vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, popart_dtype="FLOAT", num_layers=2, no_mask=True, split_qkv=False) def get_split(full_t, t): return np.split(full_t, 3, axis=1)["QKV".index(t)] mapping = {f"Layer{i}/Attention/{t}": f"Layer{i}/Attention/QKV" for i in range(config.num_layers) for t in "QKV"} transform = {f"Layer{i}/Attention/{t}": partial(get_split, t=t) for i in range(config.num_layers) for t in "QKV"} # Get a unsplit checkpoint np.random.seed(123) proto_1 = get_model_proto(config) initializers = get_initializers(proto_1) split_config = config._replace(split_qkv=True) # Load the unsplit checkpoint into a split model np.random.seed(456) proto_2 = get_model_proto(split_config, initializers=initializers) check_onnx_model(proto_1, proto_2, mapping, transform, allow_missing=False) # Extract weights initializers = get_initializers(proto_2) # Load the split checkpoint into an unsplit model np.random.seed(456) proto_3 = get_model_proto(config, initializers=initializers) check_onnx_model(proto_3, proto_2, mapping, transform, allow_missing=False)
def test_weight_mapping(num_vocab_splits, task): config = BertConfig(task=task, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, mask_tokens=8, popart_dtype="FLOAT", num_layers=2, no_mask=True, no_dropout=True, no_attn_dropout=True, embedding_serialization_vocab_steps=num_vocab_splits, inference=True) # Run pipelined BERT pipelined_proto = get_model_proto(config, mode=ExecutionMode.PIPELINE) # Extract weights with tempfile.TemporaryDirectory() as tmp: file_path = os.path.join(tmp, "model.onnx") onnx.save(pipelined_proto, file_path) initializers = load_initializers_from_onnx(file_path) initializers.update( **get_phased_initializers_from_default(config, initializers)) # Create phased_execution version of the model config_nosplit = config._replace(embedding_serialization_vocab_steps=1) phased_proto = get_model_proto(config, mode=ExecutionMode.PHASED, initializers=initializers) # Create a pipelined version of the model without any embedding split for the comparison pipelined_proto_nosplit = get_model_proto(config_nosplit, mode=ExecutionMode.PIPELINE, initializers=initializers) # Check inital protos match for pipelined vs phased_execution model check_onnx_model(pipelined_proto_nosplit, phased_proto, phased_to_default_mapping(config), phased_from_default_transform(config), allow_missing=False)
def test_lamb_serialised_pattern_correctness(splits, custom_ops): outputs_1, proto_1 = session(splits=1) outputs_2, proto_2 = session(splits=splits) check_tensors(outputs_1, outputs_2) check_onnx_model(proto_1, proto_2)