コード例 #1
0
def test_trainable_params():
    config = BertConfig(task="PRETRAINING",
                        vocab_length=1024,
                        micro_batch_size=1,
                        hidden_size=64,
                        attention_heads=2,
                        sequence_length=20,
                        mask_tokens=8,
                        popart_dtype="FLOAT",
                        num_layers=2,
                        no_mask=True,
                        no_dropout=True,
                        no_attn_dropout=True,
                        embedding_serialization_vocab_steps=4,
                        inference=False)

    # Create phased_execution version of the model
    model = get_model(config, ExecutionMode.PHASED)
    data = {
        'indices':
        np.random.randint(
            0, config.vocab_length,
            (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32),
        'positions':
        np.random.randint(
            0, config.sequence_length,
            (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32),
        'segments':
        np.random.randint(
            0, 2, (config.micro_batch_size * config.sequence_length)).astype(
                np.uint32)
    }

    sequence_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size * config.sequence_length])
    indices = model.builder.addInputTensor(sequence_info)
    positions = model.builder.addInputTensor(sequence_info)
    segments = model.builder.addInputTensor(sequence_info)

    data_popart = {}
    data_popart[indices] = data['indices']
    data_popart[segments] = data['segments']
    data_popart[positions] = data['positions']

    model(indices, positions, segments)
    proto = model.builder.getModelProto()

    # Extract weights from onnx model and check if same number of elements as self.tensors[0]
    with tempfile.TemporaryDirectory() as tmp:
        model_path = os.path.join(tmp, "model.onnx")
        onnx.save(proto, model_path)
        onnx_model = onnx.load(model_path)
        assert len(model.tensors[0]) == len(onnx_model.graph.initializer)
コード例 #2
0
def get_model_proto(config, mode, initializers=None):
    model = get_model(config, mode, initializers=initializers)

    sequence_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size * config.sequence_length])
    indices = model.builder.addInputTensor(sequence_info)
    positions = model.builder.addInputTensor(sequence_info)
    segments = model.builder.addInputTensor(sequence_info)

    if mode == ExecutionMode.PHASED:
        output = model(indices, positions, segments)
    else:
        output = model.build_graph(indices, positions, segments)
    return onnx.load_model_from_string(model.builder.getModelProto())
コード例 #3
0
def test_pretraining_fwd(custom_ops, mode, replication_factor,
                         replicated_tensor_sharding):
    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        encoder_start_ipu=1,
                        vocab_length=1024,
                        micro_batch_size=1,
                        hidden_size=64,
                        attention_heads=2,
                        sequence_length=20,
                        max_positional_length=20,
                        mask_tokens=2,
                        popart_dtype="FLOAT",
                        activation_type="relu",
                        no_dropout=True,
                        no_attn_dropout=True,
                        no_cls_layer=False,
                        inference=True,
                        no_mask=True,
                        execution_mode=mode,
                        split_qkv=False)

    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForMaskedLM(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        no_cls_layer=config.no_cls_layer))

    fwd_graph(popart_model,
              torch_model,
              mode,
              mapping=ONNX_TORCH_MAPPING[mode],
              transform=onnx_torch_tform,
              replication_factor=replication_factor,
              replicated_tensor_sharding=replicated_tensor_sharding)
コード例 #4
0
def test_squad_fwd(mode, replication_factor, replicated_weight_sharding):
    split_qkv = False
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        num_layers=2,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        inference=True,
                        no_mask=True,
                        execution_mode=mode,
                        split_qkv=split_qkv,
                        squad_single_output=False)

    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForQuestionAnswering(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    fwd_graph(popart_model,
              torch_model,
              mode,
              mapping=ONNX_TORCH_MAPPING[mode],
              transform={"qa_outputs.weight": np.transpose},
              replication_factor=replication_factor,
              replicated_weight_sharding=replicated_weight_sharding)
コード例 #5
0
ファイル: nsp_test.py プロジェクト: muzzynine/examples-1
def test_nsp_fwd(custom_ops, mode):
    #  ------------------- PopART --------------------
    config = BertConfig(task="NSP",
                        vocab_length=9728,
                        num_layers=2,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        inference=True,
                        no_mask=True,
                        execution_mode=mode,
                        mask_tokens=0,
                        split_qkv=False)
    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForNextSentencePrediction(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act=config.activation_type,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        num_labels=2))

    fwd_graph(popart_model,
              torch_model,
              mode,
              NSP_MAPPING[mode],
              transform=NSP_TRANSFORM)
