Exemplo n.º 1
0
Arquivo: squad.py Projeto: vsl9/NeMo
def create_pipeline(
    data_dir,
    model,
    head,
    loss_fn,
    max_query_length,
    max_seq_length,
    doc_stride,
    batch_size,
    version_2_with_negative,
    num_gpus=1,
    batches_per_step=1,
    mode="train",
):
    data_layer = nemo_nlp.BertQuestionAnsweringDataLayer(
        mode=mode,
        version_2_with_negative=version_2_with_negative,
        batch_size=batch_size,
        tokenizer=tokenizer,
        data_dir=data_dir,
        max_query_length=max_query_length,
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
    )

    input_data = data_layer()

    hidden_states = model(
        input_ids=input_data.input_ids,
        token_type_ids=input_data.input_type_ids,
        attention_mask=input_data.input_mask,
    )

    qa_output = head(hidden_states=hidden_states)
    loss_output = loss_fn(
        logits=qa_output,
        start_positions=input_data.start_positions,
        end_positions=input_data.end_positions,
    )

    steps_per_epoch = len(data_layer) // (batch_size * num_gpus *
                                          batches_per_step)
    return (
        loss_output.loss,
        steps_per_epoch,
        [
            loss_output.start_logits,
            loss_output.end_logits,
            input_data.unique_ids,
        ],
        data_layer,
    )
Exemplo n.º 2
0
    def test_squad_v1(self):
        version_2_with_negative = False
        pretrained_bert_model = 'bert-base-uncased'
        batch_size = 3
        data_dir = os.path.abspath(
            os.path.join(os.path.dirname(__file__), '../data/nlp/squad/v1.1'))
        max_query_length = 64
        max_seq_length = 384
        doc_stride = 128
        max_steps = 100
        lr_warmup_proportion = 0
        eval_step_freq = 50
        lr = 3e-6
        do_lower_case = True
        n_best_size = 5
        max_answer_length = 20
        null_score_diff_threshold = 0.0

        tokenizer = nemo_nlp.NemoBertTokenizer(pretrained_bert_model)
        neural_factory = nemo.core.NeuralModuleFactory(
            backend=nemo.core.Backend.PyTorch,
            local_rank=None,
            create_tb_writer=False,
        )
        model = nemo_nlp.huggingface.BERT(
            pretrained_model_name=pretrained_bert_model)
        hidden_size = model.local_parameters["hidden_size"]
        qa_head = nemo_nlp.TokenClassifier(
            hidden_size=hidden_size,
            num_classes=2,
            num_layers=1,
            log_softmax=False,
        )
        squad_loss = nemo_nlp.QuestionAnsweringLoss()

        data_layer = nemo_nlp.BertQuestionAnsweringDataLayer(
            mode='train',
            version_2_with_negative=version_2_with_negative,
            batch_size=batch_size,
            tokenizer=tokenizer,
            data_dir=data_dir,
            max_query_length=max_query_length,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
        )

        (
            input_ids,
            input_type_ids,
            input_mask,
            start_positions,
            end_positions,
            _,
        ) = data_layer()

        hidden_states = model(
            input_ids=input_ids,
            token_type_ids=input_type_ids,
            attention_mask=input_mask,
        )

        qa_output = qa_head(hidden_states=hidden_states)
        loss, _, _ = squad_loss(
            logits=qa_output,
            start_positions=start_positions,
            end_positions=end_positions,
        )

        data_layer_eval = nemo_nlp.BertQuestionAnsweringDataLayer(
            mode='dev',
            version_2_with_negative=version_2_with_negative,
            batch_size=batch_size,
            tokenizer=tokenizer,
            data_dir=data_dir,
            max_query_length=max_query_length,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
        )
        (
            input_ids_eval,
            input_type_ids_eval,
            input_mask_eval,
            start_positions_eval,
            end_positions_eval,
            unique_ids_eval,
        ) = data_layer_eval()

        hidden_states_eval = model(
            input_ids=input_ids_eval,
            token_type_ids=input_type_ids_eval,
            attention_mask=input_mask_eval,
        )

        qa_output_eval = qa_head(hidden_states=hidden_states_eval)
        _, start_logits_eval, end_logits_eval = squad_loss(
            logits=qa_output_eval,
            start_positions=start_positions_eval,
            end_positions=end_positions_eval,
        )
        eval_output = [start_logits_eval, end_logits_eval, unique_ids_eval]

        callback_train = nemo.core.SimpleLossLoggerCallback(
            tensors=[loss],
            print_func=lambda x: print("Loss: {:.3f}".format(x[0].item())),
            get_tb_values=lambda x: [["loss", x[0]]],
            step_freq=10,
            tb_writer=neural_factory.tb_writer,
        )

        callbacks_eval = nemo.core.EvaluatorCallback(
            eval_tensors=eval_output,
            user_iter_callback=lambda x, y: eval_iter_callback(x, y),
            user_epochs_done_callback=lambda x: eval_epochs_done_callback(
                x,
                eval_data_layer=data_layer_eval,
                do_lower_case=do_lower_case,
                n_best_size=n_best_size,
                max_answer_length=max_answer_length,
                version_2_with_negative=version_2_with_negative,
                null_score_diff_threshold=null_score_diff_threshold,
            ),
            tb_writer=neural_factory.tb_writer,
            eval_step=eval_step_freq,
        )

        lr_policy_fn = get_lr_policy(
            'WarmupAnnealing',
            total_steps=max_steps,
            warmup_ratio=lr_warmup_proportion,
        )

        neural_factory.train(
            tensors_to_optimize=[loss],
            callbacks=[callback_train, callbacks_eval],
            lr_policy=lr_policy_fn,
            optimizer='adam_w',
            optimization_params={
                "max_steps": max_steps,
                "lr": lr
            },
        )