Exemplo n.º 1
0
 def popart_loss_fn(outputs):
     losses = [
         popart.L1Loss(outputs[0], "startsLossVal", l1_lambda),
         popart.L1Loss(outputs[1], "endsLossVal", l1_lambda),
     ]
     for loss in losses:
         loss.virtualGraph(popart_model.squad_scope.virtualGraph)
     return losses
Exemplo n.º 2
0
def popart_result_and_model(popart_config, is_bwd=False):
    builder = popart.Builder()
    popart_model = Bert(popart_config, builder=builder)

    input_info = popart.TensorInfo(popart_config.popart_dtype, [
        popart_config.batch_size * popart_config.sequence_length,
        popart_config.hidden_size
    ])
    input_tensor = builder.addInputTensor(input_info)

    data = {
        input_tensor:
        np.random.normal(0, 0.02,
                         input_info.shape()).astype(popart_config.dtype)
    }

    output = popart_model.feed_forward(input_tensor)
    proto = builder.getModelProto()

    if is_bwd:
        l1_lambda = 0.1
        l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)
        optimizer = popart.ConstSGD(0.01)

        outputs, post_proto = run_py(proto,
                                     data, (output, l1.output(0)),
                                     loss=l1,
                                     optimizer=optimizer)
    else:
        outputs, post_proto = run_py(proto, data, output)

    return data[input_tensor], outputs, proto, post_proto
Exemplo n.º 3
0
def test_outline_dropout_pattern_one(custom_ops):
    '''
    Tests that the OutlineDropoutPattern successfully outlines all 3 dropouts (fwd, bwd) into a single subgraph
    Expected IR Graph (excluding adds etc)
    fwd...
        x = add(data0, weight0)
        0_seed = seedModify(seed, 0)
        x = call_0(x, 0_seed)
        1_seed = seedModify(seed, 1)
        x = call_0(x, 1_seed)
        2_seed = seedModify(seed, 2)
        x = call_0(x, 2_seed)
    bwd...
        x = call_0(x, 0_seed)
        x = call_0(x, 1_seed)
        x = call_0(x, 2_seed)

        where call_0(x, seed) = dropout(x, seed)
    '''

    input_data = np.random.rand(2, 2).astype(np.float32)

    builder = popart.Builder()

    d0 = builder.addInputTensor(popart.TensorInfo('FLOAT', input_data.shape),
                                'data0')

    w0 = builder.addInitializedInputTensor(input_data, 'weight0')

    x = builder.aiOnnx.add([d0, w0])

    x = builder.aiOnnx.dropout([x], 1)[0]

    x = builder.aiOnnx.dropout([x], 1)[0]

    x = builder.aiOnnx.dropout([x], 1)[0]

    session = run_py(builder.getModelProto(),
                     data={d0: input_data},
                     outputs=x,
                     loss=popart.L1Loss(x, 'loss', 0.1),
                     optimizer=popart.ConstSGD(0.1),
                     patterns=popart.Patterns(
                         ["OutlineDropoutPattern", "PostNRepl"]),
                     user_options={"outlineThreshold": -1},
                     skip_execution=True)

    ir = json.loads(session._serializeIr(popart.IrSerializationFormat.JSON))

    # There should only be a main graph and 1 subgraph containing dropout
    assert len(ir.keys()) == 2

    ops = [o["type"] for o in ir["_subgraph(0)"]]
    assert "Dropout" in ops

    ops = [o["type"] for o in ir["maingraph"]]
    # Should only be 1 seed modify per dropout
    assert len(list(filter(lambda op: op == "SeedModify", ops))) == 6
    # The bwd and fwd should be outlined together
    assert len(list(filter(lambda op: op == "Call", ops))) == 6