コード例 #6
0
def pretraining_bwd(custom_ops,
                    mode,
                    replication_factor,
                    replicated_weight_sharding,
                    opt_type,
                    vocab_length=9728,
                    hidden_size=768):
    #  ------------------- PopART --------------------
    if mode == ExecutionMode.PHASED:
        # Phased Execution requires atleast two transformer layers to ensure mlm and embedding are in the same virtual graph.
        num_layers = 2
    else:
        num_layers = 1
    config = BertConfig(task="PRETRAINING",
                        vocab_length=vocab_length,
                        num_layers=num_layers,
                        batch_size=1,
                        hidden_size=hidden_size,
                        sequence_length=128,
                        popart_dtype="FLOAT",
                        activation_type="relu",
                        no_dropout=True,
                        no_attn_dropout=True,
                        update_embedding_dict=True,
                        no_cls_layer=True,
                        no_mask=True,
                        phased_execution_type="single",
                        execution_mode=mode,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForMaskedLM(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        update_embedding_dict=True,
                        mask_tokens=config.mask_tokens))

    l1_lambda = 0.1

    def popart_loss_fn(logits):
        if mode == ExecutionMode.PHASED:
            with popart_model.scope_provider(popart_model.builder,
                                             popart_model.mlm_scope):
                loss = popart_model.builder.aiGraphcore.l1loss(
                    [logits[0]],
                    l1_lambda,
                    debugPrefix="l1LossVal",
                    reduction=popart.ReductionType.Sum)
        else:
            loss = popart_model.builder.aiGraphcore.l1loss(
                [logits[0]],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
            popart_model.builder.virtualGraph(
                loss, popart_model.mlm_scope.virtualGraph)
        return loss

    bwd_graph(
        popart_model,
        torch_model,
        mode,
        popart_loss_fn=popart_loss_fn,
        torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1),
        mapping={},
        transform=onnx_torch_tform,
        replication_factor=replication_factor,
        replicated_weight_sharding=replicated_weight_sharding,
        opt_type=opt_type)
