Exemplo n.º 1
0
def test_nsp_fwd(custom_ops):
    #  ------------------- PopART --------------------
    config = BertConfig(task="NSP",
                        vocab_length=9728,
                        num_layers=2,
                        micro_batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        inference=True,
                        no_mask=True,
                        mask_tokens=0,
                        split_qkv=False)
    popart_model = Bert(config)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForNextSentencePrediction(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act=config.activation_type,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    fwd_graph(popart_model, torch_model, NSP_MAPPING, transform=NSP_TRANSFORM)
Exemplo n.º 2
0
def test_nsp_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(
        opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1})
    config = BertConfig(task="NSP",
                        vocab_length=9728,
                        num_layers=2,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=["gather", "attention"],
                        inference=True)
    popart_model = Bert(config, builder=builder)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForNextSentencePrediction(
        TorchBertConfig(config.vocab_length, config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act=config.activation_type,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    fwd_graph(popart_model,
              torch_model,
              mapping=NSP_MAPPING,
              transform=NSP_TRANSFORM)
Exemplo n.º 3
0
def test_nsp_bwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(task="NSP",
                        vocab_length=9728,
                        num_layers=1,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=["gather", "attention"])
    popart_model = Bert(config, builder=builder)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForNextSentencePrediction(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    def popart_loss_fn(outputs):
        loss = popart.L1Loss(outputs[0], "l1Loss", 0.1)
        loss.virtualGraph(popart_model.nsp_scope.virtualGraph)
        return [loss]

    def torch_loss_fn(outputs):
        return 0.1 * torch.norm(outputs[0], 1)

    bwd_graph(popart_model,
              torch_model,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping={
                  "bert.pooler.dense.weight": "NSP/PoolW",
                  "bert.pooler.dense.bias": "NSP/PoolB",
                  "cls.seq_relationship.weight": "NSP/NspW",
                  "cls.seq_relationship.bias": "NSP/NspB"
              },
              transform={
                  "bert.pooler.dense.weight": np.transpose,
                  "cls.seq_relationship.weight": np.transpose
              })
Exemplo n.º 4
0
def test_nsp_bwd(custom_ops, opt_type):
    #  ------------------- PopART --------------------
    config = BertConfig(task="NSP",
                        vocab_length=2432,
                        micro_batch_size=1,
                        hidden_size=288,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        no_mask=True,
                        update_embedding_dict=True,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = Bert(config)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForNextSentencePrediction(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act=config.activation_type,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        update_embedding_dict=True,
                        num_labels=2))
    l1_lambda = 0.1

    def popart_loss_fn(outputs):
        loss = popart_model.builder.aiGraphcore.l1loss(
            [outputs[0]],
            l1_lambda,
            debugContext="l1LossVal",
            reduction=popart.ReductionType.Sum)
        popart_model.builder.virtualGraph(loss,
                                          popart_model.nsp_scope.virtualGraph)
        return loss

    def torch_loss_fn(outputs):
        return l1_lambda * torch.norm(outputs[0], 1)

    bwd_graph(popart_model,
              torch_model,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping=NSP_MAPPING,
              transform=NSP_TRANSFORM,
              opt_type=opt_type)
Exemplo n.º 5
0
def test_nsp_bwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(task="NSP",
                        vocab_length=9728,
                        num_layers=1,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True)
    popart_model = Bert(config, builder=builder)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForNextSentencePrediction(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act=config.activation_type,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    def popart_loss_fn(outputs):
        loss = builder.aiGraphcore.l1loss([outputs[0]],
                                          0.1,
                                          debugPrefix="l1Loss",
                                          reduction=popart.ReductionType.Sum)
        builder.virtualGraph(loss, popart_model.nsp_scope.virtualGraph)
        return loss

    def torch_loss_fn(outputs):
        return 0.1 * torch.norm(outputs[0], 1)

    bwd_graph(popart_model,
              torch_model,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping=NSP_MAPPING,
              transform=NSP_TRANSFORM)
Exemplo n.º 6
0
def nsp_bwd(custom_ops, mode, opt_type, vocab_length=9728, hidden_size=768):
    if mode == ExecutionMode.PHASED:
        # Phased Execution requires atleast two transformer layers to ensure mlm and embedding are in the same virtual graph.
        num_layers = 2
    else:
        num_layers = 1

    #  ------------------- PopART --------------------
    config = BertConfig(task="NSP",
                        vocab_length=vocab_length,
                        num_layers=num_layers,
                        batch_size=1,
                        hidden_size=hidden_size,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        no_mask=True,
                        update_embedding_dict=True,
                        phased_execution_type="single",
                        execution_mode=mode,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForNextSentencePrediction(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act=config.activation_type,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        update_embedding_dict=True,
                        num_labels=2))
    l1_lambda = 0.1

    def popart_loss_fn(outputs):
        if mode == ExecutionMode.PHASED:
            with popart_model.scope_provider(popart_model.builder,
                                             popart_model.nsp_scope):
                loss = popart_model.builder.aiGraphcore.l1loss(
                    [outputs[0]],
                    l1_lambda,
                    debugPrefix="l1LossVal",
                    reduction=popart.ReductionType.Sum)
        else:
            loss = popart_model.builder.aiGraphcore.l1loss(
                [outputs[0]],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
            popart_model.builder.virtualGraph(
                loss, popart_model.nsp_scope.virtualGraph)
        return loss

    def torch_loss_fn(outputs):
        return l1_lambda * torch.norm(outputs[0], 1)

    bwd_graph(popart_model,
              torch_model,
              mode,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping=NSP_MAPPING[mode],
              transform=NSP_TRANSFORM,
              opt_type=opt_type)