Exemplo n.º 4
0
def popart_result_and_model(popart_config, weight_decay=0, lr=0, l1_lambda=0):
    builder = popart.Builder()
    popart_model = Bert(popart_config, builder=builder)

    input_info = popart.TensorInfo(popart_config.popart_dtype, [
        popart_config.batch_size * popart_config.sequence_length,
        popart_config.hidden_size
    ])
    input_tensor = builder.addInputTensor(input_info)

    data = {
        input_tensor:
        np.random.normal(0, 0.02,
                         input_info.shape()).astype(popart_config.dtype)
    }

    output = popart_model.feed_forward(input_tensor)
    proto = builder.getModelProto()

    l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)

    iteration = MockIteration()
    args = MockArgs(lr, weight_decay)
    optimizer_factory = BaseOptimizerFactory(args, iteration,
                                             popart_model.tensors)
    optimizer = optimizer_factory.create()

    outputs, post_proto = run_py(proto,
                                 data, (output, l1.output(0)),
                                 loss=l1,
                                 optimizer=optimizer)

    return data[input_tensor], outputs, proto, post_proto
Exemplo n.º 5
0
def session(skip_execution=False, include_patterns=True, momentum=False):
    proto, data, x = model()
    # Required
    patterns = [
        "MatMulOp", "MatMulLhsGradOp", "MatMulRhsGradOp", "OpToIdentity",
        "PreUniRepl", "PostNRepl", "InPlace"
    ]
    if include_patterns:
        patterns += ["InplaceWorkaroundPattern"]
    optimizer = popart.ConstSGD(0.1)
    if momentum:
        optimizer = popart.SGD({
            "defaultLearningRate": (0.1, True),
            "defaultMomentum": (0.9, True)
        })
    return run_py(proto,
                  data=data,
                  outputs=x,
                  loss=popart.L1Loss(x, 'loss', 0.1),
                  optimizer=optimizer,
                  patterns=popart.Patterns(patterns),
                  user_options={"enableOutlining": False},
                  skip_execution=skip_execution)
Exemplo n.º 6
0
def session(train=False,
            skip_execution=False,
            include_patterns=True,
            splits=1,
            outline=False):
    proto, data, x = model(splits=splits)
    # Required
    patterns = [
        "MatMulOp", "MatMulLhsGradOp", "MatMulRhsGradOp", "OpToIdentity",
        "PreUniRepl"
    ]
    if include_patterns:
        patterns += ["TiedGatherPattern", "TiedGatherGradPattern"]
    if train:
        return run_py(
            proto,
            data=data,
            outputs=x,
            loss=popart.L1Loss(x, 'loss', 0.1),
            optimizer=popart.SGD({
                "defaultLearningRate": (0.1, True),
                "defaultMomentum": (0.9, True),
                "defaultDampening": (0, True)
            }),  # 0 dampening to increase the error of incorrect gradients
            patterns=popart.Patterns(patterns),
            user_options={"enableOutlining": outline},
            skip_execution=skip_execution)
    else:
        return run_py(proto,
                      data=data,
                      outputs=x,
                      patterns=popart.Patterns(patterns),
                      user_options={
                          "enableOutlining": outline,
                          "constantWeights": False
                      },
                      skip_execution=skip_execution)
Exemplo n.º 7
0
 def popart_loss_fn(logits):
     loss = popart.L1Loss(logits[0], "l1LossVal", l1_lambda)
     loss.virtualGraph(popart_model.mlm_scope.virtualGraph)
     return [loss]
Exemplo n.º 8
0
 def popart_loss_fn(outputs):
     loss = popart.L1Loss(outputs[0], "l1Loss", 0.1)
     loss.virtualGraph(popart_model.nsp_scope.virtualGraph)
     return [loss]
Exemplo n.º 9
0
def test_embedding_bwd(custom_ops):
    l1_lambda = 0.1

    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather'])
    popart_model = Bert(config, builder=builder)
    # Prevent virtualGraph attributes being added to the ops.
    popart_model.embedding_scope = popart_model.device_scope(None, None)
    popart_model.embedding_split_scope = popart_model.embedding_scope

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    segments = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    output = popart_model.embedding(indices, positions, segments)

    proto = builder.getModelProto()

    l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)
    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(
        proto,
        data,
        output,
        loss=l1,
        optimizer=optimizer,
        user_options={"enableStochasticRounding": True})

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    torch_to_onnx = {
        "word_embeddings.weight": "Embedding_Dict",
        "position_embeddings.weight": "Positional_Dict",
        "token_type_embeddings.weight": "Segment_Dict",
        "LayerNorm.weight": "Gamma",
        "LayerNorm.bias": "Beta"
    }

    transposed_weights = {
        "word_embeddings.weight": np.transpose,
        "position_embeddings.weight": np.transpose,
    }

    #  ------------------- PyTorch -------------------------

    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          torch_to_onnx,
                          transform=transposed_weights)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs])
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    torch_outputs = [torch_output.detach().numpy()]

    check_tensors(torch_outputs, outputs)

    check_model(torch_model,
                post_proto,
                torch_to_onnx,
                transform=transposed_weights)