コード例 #7
0
def popart_result_and_model(popart_config,
                            mode,
                            batch_serialization_factor,
                            is_bwd=False,
                            momentum=0.0):
    popart_model = get_model(popart_config, mode, 'feedforward')

    input_info = popart.TensorInfo(popart_config.popart_dtype, [
        popart_config.micro_batch_size * popart_config.sequence_length,
        popart_config.hidden_size
    ])
    input_tensor = popart_model.builder.addInputTensor(input_info)

    data = {
        input_tensor:
        np.random.normal(0, 0.02,
                         input_info.shape()).astype(popart_config.dtype)
    }

    user_options = {}
    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": batch_serialization_factor,
            "executionPhases": popart_model.total_execution_phases,
        }
        output = popart_model(input_tensor)
    else:
        output = popart_model.feed_forward(input_tensor)

    if is_bwd:
        l1_lambda = 0.1
        if mode == ExecutionMode.PHASED:
            with popart_model.scope_provider(popart_model.builder,
                                             popart_model.norm.scope):
                l1 = popart_model.builder.aiGraphcore.l1loss(
                    [output],
                    l1_lambda,
                    debugPrefix="l1LossVal",
                    reduction=popart.ReductionType.Sum)

        else:
            l1 = popart_model.builder.aiGraphcore.l1loss(
                [output],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
        proto = popart_model.builder.getModelProto()

        if momentum > 0.0:
            optimizer = popart.SGD({
                "defaultLearningRate": (lr, False),
                "defaultMomentum": (momentum, False),
                "defaultWeightDecay": (0.0, False)
            })
        else:
            optimizer = popart.ConstSGD(lr)

        outputs, post_proto = run_py(proto,
                                     data, (output, l1),
                                     loss=l1,
                                     optimizer=optimizer,
                                     user_options=user_options,
                                     execution_mode=mode,
                                     num_reps=num_reps_bwd)
    else:
        proto = popart_model.builder.getModelProto()
        outputs, post_proto = run_py(proto,
                                     data,
                                     output,
                                     user_options=user_options,
                                     execution_mode=mode)

    return data[input_tensor], outputs, proto, post_proto
コード例 #8
0
def popart_result_and_model(config, mode, weight_transposed, is_bwd=False):
    """Run popart model based on config.

    Args:
        config (BertConfig): Popart config.
        weight_transposed: Construct embedding dict transposed.
        is_bwd (bool, optional): Construct training graph if True,
                                 else inference graph. Defaults to False.

    Returns:
        Tuple: Gathered numpy data, outputs from model, proto, post_proto
    """

    scope_provider = ScopeProvider()
    user_options = {}
    if mode == ExecutionMode.PHASED:
        builder = popart.Builder()

        indices_len = config.batch_size * config.sequence_length
        sequence_info = popart.TensorInfo("UINT32", [indices_len])
        indices = builder.addInputTensor(sequence_info)
        data = {indices: np.random.randint(0, config.vocab_length, (indices_len)).astype(np.uint32)}

        popart_model = EmbeddingSerialised(scope_provider.get_scope('Token'),
                                           input_dim=config.vocab_length,
                                           output_dim=config.hidden_size,
                                           num_splits=config.embedding_serialization_vocab_steps,
                                           custom=True,
                                           dtype=config.dtype,
                                           detach=not config.update_embedding_dict,
                                           weight_transposed=weight_transposed,
                                           builder=builder,
                                           scope_provider=scope_provider)
        user_options = {
            "batchSerializationFactor": 1,
            "executionPhases": popart_model.total_execution_phases
        }
        output = popart_model(indices)
    else:
        popart_model = get_model(config, mode, block="embedding", initializers={})
        builder = popart_model.builder

        indices_len = config.batch_size * config.sequence_length
        sequence_info = popart.TensorInfo("UINT32", [indices_len])
        indices = builder.addInputTensor(sequence_info)
        data = {indices: np.random.randint(0, config.vocab_length, (indices_len)).astype(np.uint32)}
        output = popart_model.word_embedding_serialized(indices, num_splits)

    if is_bwd:
        l1_lambda = 0.1
        if mode == ExecutionMode.PHASED:
            loss_scope = scope_provider.get_scope('Loss', 'prev')
            with popart_model.scope_provider(popart_model.builder, loss_scope):
                l1_loss = popart_model.builder.aiGraphcore.l1loss([output],
                                                                  l1_lambda,
                                                                  debugPrefix="l1LossVal",
                                                                  reduction=popart.ReductionType.Sum)
        else:
            l1_loss = popart_model.builder.aiGraphcore.l1loss([output],
                                                              l1_lambda,
                                                              debugPrefix="l1LossVal",
                                                              reduction=popart.ReductionType.Sum)
        proto = builder.getModelProto()
        optimizer = popart.ConstSGD(0.01)
        outputs, post_proto = run_py(proto,
                                     data, (output, l1_loss),
                                     loss=l1_loss,
                                     optimizer=optimizer,
                                     user_options=user_options,
                                     execution_mode=mode)
    else:
        proto = builder.getModelProto()
        outputs, post_proto = run_py(proto, data, output,
                                     user_options=user_options,
                                     execution_mode=mode)

    return [data[indices]], outputs, proto, post_proto
コード例 #9
0
def main(args):
    set_library_seeds(args.seed)

    config = bert_config_from_args(args)

    initializers = bert_pretrained_initialisers(config, args)

    logger.info("Building Model")
    # Specifying ai.onnx opset9 for the slice syntax
    model = get_model(config,
                      mode=args.execution_mode,
                      initializers=initializers,
                      block=None)

    # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector.
    indices, positions, segments, masks, labels = bert_add_inputs(args, model)
    logits = bert_logits_graph(model, indices, positions, segments, masks,
                               args.execution_mode)

    if args.inference:
        accuracies = None
        losses = []
        if args.task == "PRETRAINING":
            # If this is a pretraining session, labels for NSP and MLM are already within the dataset,
            # so we can always calculate prediction performance
            predictions, _ = bert_infer_graph(model,
                                              logits,
                                              include_probs=False)

            if args.inference_lm_perplexity:
                losses, _ = bert_perplexity_graph(args, model, logits, labels)
            else:
                losses = [None, None]

            outputs, accuracies, losses = bert_add_validation_outputs(
                args, model, predictions, labels, losses)
        else:
            if args.inference_lm_perplexity:
                raise RuntimeError(
                    "Masked LM perplexity is only supported in pretraining.")

            outputs = bert_add_logit_outputs(model, logits)

        writer = None
    else:
        predictions, probs = bert_infer_graph(model, logits)
        losses, final_loss = bert_loss_graph(args, model, probs, labels)
        outputs, accuracies, losses = bert_add_validation_outputs(
            args, model, predictions, labels, losses)
        writer = bert_writer(args)

    device = acquire_device(args, bert_required_ipus(args, model))

    dataset = get_bert_dataset(model, args,
                               [indices, positions, segments, masks, labels])
    logger.info(f"Dataset length: {len(dataset)}")

    data_flow = popart.DataFlow(args.batches_per_step, outputs)

    iteration = bert_iteration(args, dataset, writer)

    if args.inference:
        session, anchors = bert_inference_session(model, args, data_flow,
                                                  device)
        logger.info("Inference Started")
        inputs = [indices, positions, segments, *masks, *labels]
        bert_infer_loop(args, session, dataset, inputs, logits, anchors,
                        accuracies, losses, iteration)
        device.detach()
    else:
        if not args.no_training:
            optimizer_factory = bert_optimizer_factory(args, model, iteration)

            session, anchors = bert_training_session(model, args, data_flow,
                                                     final_loss, device,
                                                     optimizer_factory)
            logger.info("Training Started")
            bert_train_loop(args, session, writer, dataset, accuracies, losses,
                            anchors, iteration, optimizer_factory)

            save_model_and_stats(args, session, writer, iteration.count)

            device.detach()
            logger.info("Training Finished")

    return session, iteration
コード例 #10
0
def test_attention_bwd(mode):
    l1_lambda = 0.1

    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True)
    popart_model = get_model(config, mode, 'attention')

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.batch_size * config.sequence_length, config.hidden_size])
    input_tensor = popart_model.builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo("UINT32", [config.batch_size])
    mmask_tensor = popart_model.builder.addInputTensor(mask_info)
    smask_tensor = popart_model.builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1,
                          (config.batch_size, )).astype(np.uint32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1,
                          (config.batch_size, )).astype(np.uint32)
    }

    user_options = {}
    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": 1,
            "executionPhases": popart_model.total_execution_phases
        }
        output = popart_model(input_tensor, [mmask_tensor, smask_tensor])
        with popart_model.scope_provider(popart_model.builder,
                                         popart_model.norm.scope):
            l1 = popart_model.builder.aiGraphcore.l1loss(
                [output],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)

    else:
        user_options = {"enableStochasticRounding": True}
        output = popart_model.attention(input_tensor,
                                        [mmask_tensor, smask_tensor])
        l1 = popart_model.builder.aiGraphcore.l1loss(
            [output], l1_lambda, reduction=popart.ReductionType.Sum)

    proto = popart_model.builder.getModelProto()

    optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(proto,
                                 data, (output, l1),
                                 loss=l1,
                                 optimizer=optimizer,
                                 user_options=user_options,
                                 execution_mode=mode)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.batch_size, config.sequence_length,
                                   config.hidden_size),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    split_qkv = {
        "self.query.weight":
        lambda arr: arr[:, 0:config.hidden_size].T,
        "self.key.weight":
        lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T,
        "self.value.weight":
        lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T,
        "output.dense.weight":
        np.transpose
    }

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model,
                          proto,
                          TORCH_TO_ONNX[mode],
                          transform=split_qkv)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).float()
                                 for t in inputs])[0]
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    check_tensors([torch_output.detach().numpy()], outputs)

    check_model(torch_model,
                post_proto,
                TORCH_TO_ONNX[mode],
                transform=split_qkv)
