Ejemplo n.º 1
0
def test_pretraining_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1})
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        num_layers=2,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        popart_dtype="FLOAT",
                        activation_type="relu",
                        no_dropout=True,
                        custom_ops=["gather", "attention"],
                        inference=True)
    popart_model = Bert(config, builder=builder)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForMaskedLM(
        TorchBertConfig(config.vocab_length, config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens))

    fwd_graph(popart_model, torch_model, mapping=onnx_torch_mapping, transform=onnx_torch_tform)
Ejemplo n.º 2
0
def test_pretraining_bwd(custom_ops, opt_type):
    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        encoder_start_ipu=1,
                        vocab_length=1024,
                        micro_batch_size=1,
                        hidden_size=64,
                        attention_heads=2,
                        sequence_length=20,
                        max_positional_length=20,
                        mask_tokens=2,
                        popart_dtype="FLOAT",
                        activation_type="relu",
                        no_dropout=True,
                        no_attn_dropout=True,
                        update_embedding_dict=True,
                        no_cls_layer=True,
                        no_mask=True,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = Bert(config)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForMaskedLM(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        update_embedding_dict=True,
                        mask_tokens=config.mask_tokens))

    l1_lambda = 0.1

    def popart_loss_fn(logits):
        loss = popart_model.builder.aiGraphcore.l1loss(
            [logits[0]],
            l1_lambda,
            debugContext="l1LossVal",
            reduction=popart.ReductionType.Sum)
        popart_model.builder.virtualGraph(loss,
                                          popart_model.mlm_scope.virtualGraph)
        return loss

    bwd_graph(
        popart_model,
        torch_model,
        popart_loss_fn=popart_loss_fn,
        torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1),
        mapping={},
        transform=onnx_torch_tform,
        opt_type=opt_type)
Ejemplo n.º 3
0
def test_pretraining_bwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        projection_serialization_steps=4,
                        num_layers=1,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        popart_dtype="FLOAT",
                        activation_type="relu",
                        no_dropout=True,
                        update_embedding_dict=False)
    popart_model = Bert(config, builder=builder)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForMaskedLM(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens))

    l1_lambda = 0.1

    def popart_loss_fn(logits):
        loss = builder.aiGraphcore.l1loss([logits[0]],
                                          l1_lambda,
                                          debugPrefix="l1LossVal",
                                          reduction=popart.ReductionType.Sum)
        builder.virtualGraph(loss, popart_model.mlm_scope.virtualGraph)
        return loss

    bwd_graph(
        popart_model,
        torch_model,
        popart_loss_fn=popart_loss_fn,
        torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1),
        mapping=onnx_torch_mapping,
        transform=onnx_torch_tform)
Ejemplo n.º 4
0
def test_pretraining_fwd(custom_ops, mode, replication_factor,
                         replicated_tensor_sharding):
    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        encoder_start_ipu=1,
                        vocab_length=1024,
                        micro_batch_size=1,
                        hidden_size=64,
                        attention_heads=2,
                        sequence_length=20,
                        max_positional_length=20,
                        mask_tokens=2,
                        popart_dtype="FLOAT",
                        activation_type="relu",
                        no_dropout=True,
                        no_attn_dropout=True,
                        no_cls_layer=False,
                        inference=True,
                        no_mask=True,
                        execution_mode=mode,
                        split_qkv=False)

    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForMaskedLM(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        no_cls_layer=config.no_cls_layer))

    fwd_graph(popart_model,
              torch_model,
              mode,
              mapping=ONNX_TORCH_MAPPING[mode],
              transform=onnx_torch_tform,
              replication_factor=replication_factor,
              replicated_tensor_sharding=replicated_tensor_sharding)