Exemplo n.º 10
0
def test_warmup(custom_ops, num_steps=100000):
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        num_layers=1,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather', 'attention'])
    popart_model = Bert(config, builder=builder)

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.sequence_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    output = popart_model.build_graph(indices, positions)[0]

    losses = [popart.L1Loss(output, "l1LossVal", 0.1)]

    for loss in losses:
        loss.virtualGraph(popart_model.ipu)

    proto = popart_model.builder.getModelProto()
    optimizer = popart.SGD(0.00001)

    ipus = math.ceil(config.num_layers / config.layers_per_ipu) \
        + popart_model.layer_offset

    # Analagous to run_py, but only the setup stages
    print("Creating session and compiling graph")
    session, anchors, device = create_session(proto,
                                              data,
                                              output,
                                              optimizer,
                                              losses,
                                              ipus=ipus)

    print("Running with opimiser updates")
    times_with_optimiser = timed_run_steps(session,
                                           anchors,
                                           data,
                                           0.1,
                                           num_steps=num_steps)
    print("Running without opimiser updates")
    times_no_optimiser = timed_run_steps(session,
                                         anchors,
                                         data,
                                         None,
                                         num_steps=num_steps)

    device.detach()

    # Convert seconds to milliseconds.
    opt_np = 1000 * times_with_optimiser
    noopt_np = 1000 * times_no_optimiser

    print(f"W/  Optimiser Update")
    print(f"\tMean: {opt_np.mean():.5f}")
    print(f"\tSum:  {opt_np.sum():.5f}")
    print(f"\tRng: {opt_np.min():.5f} -> {opt_np.max():.5f}")

    print(f"W/o  Optimiser Update")
    print(f"\tMean: {noopt_np.mean():.5f}")
    print(f"\tSum:  {noopt_np.sum():.5f}")
    print(f"\tRng: {noopt_np.min():.5f} -> {noopt_np.max():.5f}")

    mean_diff = opt_np.mean() - noopt_np.mean()
    percentage_difference = 100 * mean_diff / noopt_np.mean()
    print(
        f"Mean difference, {mean_diff:.5f}ms (~{percentage_difference:.1f}%)")

    assert (percentage_difference < 5)
    def test(config, iteration, true_scaling, test_case):
        builder = popart.Builder()

        w0name = "weight_0"
        w1name = "weight_1"
        w2name = "weight_2"

        input0Shape = [1, 1, 1]
        input0 = builder.addInputTensor(
            popart.TensorInfo("FLOAT", input0Shape), "input0")

        w0data = np.array([test_case[0][0]], dtype=np.float32)
        w0R = np.empty([
            1,
        ], dtype=np.float32)
        w0Id = builder.addInitializedInputTensor(w0data, w0name)

        w1data = np.array([test_case[1][0]], dtype=np.float32)
        w1R = np.empty([
            1,
        ], dtype=np.float32)
        w1Id = builder.addInitializedInputTensor(w1data, w1name)

        w2data = np.array([test_case[2][0]], dtype=np.float32)
        w2R = np.empty([
            1,
        ], dtype=np.float32)
        w2Id = builder.addInitializedInputTensor(w2data, w2name)

        add0 = builder.aiOnnx.add([w0Id, input0])
        add1 = builder.aiOnnx.add([w1Id, add0])
        add2 = builder.aiOnnx.add([w2Id, add1])

        builder.addOutputTensor(add2)

        proto = builder.getModelProto()
        dataFlow = popart.DataFlow(1, {})
        opts = popart.SessionOptions()
        opts.reportOptions = {"showExecutionSteps": "true"}
        opts.enableGroupedMatmuls = False
        pat = popart.Patterns(popart.PatternsLevel.DEFAULT)
        device = popart.DeviceManager().acquireAvailableDevice(1)
        if device is None:
            raise OSError("Failed to acquire IPU.")

        # The stage->tensor map would come from the Bert model in reality
        # (see model.pipeline_stage_tensors)
        mock_tensor_map = {0: [w0Id], 1: [w1Id], 2: [w2Id]}

        factory = ScheduledOptimizerFactory(config,
                                            iteration,
                                            tensors=mock_tensor_map)
        assert_scaled_lr(factory, true_scaling)

        optimizer_step0 = factory.create()

        session = popart.TrainingSession(
            fnModel=proto,
            dataFeed=dataFlow,
            userOptions=opts,
            losses=[popart.L1Loss(add2, "l1LossVal", 1.0)],
            optimizer=optimizer_step0,
            passes=pat,
            deviceInfo=device)

        session.prepareDevice()
        session.weightsFromHost()
        anchors = session.initAnchorArrays()

        input_data = np.array([3.1415], dtype=np.float32)
        stepio = popart.PyStepIO({input0: input_data}, anchors)
        session.optimizerFromHost()

        for step in range(iteration.total_steps):
            session.run(stepio)
            session.weightsToHost()
            weightsRead = popart.PyWeightsIO({w0Id: w0R, w1Id: w1R, w2Id: w2R})
            session.readWeights(weightsRead)

            assert (np.isclose(test_case[0][step + 1], w0R))
            assert (np.isclose(test_case[1][step + 1], w1R))
            assert (np.isclose(test_case[2][step + 1], w2R))

            iteration.count += 1

            if factory.should_update(iteration):
                optimizer_step1 = factory.update_and_create(iteration)
                assert_scaled_lr(factory, true_scaling)

                session.updateOptimizer(optimizer_step1)
                session.optimizerFromHost()