コード例 #11
0
def test_attention_fwd(mode, micro_batch_size, batch_serialisation_factor,
                       number_attention_splits, attention_bias, split_qkv):
    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        micro_batch_size=micro_batch_size,
                        hidden_size=768,
                        attention_heads=4,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        inference=True,
                        split_qkv=split_qkv,
                        attention_bias=attention_bias,
                        num_attention_splits=number_attention_splits)
    popart_model = get_model(config, mode, 'attention')

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.micro_batch_size * config.sequence_length, config.hidden_size])
    input_tensor = popart_model.builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size, config.sequence_length])
    mmask_tensor = popart_model.builder.addInputTensor(mask_info)
    smask_tensor = popart_model.builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32)
    }

    user_options = {}
    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": batch_serialisation_factor,
            "executionPhases": popart_model.total_execution_phases
        }
        output = popart_model(input_tensor, [mmask_tensor, smask_tensor])
    else:
        user_options = {"enableStochasticRounding": True}
        output = popart_model.attention(input_tensor,
                                        [mmask_tensor, smask_tensor])

    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 user_options=user_options,
                                 execution_mode=mode)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.micro_batch_size,
                                   config.sequence_length,
                                   config.hidden_size).astype(np.float32),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        attention_bias=config.attention_bias,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()
    mapping = TORCH_TO_ONNX[mode]
    if split_qkv:
        mapping = TORCH_TO_ONNX_SPLIT_QKV[mode]
    copy_weights_to_torch(torch_model,
                          proto,
                          mapping,
                          transform=get_transform(split_qkv,
                                                  config.hidden_size))

    # Model to test against
    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs)
