コード例 #1
0
def test_tied_gather_pattern_correctness(splits, phase, optimizer, custom_ops):
    train = phase == "bwd"

    outputs_1, proto_1 = session(train, skip_execution=False, splits=splits, optim=optimizer)

    outputs_2, proto_2 = session(train, skip_execution=False, include_patterns=False, splits=splits, optim=optimizer)

    check_tensors(outputs_1, outputs_2)
    if train:
        check_onnx_model(proto_1, proto_2)
コード例 #2
0
def test_tied_gather_pattern_outlining_correctness(phase, custom_ops):
    train = phase == "bwd"

    outputs_1, proto_1 = session(train, skip_execution=False, splits=4, outline=True)

    outputs_2, proto_2 = session(train, skip_execution=False, include_patterns=False, splits=4, outline=True)

    check_tensors(outputs_1, outputs_2)
    if train:
        check_onnx_model(proto_1, proto_2)
コード例 #3
0
def test_split_qkv_weight_loading():
    config = BertConfig(task="SQUAD",
                        vocab_length=1024,
                        micro_batch_size=1,
                        hidden_size=64,
                        attention_heads=2,
                        sequence_length=20,
                        popart_dtype="FLOAT",
                        num_layers=2,
                        no_mask=True,
                        split_qkv=False)

    def get_split(full_t, t):
        return np.split(full_t, 3, axis=1)["QKV".index(t)]

    mapping = {f"Layer{i}/Attention/{t}": f"Layer{i}/Attention/QKV"
               for i in range(config.num_layers) for t in "QKV"}
    transform = {f"Layer{i}/Attention/{t}": partial(get_split, t=t)
                 for i in range(config.num_layers) for t in "QKV"}

    # Get a unsplit checkpoint
    np.random.seed(123)
    proto_1 = get_model_proto(config)
    initializers = get_initializers(proto_1)

    split_config = config._replace(split_qkv=True)

    # Load the unsplit checkpoint into a split model
    np.random.seed(456)
    proto_2 = get_model_proto(split_config, initializers=initializers)

    check_onnx_model(proto_1,
                     proto_2,
                     mapping,
                     transform,
                     allow_missing=False)

    # Extract weights
    initializers = get_initializers(proto_2)

    # Load the split checkpoint into an unsplit model
    np.random.seed(456)
    proto_3 = get_model_proto(config, initializers=initializers)

    check_onnx_model(proto_3,
                     proto_2,
                     mapping,
                     transform,
                     allow_missing=False)
コード例 #4
0
def test_weight_mapping(num_vocab_splits, task):
    config = BertConfig(task=task,
                        vocab_length=1024,
                        micro_batch_size=1,
                        hidden_size=64,
                        attention_heads=2,
                        sequence_length=20,
                        mask_tokens=8,
                        popart_dtype="FLOAT",
                        num_layers=2,
                        no_mask=True,
                        no_dropout=True,
                        no_attn_dropout=True,
                        embedding_serialization_vocab_steps=num_vocab_splits,
                        inference=True)

    # Run pipelined BERT
    pipelined_proto = get_model_proto(config, mode=ExecutionMode.PIPELINE)

    # Extract weights
    with tempfile.TemporaryDirectory() as tmp:
        file_path = os.path.join(tmp, "model.onnx")
        onnx.save(pipelined_proto, file_path)
        initializers = load_initializers_from_onnx(file_path)
        initializers.update(
            **get_phased_initializers_from_default(config, initializers))

    # Create phased_execution version of the model
    config_nosplit = config._replace(embedding_serialization_vocab_steps=1)
    phased_proto = get_model_proto(config,
                                   mode=ExecutionMode.PHASED,
                                   initializers=initializers)

    # Create a pipelined version of the model without any embedding split for the comparison
    pipelined_proto_nosplit = get_model_proto(config_nosplit,
                                              mode=ExecutionMode.PIPELINE,
                                              initializers=initializers)

    # Check inital protos match for pipelined vs phased_execution model
    check_onnx_model(pipelined_proto_nosplit,
                     phased_proto,
                     phased_to_default_mapping(config),
                     phased_from_default_transform(config),
                     allow_missing=False)
コード例 #5
0
def test_lamb_serialised_pattern_correctness(splits, custom_ops):
    outputs_1, proto_1 = session(splits=1)
    outputs_2, proto_2 = session(splits=splits)

    check_tensors(outputs_1, outputs_2)
    check_onnx_model(proto_1, proto_2)