Esempio n. 1
0
def fwd_graph(popart_model, torch_model, mapping=None, transform=None):
    #  ------------------- PopART --------------------
    config = popart_model.config
    builder = popart_model.builder

    sequence_info = popart.TensorInfo(
        "INT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    segments = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32),
        positions:
        np.random.randint(0, config.sequence_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32)
    }

    output = popart_model.build_graph(indices, positions, segments)
    proto = builder.getModelProto()

    outputs, post_proto = run_py(
        proto,
        data,
        output,
        ipus=math.ceil(config.num_layers / config.layers_per_ipu) +
        popart_model.layer_offset)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = {
        "input_ids":
        data[indices].reshape(config.batch_size, config.sequence_length),
        "position_ids":
        data[positions].reshape(config.batch_size, config.sequence_length),
        "token_type_ids":
        data[segments].reshape(config.batch_size, config.sequence_length)
    }

    torch_to_onnx = get_mapping(config, init=mapping)

    transform_weights = get_transform(config, init=transform)

    #  ------------------- PyTorch -------------------------
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform_weights)

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
Esempio n. 2
0
def pytorch_result_and_model(torch_config, inputs, popart_proto, is_bwd=False):
    # Conversion of the popart model to onnx
    proto = onnx.load_model_from_string(popart_proto)

    torch_model = BertFCN(torch_config)
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          TORCH_TO_ONNX,
                          transform=TRANSPOSE_WEIGHTS)

    result = run_fwd_model(inputs, torch_model)

    if is_bwd:
        l1_lambda = 0.1
        optim = torch.optim.SGD(torch_model.parameters(),
                                0.01,
                                weight_decay=0.0,
                                momentum=0.0)

        result = torch_model(*[torch.from_numpy(t).float() for t in inputs])[0]
        torch_loss = l1_lambda * torch.norm(result, 1)
        torch_loss.backward()
        optim.step()
        result = result.detach().numpy()

    return result, torch_model
Esempio n. 3
0
def pytorch_result_and_model(torch_config,
                             inputs,
                             popart_proto,
                             weight_decay=0.0,
                             lr=0.0,
                             l1_lambda=0.0):

    proto = onnx.load_model_from_string(popart_proto)
    torch_model = BertFCN(torch_config)
    torch_model.eval()  # Turn off dropout
    copy_weights_to_torch(torch_model,
                          proto,
                          TORCH_TO_ONNX,
                          transform=TRANSPOSE_WEIGHTS)
    run_fwd_model(inputs, torch_model)

    decay = []
    no_decay = []
    for name, param in torch_model.named_parameters():
        if "bias" in name or "LayerNorm" in name:
            no_decay.append(param)
        else:
            decay.append(param)

    params = [{
        'params': no_decay,
        'weight_decay': 0.
    }, {
        'params': decay,
        'weight_decay': weight_decay
    }]

    optim = torch.optim.SGD(params, lr, momentum=0.0)

    result = torch_model(*[torch.from_numpy(t).float() for t in inputs])[0]
    torch_loss = l1_lambda * torch.norm(result, 1)
    torch_loss.backward()
    optim.step()
    result = result.detach().numpy()

    return result, torch_model
Esempio n. 4
0
def run_models(config, proto, indices, positions, segments, output,
               popart_model, torch_model):
    onnx_proto = onnx.load_model_from_string(proto)
    check_model(torch_model, onnx_proto, get_mapping(config),
                get_transform(config))

    # Run the models
    popart_inputs = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(
            0,
            config.sequence_length,
            (config.batch_size * config.sequence_length),
        ).astype(np.uint32),
        segments:
        np.random.randint(
            0,
            2,
            (config.batch_size * config.sequence_length),
        ).astype(np.uint32),
    }

    popart_outputs, post_proto = run_py(
        proto,
        popart_inputs,
        output,
        ipus=popart_model.total_ipus,
    )

    torch_inputs = {
        "input_ids":
        popart_inputs[indices].reshape(config.batch_size,
                                       config.sequence_length),
        "position_ids":
        popart_inputs[positions].reshape(config.batch_size,
                                         config.sequence_length),
        "token_type_ids":
        popart_inputs[segments].reshape(config.batch_size,
                                        config.sequence_length),
    }

    torch_model.eval()
    torch_outputs = run_fwd_model(torch_inputs, torch_model)

    check_model(torch_model, post_proto, get_mapping(config),
                get_transform(config))
    check_tensors(torch_outputs, popart_outputs)
    print("Test succeeded")