コード例 #12
0
def test_attention_bwd(mode, momentum, micro_batch_size,
                       batch_serialisation_factor, number_attention_splits,
                       attention_bias):
    l1_lambda = 0.1
    num_reps = 5
    np.random.seed(1984)
    torch.manual_seed(1984)
    split_qkv = False

    #  ------------------- PopART --------------------
    config = BertConfig(task="PRETRAINING",
                        vocab_length=9728,
                        micro_batch_size=micro_batch_size,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        split_qkv=split_qkv,
                        attention_bias=attention_bias,
                        num_attention_splits=number_attention_splits)
    popart_model = get_model(config, mode, 'attention')

    input_info = popart.TensorInfo(
        config.popart_dtype,
        [config.micro_batch_size * config.sequence_length, config.hidden_size])
    input_tensor = popart_model.builder.addInputTensor(input_info)
    mask_info = popart.TensorInfo(
        "UINT32", [config.micro_batch_size, config.sequence_length])
    mmask_tensor = popart_model.builder.addInputTensor(mask_info)
    smask_tensor = popart_model.builder.addInputTensor(mask_info)
    data = {
        input_tensor:
        np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype),
        mmask_tensor:
        np.random.randint(0, config.mask_tokens + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32),
        smask_tensor:
        np.random.randint(config.mask_tokens, config.sequence_length + 1, (
            config.micro_batch_size,
            config.sequence_length,
        )).astype(np.uint32)
    }

    user_options = {}
    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": batch_serialisation_factor,
            "executionPhases": popart_model.total_execution_phases
        }
        output = popart_model(input_tensor, [mmask_tensor, smask_tensor])
        with popart_model.scope_provider(popart_model.builder,
                                         popart_model.norm.scope):
            l1 = popart_model.builder.aiGraphcore.l1loss(
                [output],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
    else:
        user_options = {}
        output = popart_model.attention(input_tensor,
                                        [mmask_tensor, smask_tensor])
        l1 = popart_model.builder.aiGraphcore.l1loss(
            [output], l1_lambda, reduction=popart.ReductionType.Sum)

    proto = popart_model.builder.getModelProto()

    if momentum:
        optimizer = popart.SGD({
            "defaultLearningRate": (0.01, True),
            "defaultMomentum": (momentum, True)
        })
    else:
        optimizer = popart.ConstSGD(0.01)

    outputs, post_proto = run_py(proto,
                                 data, (output, l1),
                                 loss=l1,
                                 optimizer=optimizer,
                                 num_reps=num_reps,
                                 user_options=user_options,
                                 execution_mode=mode)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[input_tensor].reshape(config.micro_batch_size,
                                   config.sequence_length, config.hidden_size),
        get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]])
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertAttention(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        config.num_layers,
                        config.attention_heads,
                        attention_bias=config.attention_bias,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    mapping = TORCH_TO_ONNX[mode]
    if split_qkv:
        mapping = TORCH_TO_ONNX_SPLIT_QKV[mode]

    copy_weights_to_torch(torch_model,
                          proto,
                          mapping,
                          transform=get_transform(split_qkv,
                                                  config.hidden_size))

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=momentum)

    if momentum:
        for group in optim.param_groups:
            for p in group['params']:
                optim.state[p]['momentum_buffer'] = p.data * 0
                optim.state[p]['exp_avg'] = p.data * 0
                optim.state[p]['exp_avg_sq'] = p.data * 0
                optim.state[p]['step'] = 0

    for _ in range(num_reps):
        torch_output = torch_model(
            *[torch.from_numpy(t).float() for t in inputs])[0]
        torch_loss = l1_lambda * torch.norm(torch_output, 1)
        torch_loss.backward()
        optim.step()
        optim.zero_grad()

    check_tensors([torch_output.detach().numpy()], outputs, margin=6e-07)

    check_model(torch_model,
                post_proto,
                mapping,
                transform=get_transform(split_qkv, config.hidden_size),
                margin=2e-7)
コード例 #13
0
def test_embedding_fwd(custom_ops, mode, batch_size,
                       batch_serialization_factor,
                       embedding_serialization_vocab_steps):
    #  ------------------- PopART --------------------
    config = BertConfig(
        task="SQUAD",
        vocab_length=9728,
        batch_size=batch_size,
        hidden_size=768,
        sequence_length=128,
        activation_type='relu',
        popart_dtype="FLOAT",
        no_dropout=True,
        inference=True,
        embedding_serialization_vocab_steps=embedding_serialization_vocab_steps
    )
    popart_model = get_model(config, mode, 'embedding')

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = popart_model.builder.addInputTensor(sequence_info)
    positions = popart_model.builder.addInputTensor(sequence_info)
    segments = popart_model.builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    user_options = {}
    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": batch_serialization_factor,
            "executionPhases": popart_model.total_execution_phases
        }
        output = popart_model(indices, positions, segments)
    else:
        user_options = {"enableStochasticRounding": True}
        with popart_model.builder.nameScope("Embedding"):
            output = popart_model.embedding(indices, positions, segments)

    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 user_options=user_options,
                                 execution_mode=mode)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------
    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    torch_model.eval()

    expanded_name_map, remapped_transform_map = expand_torch_to_onnx_map(
        TORCH_TO_ONNX[mode], config, mode)
    copy_weights_to_torch(torch_model, proto, expanded_name_map,
                          remapped_transform_map)

    torch_outputs = run_fwd_model(inputs, torch_model)

    check_tensors(torch_outputs, outputs, margin=5e-7)
