Example #1
0
def run_embedding_layer(args):
    set_library_seeds(args.seed)

    config = bert_config_from_args(args)

    initializers = bert_pretrained_initialisers(config, args)

    logger.info("Building Model")
    # Specifying ai.onnx opset9 for the slice syntax
    # TODO: Change slice to opset10
    model = Bert(config,
                 builder=popart.Builder(opsets={
                     "ai.onnx": 9,
                     "ai.onnx.ml": 1,
                     "ai.graphcore": 1
                 }),
                 initializers=initializers,
                 execution_mode=args.execution_mode)

    # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector.
    indices, positions, segments, masks, labels = bert_add_inputs(args, model)
    logits = tuple([model.embedding(indices, positions, segments)])

    if args.inference:
        outputs = bert_add_logit_outputs(model, logits)
        writer = None

        dataset = get_bert_dataset(
            model, args, [indices, positions, segments, masks, labels])

        data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

        iteration = Iteration(
            args,
            steps_per_epoch=len(dataset),
            writer=writer,
            recording_steps=args.aggregate_metrics_over_steps)

        request_ipus = bert_required_ipus(args, model)

        device = acquire_device(args, request_ipus)

        session, anchors = bert_inference_session(model, args, data_flow,
                                                  device)
        logger.info("Inference Started")
        inputs = [indices, positions, segments, *masks]
        """bert_infer_loop(args, session,
                        dataset, inputs, logits, anchors,
                        iteration)"""
        save_results = args.task == "SQUAD" and not (args.synthetic_data
                                                     or args.generated_data)

        start_times = defaultdict(list)
        end_times = defaultdict(list)
        # Create the stepio once outside of the inference loop:
        static_data = {}
        if args.low_latency_inference and args.task == "SQUAD":
            stepio = create_callback_stepio(static_data, anchors, start_times,
                                            end_times,
                                            dataset.batches_per_step)
        else:
            stepio = None

        output = []
        logger.info(dataset)
        for data in dataset:
            static_data.update({t: data[t] for t in inputs})
            result = bert_process_infer_data(args, session, static_data,
                                             anchors, logits, iteration,
                                             start_times, end_times, stepio)
            if save_results:
                output.append(result)
            break

        device.detach()
        return output

    return None
Example #2
0
def test_embedding_fwd(custom_ops):
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        micro_batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        inference=True)
    popart_model = Bert(config)

    sequence_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size * config.sequence_length])
    indices = popart_model.builder.addInputTensor(sequence_info)
    positions = popart_model.builder.addInputTensor(sequence_info)
    segments = popart_model.builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(
            0, config.vocab_length,
            (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32),
        positions:
        np.random.randint(
            0, config.max_positional_length,
            (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32),
        segments:
        np.random.randint(
            0, 2, (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32)
    }

    user_options = {"enableStochasticRounding": True}
    with popart_model.builder.nameScope("Embedding"):
        output = popart_model.embedding(indices, positions, segments)

    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 user_options=user_options)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.micro_batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {})
    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs, margin=5e-7)
Example #3
0
def test_embedding_bwd(custom_ops):
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        micro_batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        update_embedding_dict=True)

    popart_model = Bert(config)
    # Prevent virtualGraph attributes being added to the ops

    sequence_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size * config.sequence_length])
    indices = popart_model.builder.addInputTensor(sequence_info)
    positions = popart_model.builder.addInputTensor(sequence_info)
    segments = popart_model.builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(
            0, config.vocab_length,
            (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32),
        positions:
        np.random.randint(
            0, config.max_positional_length,
            (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32),
        segments:
        np.random.randint(
            0, 2, (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32)
    }

    optimizer = popart.ConstSGD(0.01)

    l1_lambda = 0.1
    with popart_model.builder.nameScope("Embedding"):
        output = popart_model.embedding(indices, positions, segments)
    l1 = popart_model.builder.aiGraphcore.l1loss(
        [output],
        l1_lambda,
        debugContext="l1LossVal",
        reduction=popart.ReductionType.Sum)

    num_reps = 5
    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 ipus=1,
                                 loss=l1,
                                 num_reps=num_reps,
                                 optimizer=optimizer)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.micro_batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------

    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        update_embedding_dict=config.update_embedding_dict))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {})

    optim = torch.optim.SGD(torch_model.parameters(), 0.01)
    for _ in range(num_reps):
        torch_output = torch_model(
            *[torch.from_numpy(t).long() for t in inputs])
        torch_loss = l1_lambda * torch.norm(torch_output, 1)
        torch_loss.backward()
        optim.step()
        optim.zero_grad()

    torch_outputs = [torch_output.detach().numpy()]

    check_tensors(torch_outputs, outputs, margin=7e-6)

    check_model(torch_model, post_proto, TORCH_TO_ONNX, {}, margin=7e-06)