Esempio n. 5
0
def pytorch_result_and_model(config,
                             inputs,
                             popart_proto,
                             weight_transposed,
                             is_bwd=False):
    """Run pytorch model based on config.

    Args:
        config (BertConfig): Popart config.
        inputs (np.ndarray): Input np array.
        popart_proto (onnx.proto):  Onnx protobuf.
        weight_transposed (bool): If True, onnx weights are constructed transposed.
        is_bwd (bool, optional): True if bwd_pass. Defaults to False.

    Returns:
        Tuple: Output np.array and Torch model.
    """
    torch_config = TorchBertConfig(config.vocab_length,
                                   config.hidden_size,
                                   config.num_layers,
                                   config.attention_heads,
                                   layer_norm_eps=config.layer_norm_eps)
    torch_model = nn.Embedding(torch_config.vocab_size,
                               torch_config.hidden_size,
                               padding_idx=0)
    # Turn off dropout
    torch_model.eval()

    # Conversion of the popart model to onnx
    proto = onnx.load_model_from_string(popart_proto)
    initializers = get_initializers(proto, weight_transposed)

    for name, weight in torch_model.named_parameters():
        weight.data.copy_(torch.from_numpy(initializers[name]).float())

    result = run_fwd_model(inputs, torch_model)

    if is_bwd:
        optim = torch.optim.SGD(torch_model.parameters(),
                                0.01,
                                weight_decay=0.0,
                                momentum=0.0)

        result = torch_model(*[torch.from_numpy(t).long() for t in inputs])[0]
        torch_loss = 0.1 * torch.norm(result, 1)
        torch_loss.backward()
        optim.step()
        result = [result.detach().numpy()]

    return result, torch_model
Esempio n. 6
0
def pytorch_result_and_model(torch_config,
                             inputs,
                             popart_proto,
                             mode,
                             is_bwd=False,
                             momentum=0.0):
    # Conversion of the popart model to onnx
    proto = onnx.load_model_from_string(popart_proto)

    torch_model = BertFCN(torch_config)
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          TORCH_TO_ONNX[mode],
                          transform=TRANSPOSE_WEIGHTS)

    result = run_fwd_model(inputs, torch_model)

    if is_bwd:
        l1_lambda = 0.1
        optim = torch.optim.SGD(torch_model.parameters(),
                                lr,
                                weight_decay=0.0,
                                momentum=momentum)

        if momentum > 0.0:
            for group in optim.param_groups:
                for p in group['params']:
                    optim.state[p]['momentum_buffer'] = p.data * 0.
                    optim.state[p]['exp_avg'] = p.data * 0.
                    optim.state[p]['exp_avg_sq'] = p.data * 0.
                    optim.state[p]['step'] = 0

        for _ in range(num_reps_bwd):
            result = torch_model(
                *[torch.from_numpy(t).float() for t in inputs])[0]
            torch_loss = l1_lambda * torch.norm(result, 1)
            torch_loss.backward()
            optim.step()
            optim.zero_grad()
        result = [result.detach().numpy()]

    return result, torch_model
Esempio n. 7
0
def test_embedding_fwd(custom_ops):
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        micro_batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        inference=True)
    popart_model = Bert(config)

    sequence_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size * config.sequence_length])
    indices = popart_model.builder.addInputTensor(sequence_info)
    positions = popart_model.builder.addInputTensor(sequence_info)
    segments = popart_model.builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(
            0, config.vocab_length,
            (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32),
        positions:
        np.random.randint(
            0, config.max_positional_length,
            (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32),
        segments:
        np.random.randint(
            0, 2, (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32)
    }

    user_options = {"enableStochasticRounding": True}
    with popart_model.builder.nameScope("Embedding"):
        output = popart_model.embedding(indices, positions, segments)

    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 user_options=user_options)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.micro_batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {})
    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs, margin=5e-7)