Ejemplo n.º 5
0
def pretraining_bwd(custom_ops,
                    mode,
                    replication_factor,
                    replicated_weight_sharding,
                    opt_type,
                    vocab_length=9728,
                    hidden_size=768):
    #  ------------------- PopART --------------------
    if mode == ExecutionMode.PHASED:
        # Phased Execution requires atleast two transformer layers to ensure mlm and embedding are in the same virtual graph.
        num_layers = 2
    else:
        num_layers = 1
    config = BertConfig(task="PRETRAINING",
                        vocab_length=vocab_length,
                        num_layers=num_layers,
                        batch_size=1,
                        hidden_size=hidden_size,
                        sequence_length=128,
                        popart_dtype="FLOAT",
                        activation_type="relu",
                        no_dropout=True,
                        no_attn_dropout=True,
                        update_embedding_dict=True,
                        no_cls_layer=True,
                        no_mask=True,
                        phased_execution_type="single",
                        execution_mode=mode,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForMaskedLM(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        update_embedding_dict=True,
                        mask_tokens=config.mask_tokens))

    l1_lambda = 0.1

    def popart_loss_fn(logits):
        if mode == ExecutionMode.PHASED:
            with popart_model.scope_provider(popart_model.builder,
                                             popart_model.mlm_scope):
                loss = popart_model.builder.aiGraphcore.l1loss(
                    [logits[0]],
                    l1_lambda,
                    debugPrefix="l1LossVal",
                    reduction=popart.ReductionType.Sum)
        else:
            loss = popart_model.builder.aiGraphcore.l1loss(
                [logits[0]],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
            popart_model.builder.virtualGraph(
                loss, popart_model.mlm_scope.virtualGraph)
        return loss

    bwd_graph(
        popart_model,
        torch_model,
        mode,
        popart_loss_fn=popart_loss_fn,
        torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1),
        mapping={},
        transform=onnx_torch_tform,
        replication_factor=replication_factor,
        replicated_weight_sharding=replicated_weight_sharding,
        opt_type=opt_type)
Ejemplo n.º 6
0
def test_load_from_chkpt(config_path, chkpt_path, custom_ops):
    """
    Compare the model loaded into our popart model against the modified
    PyTorch model:
        - Load tf weights into BERT using torch impl -> run fwd model
        - Load tf weights into BERT using popart impl -> run fwd model
        - Compare output tensors
    """
    config = load_bert_config_tf(config_path)

    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })

    # Load Torch version
    torch_model = TorchModel(
        TorchBertConfig(
            config.vocab_length,
            config.hidden_size,
            num_hidden_layers=config.num_layers,
            num_attention_heads=config.attention_heads,
            intermediate_size=config.ff_size,
            hidden_act="relu",
            max_position_embeddings=config.max_positional_length,
            layer_norm_eps=config.layer_norm_eps,
            mask_tokens=config.mask_tokens,
        ))

    torch_model.eval()
    torch_model = load_tf_weights_in_bert(torch_model, config, chkpt_path)

    # Load Popart model
    sequence_info = popart.TensorInfo(
        "INT32", [config.batch_size * config.sequence_length])

    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)

    popart_model, proto, output = load_from_tf(chkpt_path,
                                               True,
                                               config,
                                               indices,
                                               positions,
                                               builder=builder)

    # Run the models
    popart_inputs = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32),
        positions:
        np.random.randint(
            0,
            config.sequence_length,
            (config.batch_size * config.sequence_length),
        ).astype(np.int32),
    }

    torch_inputs = {
        "input_ids":
        popart_inputs[indices].reshape(config.batch_size,
                                       config.sequence_length),
        "position_ids":
        popart_inputs[positions].reshape(config.batch_size,
                                         config.sequence_length),
    }

    torch_outputs = run_fwd_model(torch_inputs, torch_model)

    popart_outputs, post_proto = run_py(
        proto,
        popart_inputs,
        output,
        ipus=math.ceil(config.num_layers / config.layers_per_ipu) + 1,
    )

    check_tensors(torch_outputs, popart_outputs)
    print("Test succeeded")