def test_embedding_projection_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather'],
                        inference=True)
    popart_model = Bert(config, builder=builder)

    sequence_info = popart.TensorInfo(
        "INT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32)
    }

    x = popart_model.embedding_custom(
        indices, config.vocab_length, "Embedding_Dict", detach=True)
    x = popart_model.norm(x)
    x = popart_model.dropout(x)
    with popart_model.device_scope(nameScope="CLS"):
        x = popart_model.lm_prediction_head(x)
    output = popart_model.projection(x)

    proto = builder.getModelProto()

    outputs, post_proto = run_py(proto, data, output,
                                 user_options={"enableStochasticRounding": True})

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [data[indices].reshape(config.batch_size, config.sequence_length)]

    #  ------------------- PyTorch -------------------------
    torch_model = EmbeddingProjectionModel(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))

    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, torch_to_onnx,
                          transposed_weights)
    torch_model.tie_weights()

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
def test_embedding_projection_bwd(custom_ops):
    l1_lambda = 0.1

    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather'])
    popart_model = Bert(config, builder=builder)

    sequence_info = popart.TensorInfo(
        "INT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32)
    }

    x = popart_model.embedding_custom(
        indices, config.vocab_length, "Embedding_Dict", detach=True)
    x = popart_model.norm(x)
    x = popart_model.dropout(x)
    with popart_model.device_scope(nameScope="CLS"):
        x = popart_model.lm_prediction_head(x)
    output = popart_model.projection(x)

    proto = builder.getModelProto()

    l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)
    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(proto,
                                 data, output,
                                 loss=l1,
                                 optimizer=optimizer,
                                 user_options={"enableStochasticRounding": True})

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [data[indices].reshape(config.batch_size, config.sequence_length)]

    #  ------------------- PyTorch -------------------------

    torch_model = EmbeddingProjectionModel(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          torch_to_onnx,
                          transform=transposed_weights)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs])
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    check_tensors([torch_output.detach().numpy()], outputs, margin=1e-5)

    check_model(torch_model,
                post_proto,
                torch_to_onnx,
                transform=transposed_weights)