Esempio n. 8
0
def test_embedding_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather'],
                        inference=True)
    popart_model = Bert(config, builder=builder)
    # Prevent virtualGraph attributes being added to the ops.
    popart_model.embedding_scope = popart_model.device_scope(None, None)
    popart_model.embedding_split_scope = popart_model.embedding_scope

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    segments = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    # Use the custom embedding for layout
    output = popart_model.embedding(indices, positions, segments)

    proto = builder.getModelProto()

    outputs, post_proto = run_py(
        proto, data, output, user_options={"enableStochasticRounding": True})

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    torch_to_onnx = {
        "word_embeddings.weight": "Embedding_Dict",
        "position_embeddings.weight": "Positional_Dict",
        "token_type_embeddings.weight": "Segment_Dict",
        "LayerNorm.weight": "Gamma",
        "LayerNorm.bias": "Beta"
    }

    transposed_weights = {
        "word_embeddings.weight": np.transpose,
        "position_embeddings.weight": np.transpose,
    }

    #  ------------------- PyTorch -------------------------
    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, torch_to_onnx,
                          transposed_weights)

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
def test_embedding_projection_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        embedding_serialization_vocab_steps=4,
                        micro_batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_cls_layer=False,
                        inference=True)
    popart_model = Bert(config, builder=builder)

    sequence_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.micro_batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    x = popart_model.gather(
        indices, config.vocab_length, "Embedding_Dict")
    x = popart_model.norm(x)
    x = popart_model.dropout(x)
    with popart_model.builder.nameScope("CLS"):
        x = popart_model.lm_prediction_head(x)
    output = popart_model.projection(x)

    proto = builder.getModelProto()

    outputs, post_proto = run_py(proto, data, output)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [data[indices].reshape(
        config.micro_batch_size, config.sequence_length).astype(np.int32)]

    #  ------------------- PyTorch -------------------------
    torch_model = EmbeddingProjectionModel(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        no_cls_layer=config.no_cls_layer))

    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX,
                          TRANSPOSE_WEIGHTS)
    torch_model.tie_weights()

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
Esempio n. 10
0
def test_embedding(config, phase):
    # define input
    indices = np.random.randint(
        0, test_config.vocab_size,
        (test_config.batch_size, test_config.sequence_length)).astype(np.int32)
    positions = np.reshape(
        np.arange(test_config.sequence_length),
        (test_config.batch_size, test_config.sequence_length)).astype(np.int32)
    segments = np.random.randint(
        0, 2,
        (test_config.batch_size, test_config.sequence_length)).astype(np.int32)
    inputs = [d for d in [indices, positions, segments]]

    # build model
    # PyTorch model
    torch_config = TorchBertConfig(
        vocab_size_or_config_json_file=test_config.vocab_size,
        hidden_size=test_config.hidden_size,
        hidden_act=test_config.hidden_act,
        num_attention_heads=test_config.num_attention_heads,
        hidden_dropout_prob=test_config.hidden_dropout_prob,
        max_position_embeddings=test_config.max_position_embeddings,
        type_vocab_size=test_config.type_vocab_size,
        update_embedding_dict=True,
        layer_norm_eps=test_config.layer_norm_eps)
    torch_model = TorchBertEmbeddings(torch_config)
    torch_model.eval()

    # TF model
    tf_config = TFBertConfig(
        vocab_size=test_config.vocab_size,
        hidden_size=test_config.hidden_size,
        hidden_act=test_config.hidden_act,
        num_attention_heads=test_config.num_attention_heads,
        max_position_embeddings=test_config.max_position_embeddings,
        max_predictions_per_seq=test_config.max_predictions_per_seq,
        hidden_dropout_prob=test_config.hidden_dropout_prob,
        type_vocab_size=test_config.type_vocab_size,
        initializer_range=test_config.initializer_range,
        dtype=test_config.dtype,
        matmul_serialize_factor=test_config.matmul_serialize_factor,
        static_mask=False)

    # farward check
    if phase == "fwd":
        torch_outputs = run_fwd_model(inputs, torch_model)

        with tf.Graph().as_default():
            tf_model = TFBertModel(tf_config, is_training=True)

            with ops.device('cpu'):
                input_ids = tf.placeholder(shape=[
                    test_config.batch_size, test_config.sequence_length
                ],
                                           dtype=tf.int32)
                position_ids = tf.placeholder(shape=[
                    test_config.batch_size, test_config.sequence_length
                ],
                                              dtype=tf.int32)
                segment_ids = tf.placeholder(shape=[
                    test_config.batch_size, test_config.sequence_length
                ],
                                             dtype=tf.int32)
            cfg = utils.create_ipu_config()
            cfg = utils.auto_select_ipus(cfg, 1)
            utils.configure_ipu_system(cfg)
            utils.move_variable_initialization_to_cpu()
            with ops.device("/device:IPU:0"):
                opt = ipu_compiler.compile(
                    tf_model.embeddings_layer,
                    inputs=[input_ids, position_ids, segment_ids])

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                # copy pytorch weight to tf
                var_and_init = copy_torch_weights_to_tf(
                    torch_model, tf_model, TF_TO_TORCH, {}, sess)
                sess.run(var_and_init)
                # run tf feed feed farward
                tf_outputs = sess.run(
                    opt, {
                        input_ids: indices,
                        position_ids: positions,
                        segment_ids: segments
                    })
                # compare tf output with pytorch output
                check_tensors(tf_outputs, torch_outputs, margin=1.5e-8)

    # backward check
    elif phase == "bwd":
        l1_lambda = 0.1
        base_lr = 0.01
        optim = torch.optim.SGD(torch_model.parameters(),
                                base_lr,
                                weight_decay=0.0,
                                momentum=0.0)

        torch_output = torch_model(
            *[torch.from_numpy(t).long() for t in inputs])
        # pytorch backward
        torch_loss = l1_lambda * torch.norm(torch_output, 1)
        torch_loss.backward()  # calculate gradients
        optim.step()  # update gradients
        torch_outputs = [torch_output.detach().numpy()]

        # TF
        with tf.Graph().as_default():
            tf_model = TFBertModel(tf_config, is_training=True)
            with ops.device('cpu'):
                input_ids = tf.placeholder(shape=[
                    test_config.batch_size, test_config.sequence_length
                ],
                                           dtype=tf.int32)
                position_ids = tf.placeholder(shape=[
                    test_config.batch_size, test_config.sequence_length
                ],
                                              dtype=tf.int32)
                segment_ids = tf.placeholder(shape=[
                    test_config.batch_size, test_config.sequence_length
                ],
                                             dtype=tf.int32)
            cfg = utils.create_ipu_config()
            cfg = utils.auto_select_ipus(cfg, 1)
            utils.configure_ipu_system(cfg)
            utils.move_variable_initialization_to_cpu()

            def embedding_graph(input_ids, position_ids, segment_ids):
                embedding_output = tf_model.embeddings_layer(
                    input_ids, position_ids, segment_ids)
                l1_loss = l1_lambda * tf.norm(embedding_output, 1)
                optimizer = tf.train.GradientDescentOptimizer(base_lr)
                train_step = optimizer.minimize(l1_loss)
                return embedding_output, l1_loss, train_step

            with ops.device("/device:IPU:0"):
                opt = ipu_compiler.compile(
                    embedding_graph,
                    inputs=[input_ids, position_ids, segment_ids])

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                var_and_init = copy_torch_weights_to_tf(
                    torch_model, tf_model, TF_TO_TORCH, {}, sess)
                sess.run(var_and_init)
                tvars = sess.run({v.name: v for v in tf.trainable_variables()})
                print(tvars)
                tf_outputs, tf_loss = sess.run(
                    opt, {
                        input_ids: indices,
                        position_ids: positions,
                        segment_ids: segments
                    })
                # sess.run(opt, {input_ids: indices, position_ids: positions, segment_ids: segments})
                # Compare the farward output
                check_tf_torch_model(sess,
                                     torch_model,
                                     TF_TO_TORCH,
                                     margin=5e-7)
            check_tensors(torch_outputs, tf_outputs, margin=5e-7)
    else:
        raise ValueError(
            f"`phase` only can be set to [`fwd`, `bwd`] which mean farward or backward respectively."
        )