Example #4
0
def test_embedding_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather'],
                        inference=True)
    popart_model = Bert(config, builder=builder)
    # Prevent virtualGraph attributes being added to the ops.
    popart_model.embedding_scope = popart_model.device_scope(None, None)
    popart_model.embedding_split_scope = popart_model.embedding_scope

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    segments = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    # Use the custom embedding for layout
    output = popart_model.embedding(indices, positions, segments)

    proto = builder.getModelProto()

    outputs, post_proto = run_py(
        proto, data, output, user_options={"enableStochasticRounding": True})

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    torch_to_onnx = {
        "word_embeddings.weight": "Embedding_Dict",
        "position_embeddings.weight": "Positional_Dict",
        "token_type_embeddings.weight": "Segment_Dict",
        "LayerNorm.weight": "Gamma",
        "LayerNorm.bias": "Beta"
    }

    transposed_weights = {
        "word_embeddings.weight": np.transpose,
        "position_embeddings.weight": np.transpose,
    }

    #  ------------------- PyTorch -------------------------
    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, torch_to_onnx,
                          transposed_weights)

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
Example #5
0
def test_embedding_bwd(custom_ops):
    l1_lambda = 0.1

    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        custom_ops=['gather'])
    popart_model = Bert(config, builder=builder)
    # Prevent virtualGraph attributes being added to the ops.
    popart_model.embedding_scope = popart_model.device_scope(None, None)
    popart_model.embedding_split_scope = popart_model.embedding_scope

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    segments = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    output = popart_model.embedding(indices, positions, segments)

    proto = builder.getModelProto()

    l1 = popart.L1Loss(output, "l1LossVal", l1_lambda)
    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(
        proto,
        data,
        output,
        loss=l1,
        optimizer=optimizer,
        user_options={"enableStochasticRounding": True})

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    torch_to_onnx = {
        "word_embeddings.weight": "Embedding_Dict",
        "position_embeddings.weight": "Positional_Dict",
        "token_type_embeddings.weight": "Segment_Dict",
        "LayerNorm.weight": "Gamma",
        "LayerNorm.bias": "Beta"
    }

    transposed_weights = {
        "word_embeddings.weight": np.transpose,
        "position_embeddings.weight": np.transpose,
    }

    #  ------------------- PyTorch -------------------------

    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          torch_to_onnx,
                          transform=transposed_weights)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs])
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    torch_outputs = [torch_output.detach().numpy()]

    check_tensors(torch_outputs, outputs)

    check_model(torch_model,
                post_proto,
                torch_to_onnx,
                transform=transposed_weights)
def test_embedding_fwd(custom_ops):
    #  ------------------- PopART --------------------
    builder = popart.Builder(opsets={
        "ai.onnx": 9,
        "ai.onnx.ml": 1,
        "ai.graphcore": 1
    })
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        inference=True)
    popart_model = Bert(config, builder=builder)
    # Prevent virtualGraph attributes being added to the ops.
    popart_model.embedding_scope = popart_model.device_scope(None, None)
    popart_model.embedding_split_scope = popart_model.embedding_scope

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = builder.addInputTensor(sequence_info)
    positions = builder.addInputTensor(sequence_info)
    segments = builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    output = popart_model.embedding(indices, positions, segments)

    proto = builder.getModelProto()

    outputs, post_proto = run_py(proto, data, output)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {})

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)