Пример #1
0
def test_squad_fwd():
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        num_layers=2,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=[],
                        inference=True)
    popart_model = Bert(config, builder=builder)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForQuestionAnswering(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    fwd_graph(popart_model,
              torch_model,
              mapping={
                  "cls.transform.dense.weight": "CLS/LMPredictionW",
                  "cls.transform.dense.bias": "CLS/LMPredictionB",
                  "cls.transform.LayerNorm.weight": "CLS/Gamma",
                  "cls.transform.LayerNorm.bias": "CLS/Beta",
                  "qa_outputs.weight": "Squad/SquadW",
                  "qa_outputs.bias": "Squad/SquadB"
              },
              transform={
                  "cls.transform.dense.weight": np.transpose,
                  "qa_outputs.weight": np.transpose
              })
Пример #2
0
def test_squad_fwd(mode, replication_factor, replicated_weight_sharding):
    split_qkv = False
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        num_layers=2,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        inference=True,
                        no_mask=True,
                        execution_mode=mode,
                        split_qkv=split_qkv,
                        squad_single_output=False)

    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForQuestionAnswering(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    fwd_graph(popart_model,
              torch_model,
              mode,
              mapping=ONNX_TORCH_MAPPING[mode],
              transform={"qa_outputs.weight": np.transpose},
              replication_factor=replication_factor,
              replicated_weight_sharding=replicated_weight_sharding)
Пример #3
0
def test_squad_fwd(custom_ops):
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        encoder_start_ipu=1,
                        vocab_length=1024,
                        micro_batch_size=1,
                        hidden_size=64,
                        attention_heads=2,
                        sequence_length=20,
                        max_positional_length=20,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        inference=True,
                        no_mask=True,
                        split_qkv=False,
                        squad_single_output=False)

    popart_model = Bert(config)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForQuestionAnswering(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=2,
                        num_labels=2))

    fwd_graph(popart_model,
              torch_model,
              mapping=ONNX_TORCH_MAPPING,
              transform={"qa_outputs.weight": np.transpose})
Пример #4
0
def test_squad_bwd():
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        num_layers=1,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=[],
                        update_embedding_dict=False)
    popart_model = Bert(config, builder=builder)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForQuestionAnswering(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    l1_lambda = 0.1

    def popart_loss_fn(outputs):
        losses = [
            popart.L1Loss(outputs[0], "startsLossVal", l1_lambda),
            popart.L1Loss(outputs[1], "endsLossVal", l1_lambda),
        ]
        for loss in losses:
            loss.virtualGraph(popart_model.squad_scope.virtualGraph)
        return losses

    def torch_loss_fn(outputs):
        torch_losses = [
            l1_lambda * torch.norm(output, 1) for output in outputs
        ]
        return torch.add(*torch_losses)

    bwd_graph(popart_model,
              torch_model,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping={
                  "cls.transform.dense.weight": "CLS/LMPredictionW",
                  "cls.transform.dense.bias": "CLS/LMPredictionB",
                  "cls.transform.LayerNorm.weight": "CLS/Gamma",
                  "cls.transform.LayerNorm.bias": "CLS/Beta",
                  "qa_outputs.weight": "Squad/SquadW",
                  "qa_outputs.bias": "Squad/SquadB"
              },
              transform={
                  "cls.transform.dense.weight": np.transpose,
                  "qa_outputs.weight": np.transpose
              })
Пример #5
0
def squad_bwd(mode,
              replication_factor,
              replicated_weight_sharding,
              opt_type,
              vocab_length=9728,
              hidden_size=768):
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        vocab_length=vocab_length,
                        num_layers=1,
                        batch_size=1,
                        hidden_size=hidden_size,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        update_embedding_dict=True,
                        no_mask=True,
                        execution_mode=mode,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForQuestionAnswering(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        update_embedding_dict=True,
                        num_labels=2))

    l1_lambda = 0.1

    def popart_loss_fn(outputs):
        if mode == ExecutionMode.PHASED:
            with popart_model.scope_provider(popart_model.builder,
                                             popart_model.squad_scope):
                losses = [
                    popart_model.builder.aiGraphcore.l1loss(
                        [outputs[0]],
                        l1_lambda,
                        debugPrefix="startsLossVal",
                        reduction=popart.ReductionType.Sum),
                    popart_model.builder.aiGraphcore.l1loss(
                        [outputs[1]],
                        l1_lambda,
                        debugPrefix="endsLossVal",
                        reduction=popart.ReductionType.Sum),
                ]
                final_loss = popart_model.builder.aiOnnx.sum(
                    losses, debugPrefix="finalLoss")

        else:
            losses = [
                popart_model.builder.aiGraphcore.l1loss(
                    [outputs[0]],
                    l1_lambda,
                    debugPrefix="startsLossVal",
                    reduction=popart.ReductionType.Sum),
                popart_model.builder.aiGraphcore.l1loss(
                    [outputs[1]],
                    l1_lambda,
                    debugPrefix="endsLossVal",
                    reduction=popart.ReductionType.Sum),
            ]
            for loss in losses:
                popart_model.builder.virtualGraph(
                    loss, popart_model.squad_scope.virtualGraph)

            final_loss = popart_model.builder.aiOnnx.sum(
                losses, debugPrefix="finalLoss")
            popart_model.builder.virtualGraph(
                final_loss, popart_model.squad_scope.virtualGraph)
        return final_loss

    def torch_loss_fn(outputs):
        torch_losses = [
            l1_lambda * torch.norm(output, 1) for output in outputs
        ]
        return torch.add(*torch_losses)

    bwd_graph(popart_model,
              torch_model,
              mode,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping=ONNX_TORCH_MAPPING[mode],
              transform={"qa_outputs.weight": np.transpose},
              replication_factor=replication_factor,
              replicated_weight_sharding=replicated_weight_sharding,
              opt_type=opt_type)
Пример #6
0
def test_squad_bwd(custom_ops, replication_factor, replicated_tensor_sharding,
                   opt_type):
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        num_layers=2,
                        encoder_start_ipu=1,
                        vocab_length=1024,
                        micro_batch_size=1,
                        hidden_size=64,
                        attention_heads=2,
                        sequence_length=20,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        update_embedding_dict=True,
                        no_mask=True,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = Bert(config)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForQuestionAnswering(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=2,
                        update_embedding_dict=True,
                        num_labels=2))

    l1_lambda = 0.1

    def popart_loss_fn(outputs):
        losses = [
            popart_model.builder.aiGraphcore.l1loss(
                [outputs[0]],
                l1_lambda,
                debugContext="startsLossVal",
                reduction=popart.ReductionType.Sum),
            popart_model.builder.aiGraphcore.l1loss(
                [outputs[1]],
                l1_lambda,
                debugContext="endsLossVal",
                reduction=popart.ReductionType.Sum),
        ]
        for loss in losses:
            popart_model.builder.virtualGraph(
                loss, popart_model.squad_scope.virtualGraph)

        final_loss = popart_model.builder.aiOnnx.sum(losses,
                                                     debugContext="finalLoss")
        popart_model.builder.virtualGraph(
            final_loss, popart_model.squad_scope.virtualGraph)
        return final_loss

    def torch_loss_fn(outputs):
        torch_losses = [
            l1_lambda * torch.norm(output, 1) for output in outputs
        ]
        return torch.add(*torch_losses)

    bwd_graph(popart_model,
              torch_model,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping=ONNX_TORCH_MAPPING,
              transform={"qa_outputs.weight": np.transpose},
              replication_factor=replication_factor,
              replicated_tensor_sharding=replicated_tensor_sharding,
              opt_type=opt_type)