Esempio n. 11
0
def test_embedding_fwd(custom_ops, mode, batch_size,
                       batch_serialization_factor,
                       embedding_serialization_vocab_steps):
    #  ------------------- PopART --------------------
    config = BertConfig(
        task="SQUAD",
        vocab_length=9728,
        batch_size=batch_size,
        hidden_size=768,
        sequence_length=128,
        activation_type='relu',
        popart_dtype="FLOAT",
        no_dropout=True,
        inference=True,
        embedding_serialization_vocab_steps=embedding_serialization_vocab_steps
    )
    popart_model = get_model(config, mode, 'embedding')

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = popart_model.builder.addInputTensor(sequence_info)
    positions = popart_model.builder.addInputTensor(sequence_info)
    segments = popart_model.builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    user_options = {}
    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": batch_serialization_factor,
            "executionPhases": popart_model.total_execution_phases
        }
        output = popart_model(indices, positions, segments)
    else:
        user_options = {"enableStochasticRounding": True}
        with popart_model.builder.nameScope("Embedding"):
            output = popart_model.embedding(indices, positions, segments)

    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 user_options=user_options,
                                 execution_mode=mode)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    torch_model.eval()

    expanded_name_map, remapped_transform_map = expand_torch_to_onnx_map(
        TORCH_TO_ONNX[mode], config, mode)
    copy_weights_to_torch(torch_model, proto, expanded_name_map,
                          remapped_transform_map)

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs, margin=5e-7)
Esempio n. 12
0
def test_load_from_chkpt(config_path, chkpt_path, custom_ops):
    """
    Compare the model loaded into our popart model against the modified
    PyTorch model:
        - Load tf weights into BERT using torch impl -> run fwd model
        - Load tf weights into BERT using popart impl -> run fwd model
        - Compare output tensors
    """
    config = load_bert_config_tf(config_path)

    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })

    # Load Torch version
    torch_model = TorchModel(
        TorchBertConfig(
            config.vocab_length,
            config.hidden_size,
            num_hidden_layers=config.num_layers,
            num_attention_heads=config.attention_heads,
            intermediate_size=config.ff_size,
            hidden_act="relu",
            max_position_embeddings=config.max_positional_length,
            layer_norm_eps=config.layer_norm_eps,
            mask_tokens=config.mask_tokens,
        ))

    torch_model.eval()
    torch_model = load_tf_weights_in_bert(torch_model, config, chkpt_path)

    # Load Popart model
    sequence_info = popart.TensorInfo(
        "INT32", [config.batch_size * config.sequence_length])

    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)

    popart_model, proto, output = load_from_tf(chkpt_path,
                                               True,
                                               config,
                                               indices,
                                               positions,
                                               builder=builder)

    # Run the models
    popart_inputs = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32),
        positions:
        np.random.randint(
            0,
            config.sequence_length,
            (config.batch_size * config.sequence_length),
        ).astype(np.int32),
    }

    torch_inputs = {
        "input_ids":
        popart_inputs[indices].reshape(config.batch_size,
                                       config.sequence_length),
        "position_ids":
        popart_inputs[positions].reshape(config.batch_size,
                                         config.sequence_length),
    }

    torch_outputs = run_fwd_model(torch_inputs, torch_model)

    popart_outputs, post_proto = run_py(
        proto,
        popart_inputs,
        output,
        ipus=math.ceil(config.num_layers / config.layers_per_ipu) + 1,
    )

    check_tensors(torch_outputs, popart_outputs)
    print("Test succeeded")