コード例 #14
0
def embedding_bwd(custom_ops,
                  mode,
                  momentum,
                  batch_size,
                  batch_serialization_factor,
                  embedding_serialization_vocab_steps,
                  vocab_length=9728,
                  hidden_size=768):
    #  ------------------- PopART --------------------
    config = BertConfig(
        task="SQUAD",
        vocab_length=vocab_length,
        batch_size=batch_size,
        hidden_size=hidden_size,
        sequence_length=128,
        activation_type='relu',
        popart_dtype="FLOAT",
        no_dropout=True,
        update_embedding_dict=True,
        embedding_serialization_vocab_steps=embedding_serialization_vocab_steps
    )

    popart_model = get_model(config, mode, 'embedding')
    # Prevent virtualGraph attributes being added to the ops

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = popart_model.builder.addInputTensor(sequence_info)
    positions = popart_model.builder.addInputTensor(sequence_info)
    segments = popart_model.builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    if momentum:
        optimizer = popart.SGD({
            "defaultLearningRate": (0.01, True),
            "defaultMomentum": (momentum, True),
            "defaultDampening": (0.0, True),
            "defaultVelocityScaling": (1.0, True),
            "lossScaling": (1.0, True),
            "defaultWeightDecay": (0.0, True)
        })
    else:
        optimizer = popart.ConstSGD(0.01)

    l1_lambda = 0.1

    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": batch_serialization_factor,
            "executionPhases": popart_model.total_execution_phases,
        }
        output = popart_model(indices, positions, segments)
        with popart_model.scope_provider(popart_model.builder,
                                         popart_model.norm.scope):
            l1 = popart_model.builder.aiGraphcore.l1loss(
                [output],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
    else:
        user_options = {"enableStochasticRounding": True}
        with popart_model.builder.nameScope("Embedding"):
            output = popart_model.embedding(indices, positions, segments)
        l1 = popart_model.builder.aiGraphcore.l1loss(
            [output],
            l1_lambda,
            debugPrefix="l1LossVal",
            reduction=popart.ReductionType.Sum)

    num_reps = 5
    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 ipus=1,
                                 loss=l1,
                                 num_reps=num_reps,
                                 optimizer=optimizer,
                                 user_options=user_options,
                                 execution_mode=mode)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------

    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        update_embedding_dict=config.update_embedding_dict))
    # Turn off dropout
    torch_model.eval()

    expanded_name_map, remapped_transform_map = expand_torch_to_onnx_map(
        TORCH_TO_ONNX[mode], config, mode)
    copy_weights_to_torch(torch_model, proto, expanded_name_map,
                          remapped_transform_map)

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            dampening=0.0,
                            momentum=momentum)

    if momentum > 0.:
        for group in optim.param_groups:
            for p in group['params']:
                optim.state[p]['momentum_buffer'] = p.data * 0
                optim.state[p]['exp_avg'] = p.data * 0
                optim.state[p]['exp_avg_sq'] = p.data * 0
                optim.state[p]['step'] = 0

    for _ in range(num_reps):
        torch_output = torch_model(
            *[torch.from_numpy(t).long() for t in inputs])
        torch_loss = l1_lambda * torch.norm(torch_output, 1)
        torch_loss.backward()
        optim.step()
        optim.zero_grad()

    torch_outputs = [torch_output.detach().numpy()]

    check_tensors(torch_outputs, outputs, margin=7e-6)

    expanded_name_map, remapped_transform_map = expand_torch_to_onnx_map(
        TORCH_TO_ONNX[mode], config, mode)
    check_model(torch_model,
                post_proto,
                expanded_name_map,
                remapped_transform_map,
                margin=7e-06)
コード例 #15
0
def test_embedding_bwd(custom_ops, mode):
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        vocab_length=9728,
                        batch_size=1,
                        hidden_size=768,
                        sequence_length=128,
                        activation_type='relu',
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        update_embedding_dict=False,
                        embedding_serialization_vocab_steps=1)

    popart_model = get_model(config, mode, 'embedding')
    # Prevent virtualGraph attributes being added to the ops.

    sequence_info = popart.TensorInfo(
        "UINT32", [config.batch_size * config.sequence_length])
    indices = popart_model.builder.addInputTensor(sequence_info)
    positions = popart_model.builder.addInputTensor(sequence_info)
    segments = popart_model.builder.addInputTensor(sequence_info)
    data = {
        indices:
        np.random.randint(0, config.vocab_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        positions:
        np.random.randint(0, config.max_positional_length,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32),
        segments:
        np.random.randint(0, 2,
                          (config.batch_size * config.sequence_length)).astype(
                              np.uint32)
    }

    optimizer = popart.ConstSGD(0.01)
    l1_lambda = 0.1

    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": 1,
            "executionPhases": popart_model.total_execution_phases,
        }
        output = popart_model(indices, positions, segments)
        with popart_model.scope_provider(popart_model.builder,
                                         popart_model.norm.scope):
            l1 = popart_model.builder.aiGraphcore.l1loss(
                [output],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
    else:
        user_options = {"enableStochasticRounding": True}
        output = popart_model.embedding(indices, positions, segments)
        l1 = popart_model.builder.aiGraphcore.l1loss(
            [output],
            l1_lambda,
            debugPrefix="l1LossVal",
            reduction=popart.ReductionType.Sum)

    proto = popart_model.builder.getModelProto()
    outputs, post_proto = run_py(proto,
                                 data,
                                 output,
                                 ipus=1,
                                 loss=l1,
                                 optimizer=optimizer,
                                 user_options=user_options,
                                 execution_mode=mode)

    # ----------------- PopART -> PyTorch ----------------
    proto = onnx.load_model_from_string(proto)

    inputs = [
        data[t].reshape(config.batch_size,
                        config.sequence_length).astype(np.int32)
        for t in [indices, positions, segments]
    ]

    #  ------------------- PyTorch -------------------------

    torch_model = BertEmbeddings(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps))
    # Turn off dropout
    torch_model.eval()

    copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX[mode], {})

    optim = torch.optim.SGD(torch_model.parameters(),
                            0.01,
                            weight_decay=0.0,
                            momentum=0.0)

    torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs])
    torch_loss = l1_lambda * torch.norm(torch_output, 1)
    torch_loss.backward()
    optim.step()

    torch_outputs = [torch_output.detach().numpy()]

    check_tensors(torch_outputs, outputs, margin=1e-06)

    check_model(torch_model, post_proto, TORCH_TO_ONNX[mode], {}, margin=1e-06)