def test_embedding_projection_bwd(custom_ops):
    l1_lambda = 0.1

    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather'])
    popart_model = Bert(config, builder=builder)

    sequence_info = popart.TensorInfo(
        "INT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.int32)
    }

    x = popart_model.embedding_custom(
        indices, config.vocab_length, "Embedding_Dict", detach=True)
    x = popart_model.norm(x)
    x = popart_model.dropout(x)
    with popart_model.device_scope(nameScope="CLS"):
        x = popart_model.lm_prediction_head(x)
    output = popart_model.projection(x)

    proto = builder.getModelProto()

    l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)
    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(proto,
                                 data, output,
                                 loss=l1,
                                 optimizer=optimizer,
                                 user_options={"enableStochasticRounding": True})

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [data[indices].reshape(config.batch_size, config.sequence_length)]

    #  ------------------- PyTorch -------------------------

    torch_model = EmbeddingProjectionModel(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          torch_to_onnx,
                          transform=transposed_weights)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs])
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    check_tensors([torch_output.detach().numpy()], outputs, margin=1e-5)

    check_model(torch_model,
                post_proto,
                torch_to_onnx,
                transform=transposed_weights)
Exemplo n.º 13
0
def test_attention_bwd(custom_ops):
    l1_lambda = 0.1

    #  ------------------- PopART --------------------
    builder = popart.Builder()
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['attention'])
    popart_model = Bert(config, builder=builder)

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.batch_size * config.sequence_length, config.hidden_size])
    input_tensor = builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo("INT32", [config.batch_size])
    mmask_tensor = builder.addInputTensor(mask_info)
    smask_tensor = builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1,
                          (config.batch_size, )).astype(np.int32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1,
                          (config.batch_size, )).astype(np.int32)
    }

    output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor])
    proto = builder.getModelProto()

    l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)
    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(proto,
                                 data, (output, l1.output(0)),
                                 loss=l1,
                                 optimizer=optimizer)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.batch_size, config.sequence_length,
                                   config.hidden_size),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    torch_to_onnx = {
        "self.query.weight": "QKV",
        "self.key.weight": "QKV",
        "self.value.weight": "QKV",
        "output.dense.weight": "Out",
        "output.LayerNorm.weight": "Gamma",
        "output.LayerNorm.bias": "Beta"
    }

    split_qkv = {
        "self.query.weight":
        lambda arr: arr[:, 0:config.hidden_size].T,
        "self.key.weight":
        lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T,
        "self.value.weight":
        lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T,
        "output.dense.weight":
        np.transpose
    }

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          torch_to_onnx,
                          transform=split_qkv)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).float()
                                 for t in inputs])[0]
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    check_tensors([torch_output.detach().numpy()], outputs)

    check_model(torch_model, post_proto, torch_to_onnx, transform=split_qkv)