def test_embedding_projection_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather'],
                        inference=True)
    popart_model = Bert(config, builder=builder)

    sequence_info = popart.TensorInfo(
        "INT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32)
    }

    x = popart_model.embedding_custom(
        indices, config.vocab_length, "Embedding_Dict", detach=True)
    x = popart_model.norm(x)
    x = popart_model.dropout(x)
    with popart_model.device_scope(nameScope="CLS"):
        x = popart_model.lm_prediction_head(x)
    output = popart_model.projection(x)

    proto = builder.getModelProto()

    outputs, post_proto = run_py(proto, data, output,
                                 user_options={"enableStochasticRounding": True})

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [data[indices].reshape(config.batch_size, config.sequence_length)]

    #  ------------------- PyTorch -------------------------
    torch_model = EmbeddingProjectionModel(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))

    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, torch_to_onnx,
                          transposed_weights)
    torch_model.tie_weights()

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
Esempio n. 14
0
def test_attention_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder()
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['attention'],
                        inference=True)

    popart_model = Bert(config, builder=builder)

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.batch_size * config.sequence_length, config.hidden_size])
    input_tensor = builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo("INT32", [config.batch_size])
    mmask_tensor = builder.addInputTensor(mask_info)
    smask_tensor = builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1,
                          (config.batch_size, )).astype(np.int32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1,
                          (config.batch_size, )).astype(np.int32)
    }

    output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor])
    proto = builder.getModelProto()

    outputs, post_proto = run_py(proto, data, output)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.batch_size, config.sequence_length,
                                   config.hidden_size).astype(np.float32),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    torch_to_onnx = {
        "self.query.weight": "QKV",
        "self.key.weight": "QKV",
        "self.value.weight": "QKV",
        "output.dense.weight": "Out",
        "output.LayerNorm.weight": "Gamma",
        "output.LayerNorm.bias": "Beta"
    }

    split_qkv = {
        "self.query.weight":
        lambda arr: arr[:, 0:config.hidden_size].T,
        "self.key.weight":
        lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T,
        "self.value.weight":
        lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T,
        "output.dense.weight":
        np.transpose
    }

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          torch_to_onnx,
                          transform=split_qkv)

    # Model to test against
    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