コード例 #16
0
ファイル: nsp_test.py プロジェクト: muzzynine/examples-1
def nsp_bwd(custom_ops, mode, opt_type, vocab_length=9728, hidden_size=768):
    if mode == ExecutionMode.PHASED:
        # Phased Execution requires atleast two transformer layers to ensure mlm and embedding are in the same virtual graph.
        num_layers = 2
    else:
        num_layers = 1

    #  ------------------- PopART --------------------
    config = BertConfig(task="NSP",
                        vocab_length=vocab_length,
                        num_layers=num_layers,
                        batch_size=1,
                        hidden_size=hidden_size,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        no_mask=True,
                        update_embedding_dict=True,
                        phased_execution_type="single",
                        execution_mode=mode,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForNextSentencePrediction(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act=config.activation_type,
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        update_embedding_dict=True,
                        num_labels=2))
    l1_lambda = 0.1

    def popart_loss_fn(outputs):
        if mode == ExecutionMode.PHASED:
            with popart_model.scope_provider(popart_model.builder,
                                             popart_model.nsp_scope):
                loss = popart_model.builder.aiGraphcore.l1loss(
                    [outputs[0]],
                    l1_lambda,
                    debugPrefix="l1LossVal",
                    reduction=popart.ReductionType.Sum)
        else:
            loss = popart_model.builder.aiGraphcore.l1loss(
                [outputs[0]],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
            popart_model.builder.virtualGraph(
                loss, popart_model.nsp_scope.virtualGraph)
        return loss

    def torch_loss_fn(outputs):
        return l1_lambda * torch.norm(outputs[0], 1)

    bwd_graph(popart_model,
              torch_model,
              mode,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping=NSP_MAPPING[mode],
              transform=NSP_TRANSFORM,
              opt_type=opt_type)
コード例 #17
0
def popart_result_and_model(popart_config, mode, is_bwd=False):
    popart_model = get_model(popart_config, mode, 'feedforward')

    input_info = popart.TensorInfo(popart_config.popart_dtype, [
        popart_config.batch_size * popart_config.sequence_length,
        popart_config.hidden_size
    ])
    input_tensor = popart_model.builder.addInputTensor(input_info)

    data = {
        input_tensor:
        np.random.normal(0, 0.02,
                         input_info.shape()).astype(popart_config.dtype)
    }

    user_options = {}
    if mode == ExecutionMode.PHASED:
        user_options = {
            "batchSerializationFactor": 1,
            "executionPhases": popart_model.total_execution_phases
        }
        output = popart_model(input_tensor)
    else:
        user_options = {"enableStochasticRounding": True}
        output = popart_model.feed_forward(input_tensor)

    if is_bwd:
        l1_lambda = 0.1
        if mode == ExecutionMode.PHASED:
            with popart_model.scope_provider(popart_model.builder,
                                             popart_model.norm.scope):
                l1 = popart_model.builder.aiGraphcore.l1loss(
                    [output],
                    l1_lambda,
                    debugPrefix="l1LossVal",
                    reduction=popart.ReductionType.Sum)

        else:
            l1 = popart_model.builder.aiGraphcore.l1loss(
                [output],
                l1_lambda,
                debugPrefix="l1LossVal",
                reduction=popart.ReductionType.Sum)
        proto = popart_model.builder.getModelProto()
        optimizer = popart.ConstSGD(0.01)

        outputs, post_proto = run_py(proto,
                                     data, (output, l1),
                                     loss=l1,
                                     optimizer=optimizer,
                                     user_options=user_options,
                                     execution_mode=mode)
    else:
        proto = popart_model.builder.getModelProto()
        outputs, post_proto = run_py(proto,
                                     data,
                                     output,
                                     user_options=user_options,
                                     execution_mode=mode)

    return data[input_tensor], outputs, proto, post_proto
コード例 #18
0
def squad_bwd(mode,
              replication_factor,
              replicated_weight_sharding,
              opt_type,
              vocab_length=9728,
              hidden_size=768):
    #  ------------------- PopART --------------------
    config = BertConfig(task="SQUAD",
                        vocab_length=vocab_length,
                        num_layers=1,
                        batch_size=1,
                        hidden_size=hidden_size,
                        sequence_length=128,
                        activation_type="relu",
                        popart_dtype="FLOAT",
                        no_dropout=True,
                        no_attn_dropout=True,
                        update_embedding_dict=True,
                        no_mask=True,
                        execution_mode=mode,
                        split_qkv=(opt_type == "LAMB"))
    popart_model = get_model(config, mode)

    #  ------------------- PyTorch -------------------------
    torch_model = BertForQuestionAnswering(
        TorchBertConfig(config.vocab_length,
                        config.hidden_size,
                        num_hidden_layers=config.num_layers,
                        num_attention_heads=config.attention_heads,
                        intermediate_size=config.ff_size,
                        hidden_act="relu",
                        max_position_embeddings=config.max_positional_length,
                        layer_norm_eps=config.layer_norm_eps,
                        mask_tokens=config.mask_tokens,
                        update_embedding_dict=True,
                        num_labels=2))

    l1_lambda = 0.1

    def popart_loss_fn(outputs):
        if mode == ExecutionMode.PHASED:
            with popart_model.scope_provider(popart_model.builder,
                                             popart_model.squad_scope):
                losses = [
                    popart_model.builder.aiGraphcore.l1loss(
                        [outputs[0]],
                        l1_lambda,
                        debugPrefix="startsLossVal",
                        reduction=popart.ReductionType.Sum),
                    popart_model.builder.aiGraphcore.l1loss(
                        [outputs[1]],
                        l1_lambda,
                        debugPrefix="endsLossVal",
                        reduction=popart.ReductionType.Sum),
                ]
                final_loss = popart_model.builder.aiOnnx.sum(
                    losses, debugPrefix="finalLoss")

        else:
            losses = [
                popart_model.builder.aiGraphcore.l1loss(
                    [outputs[0]],
                    l1_lambda,
                    debugPrefix="startsLossVal",
                    reduction=popart.ReductionType.Sum),
                popart_model.builder.aiGraphcore.l1loss(
                    [outputs[1]],
                    l1_lambda,
                    debugPrefix="endsLossVal",
                    reduction=popart.ReductionType.Sum),
            ]
            for loss in losses:
                popart_model.builder.virtualGraph(
                    loss, popart_model.squad_scope.virtualGraph)

            final_loss = popart_model.builder.aiOnnx.sum(
                losses, debugPrefix="finalLoss")
            popart_model.builder.virtualGraph(
                final_loss, popart_model.squad_scope.virtualGraph)
        return final_loss

    def torch_loss_fn(outputs):
        torch_losses = [
            l1_lambda * torch.norm(output, 1) for output in outputs
        ]
        return torch.add(*torch_losses)

    bwd_graph(popart_model,
              torch_model,
              mode,
              popart_loss_fn=popart_loss_fn,
              torch_loss_fn=torch_loss_fn,
              mapping=ONNX_TORCH_MAPPING[mode],
              transform={"qa_outputs.weight": np.transpose},
              replication_factor=replication_factor,
              replicated_weight_sharding=replicated_weight_sharding,
              opt_type=opt_type)
コード例 #19
0
ファイル: predict.py プロジェクト: WangJengYun/NER-Model
                ' 老 牌 服 飾 業 者 涉 詐 貸 1 6 家 銀 行 踩 雷 12.25 億 台 銀 是 最 大 苦 主',
                '解 讀 台 積 電 張 忠 謀 在 演 講 的 4 大 關 鍵 ! 一 場 劉 德 音 、 魏 哲 家 都 沒 錯 過 的 談 話']
    test_dataset = NER_Dataset(config).load_data('test',input_sentences = sentence)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                                 shuffle=False,#num_workers = 8,pin_memory = True,
                                 collate_fn=test_dataset.collate_fn)
    
    # train_dataset = NER_Dataset(config).load_data('train')
    # train_loader = DataLoader(train_dataset, batch_size = batch_size,
    #                          shuffle=True,#num_workers = 8,pin_memory = True,
    #                          collate_fn=train_dataset.collate_fn)
    # 
    # val_dataset = NER_Dataset(config).load_data('val')
    # val_loader = DataLoader(val_dataset, batch_size=batch_size,
    #                          shuffle=True,#num_workers = 8,pin_memory = True,
    #                          collate_fn=val_dataset.collate_fn)

    # import model
    selected_model = model_config['architectures'][0]
    config['model']['selected_model'] = selected_model
    model = get_model(config = config)
    

    result = trainer(model,config,logger)
    
    # result.import_data((train_loader,val_loader),(len(train_dataset),len(val_dataset)))
    # result.selected_model = selected_model
    # metrics = result.evaluate()

    result.import_data(test_loader,len(test_dataset))
    AA = result.predict()