Exemplo n.º 14
0
batchSize = 2
batchesPerStep = 4
anchors = {
    "l1LossVal": popart.AnchorReturnType("EveryN", 2),
    "out": popart.AnchorReturnType("Final"),
    "im0": popart.AnchorReturnType("All")
}
dataFlow = popart.DataFlow(batchesPerStep, anchors)
inputShapeInfo = popart.InputShapeInfo()
inputShapeInfo.add("im0",
                   popart.TensorInfo("FLOAT", [batchSize, nInChans, 32, 32]))

inNames = ["im0"]
outNames = ["out"]
cifarInIndices = {"im0": 0}
losses = [popart.L1Loss("out", "l1LossVal", 0.1)]


class Module0(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)

        self.sin = torch.sin
        self.conv1 = torchwriter.conv3x3(nInChans, nOutChans)
        self.in2 = torch.nn.InstanceNorm2d(nOutChans,
                                           eps=0.1,
                                           affine=True,
                                           momentum=0)
        # Force random initialization
        np.random.seed(0)
        self.in2.weight.data = torch.tensor(
Exemplo n.º 15
0
dataFlow = popart.DataFlow(batchesPerStep, anchors)
inputShapeInfo = popart.InputShapeInfo()
inputShapeInfo.add("image0",
                   popart.TensorInfo("FLOAT", [batchSize, nInChans, 32, 32]))
inputShapeInfo.add("image1",
                   popart.TensorInfo("FLOAT", [batchSize, nInChans, 32, 32]))
inputShapeInfo.add("label", popart.TensorInfo("INT32", [batchSize]))

inNames = ["image0", "image1"]
cifarInIndices = {"image0": 0, "image1": 0, "label": 1}
outNames = ["imageSum", "postConv0", "preProbSquared", "probs"]

losses = [
    popart.NllLoss("probs", "label", "nllLossVal"),
    popart.L1Loss("preProbSquared", "l1LossVal", 0.01)
]

willowOptPatterns = popart.Patterns(popart.PatternsLevel.All)


class Module0(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)
        self.conv1 = torchwriter.conv3x3(nInChans, nOutChans)
        self.conv2 = torchwriter.conv3x3(nOutChans, nOutChans)
        self.sin = torch.sin
        self.pad = torch.nn.functional.pad
        # for softmax dim -1 is correct for [sample][class],
        # gives class probabilities for each sample.
        self.softmax = torch.nn.Softmax(dim=-1)
def test_embedding_projection_bwd(custom_ops):
    l1_lambda = 0.1

    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(
        vocab_length=9728,
        projection_serialization_steps=4,
        batch_size=1,
        hidden_size=768,
        sequence_length=128,
        activation_type='relu',
        popart_dtype="FLOAT",
        no_dropout=True,
        # Currently updating embedding dict with projection is only
        # available with momentum. And PopART != Pytorch momentum
        # due to a bootstrapping step on iter 0.
        update_embedding_dict=False)
    popart_model = Bert(config, builder=builder)

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    x = popart_model.gather(indices, config.vocab_length, "Embedding_Dict")
    x = popart_model.norm(x)
    x = popart_model.dropout(x)
    with popart_model.device_scope(nameScope="CLS"):
        x = popart_model.lm_prediction_head(x)
    output = popart_model.projection(x)

    proto = builder.getModelProto()

    l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)
    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 loss=l1,
                                 optimizer=optimizer)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[indices].reshape(config.batch_size,
                              config.sequence_length).astype(np.int32)
    ]

    #  ------------------- PyTorch -------------------------

    torch_model = EmbeddingProjectionModel(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          TORCH_TO_ONNX,
                          transform=TRANSPOSE_WEIGHTS)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs])
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    check_tensors([torch_output.detach().numpy()], outputs, margin=1e-5)

    check_model(torch_model,
                post_proto,
                TORCH_TO_ONNX,
                transform=TRANSPOSE_WEIGHTS)