Esempio n. 15
0
def test_embedding_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        inference=True)
    popart_model = Bert(config, builder=builder)
    # Prevent virtualGraph attributes being added to the ops.
    popart_model.embedding_scope = popart_model.device_scope(None, None)
    popart_model.embedding_split_scope = popart_model.embedding_scope

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    segments = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    output = popart_model.embedding(indices, positions, segments)

    proto = builder.getModelProto()

    outputs, post_proto = run_py(proto, data, output)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {})

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
Esempio n. 16
0
def fwd_graph(popart_model,
              torch_model,
              mapping=None,
              transform=None,
              replication_factor=1,
              replicated_tensor_sharding=False):
    #  ------------------- PopART --------------------
    config = popart_model.config
    builder = popart_model.builder

    sequence_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    segments = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (replication_factor, config.micro_batch_size *
                           config.sequence_length)).astype(np.uint32),
        positions:
        np.random.randint(0, config.sequence_length,
                          (replication_factor, config.micro_batch_size *
                           config.sequence_length)).astype(np.uint32),
        segments:
        np.random.randint(0, 2, (replication_factor, config.micro_batch_size *
                                 config.sequence_length)).astype(np.uint32)
    }

    output = popart_model.build_graph(indices, positions, segments)
    ipus = popart_model.total_ipus

    proto = builder.getModelProto()

    outputs, _ = run_py(proto,
                        data,
                        output,
                        replication_factor=replication_factor,
                        replicated_tensor_sharding=replicated_tensor_sharding,
                        ipus=ipus)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = {
        "input_ids":
        data[indices].reshape(replication_factor * config.micro_batch_size,
                              config.sequence_length).astype(np.int32),
        "position_ids":
        data[positions].reshape(replication_factor * config.micro_batch_size,
                                config.sequence_length).astype(np.int32),
        "token_type_ids":
        data[segments].reshape(replication_factor * config.micro_batch_size,
                               config.sequence_length).astype(np.int32)
    }

    torch_to_onnx = get_mapping(config, init=mapping)

    transform_weights = get_transform(config, init=transform)

    #  ------------------- PyTorch -------------------------
    # Turn off dropout
    torch_model.eval()
    copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform_weights)

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
Esempio n. 17
0
def test_attention_fwd(mode, micro_batch_size, batch_serialisation_factor,
                       number_attention_splits, attention_bias, split_qkv):
    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        micro_batch_size=micro_batch_size,
                        hidden_size=768,
                        attention_heads=4,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        inference=True,
                        split_qkv=split_qkv,
                        attention_bias=attention_bias,
                        num_attention_splits=number_attention_splits)
    popart_model = get_model(config, mode, 'attention')

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.micro_batch_size * config.sequence_length, config.hidden_size])
    input_tensor = popart_model.builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size, config.sequence_length])
    mmask_tensor = popart_model.builder.addInputTensor(mask_info)
    smask_tensor = popart_model.builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32)
    }

    user_options = {}
    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": batch_serialisation_factor,
            "executionPhases": popart_model.total_execution_phases
        }
        output = popart_model(input_tensor, [mmask_tensor, smask_tensor])
    else:
        user_options = {"enableStochasticRounding": True}
        output = popart_model.attention(input_tensor,
                                        [mmask_tensor, smask_tensor])

    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 user_options=user_options,
                                 execution_mode=mode)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.micro_batch_size,
                                   config.sequence_length,
                                   config.hidden_size).astype(np.float32),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        attention_bias=config.attention_bias,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()
    mapping = TORCH_TO_ONNX[mode]
    if split_qkv:
        mapping = TORCH_TO_ONNX_SPLIT_QKV[mode]
    copy_weights_to_torch(torch_model,
                          proto,
                          mapping,
                          transform=get_transform(split_qkv,
                                                  config.hidden_size))

    # Model to test against
    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)