Exemplo n.º 1
0
def test_attention_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder()
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['attention'],
                        inference=True)

    popart_model = Bert(config, builder=builder)

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.batch_size * config.sequence_length, config.hidden_size])
    input_tensor = builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo("INT32", [config.batch_size])
    mmask_tensor = builder.addInputTensor(mask_info)
    smask_tensor = builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1,
                          (config.batch_size, )).astype(np.int32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1,
                          (config.batch_size, )).astype(np.int32)
    }

    output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor])
    proto = builder.getModelProto()

    outputs, post_proto = run_py(proto, data, output)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.batch_size, config.sequence_length,
                                   config.hidden_size).astype(np.float32),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    torch_to_onnx = {
        "self.query.weight": "QKV",
        "self.key.weight": "QKV",
        "self.value.weight": "QKV",
        "output.dense.weight": "Out",
        "output.LayerNorm.weight": "Gamma",
        "output.LayerNorm.bias": "Beta"
    }

    split_qkv = {
        "self.query.weight":
        lambda arr: arr[:, 0:config.hidden_size].T,
        "self.key.weight":
        lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T,
        "self.value.weight":
        lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T,
        "output.dense.weight":
        np.transpose
    }

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          torch_to_onnx,
                          transform=split_qkv)

    # Model to test against
    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
Exemplo n.º 2
0
def test_attention_bwd(attention_bias, split_qkv):
    l1_lambda = 0.1
    num_reps = 5
    np.random.seed(1984)
    torch.manual_seed(1984)

    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        micro_batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        split_qkv=split_qkv,
                        attention_bias=attention_bias)
    popart_model = Bert(config, pipeline=True)

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.micro_batch_size * config.sequence_length, config.hidden_size])
    input_tensor = popart_model.builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size, config.sequence_length])
    mmask_tensor = popart_model.builder.addInputTensor(mask_info)
    smask_tensor = popart_model.builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32)
    }

    user_options = {}
    output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor])
    l1 = popart_model.builder.aiGraphcore.l1loss(
        [output], l1_lambda, reduction=popart.ReductionType.Sum)

    proto = popart_model.builder.getModelProto()
    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(proto,
                                 data, (output, l1),
                                 loss=l1,
                                 optimizer=optimizer,
                                 num_reps=num_reps,
                                 user_options=user_options,
                                 pipeline=popart_model.pipeline)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.micro_batch_size,
                                   config.sequence_length, config.hidden_size),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        attention_bias=config.attention_bias,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    mapping = TORCH_TO_ONNX_SPLIT_QKV if split_qkv else TORCH_TO_ONNX

    copy_weights_to_torch(torch_model,
                          proto,
                          mapping,
                          transform=get_transform(split_qkv,
                                                  config.hidden_size))

    optim = torch.optim.SGD(torch_model.parameters(), 0.01)

    for _ in range(num_reps):
        torch_output = torch_model(
            *[torch.from_numpy(t).float() for t in inputs])[0]
        torch_loss = l1_lambda * torch.norm(torch_output, 1)
        torch_loss.backward()
        optim.step()
        optim.zero_grad()

    check_tensors([torch_output.detach().numpy()], outputs, margin=6e-07)

    check_model(torch_model,
                post_proto,
                mapping,
                transform=get_transform(split_qkv, config.hidden_size),
                margin=2e-7)
Exemplo n.º 3
0
def test_attention_bwd(custom_ops):
    l1_lambda = 0.1

    #  ------------------- PopART --------------------
    builder = popart.Builder()
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['attention'])
    popart_model = Bert(config, builder=builder)

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.batch_size * config.sequence_length, config.hidden_size])
    input_tensor = builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo("INT32", [config.batch_size])
    mmask_tensor = builder.addInputTensor(mask_info)
    smask_tensor = builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1,
                          (config.batch_size, )).astype(np.int32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1,
                          (config.batch_size, )).astype(np.int32)
    }

    output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor])
    proto = builder.getModelProto()

    l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)
    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(proto,
                                 data, (output, l1.output(0)),
                                 loss=l1,
                                 optimizer=optimizer)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.batch_size, config.sequence_length,
                                   config.hidden_size),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    torch_to_onnx = {
        "self.query.weight": "QKV",
        "self.key.weight": "QKV",
        "self.value.weight": "QKV",
        "output.dense.weight": "Out",
        "output.LayerNorm.weight": "Gamma",
        "output.LayerNorm.bias": "Beta"
    }

    split_qkv = {
        "self.query.weight":
        lambda arr: arr[:, 0:config.hidden_size].T,
        "self.key.weight":
        lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T,
        "self.value.weight":
        lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T,
        "output.dense.weight":
        np.transpose
    }

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          torch_to_onnx,
                          transform=split_qkv)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).float()
                                 for t in inputs])[0]
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    check_tensors([torch_output.detach().numpy()], outputs)

    check_model(torch_model, post_proto, torch_to_onnx, transform=split_qkv)
Exemplo n.º 4
0
def test_attention_fwd(attention_bias, split_qkv):
    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        micro_batch_size=1,
                        hidden_size=768,
                        attention_heads=4,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        inference=True,
                        split_qkv=split_qkv,
                        attention_bias=attention_bias)
    popart_model = Bert(config, pipeline=True)

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.micro_batch_size * config.sequence_length, config.hidden_size])
    input_tensor = popart_model.builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size, config.sequence_length])
    mmask_tensor = popart_model.builder.addInputTensor(mask_info)
    smask_tensor = popart_model.builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32)
    }

    user_options = {"enableStochasticRounding": True}
    output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor])

    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 user_options=user_options,
                                 pipeline=popart_model.pipeline)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.micro_batch_size,
                                   config.sequence_length,
                                   config.hidden_size).astype(np.float32),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        attention_bias=config.attention_bias,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()
    mapping = TORCH_TO_ONNX_SPLIT_QKV if split_qkv else TORCH_TO_ONNX
    copy_weights_to_torch(torch_model,
                          proto,
                          mapping,
                          transform=get_transform(split_qkv,
                                                  config.hidden_size))

    # Model to test against
    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)