Ejemplo n.º 1
0
def model_fn(features, labels, mode, params):
    """Model function."""
    if labels is None:
        labels = tf.constant([""])

    reader_beam_size = params["reader_beam_size"]
    if mode == tf.estimator.ModeKeys.PREDICT:
        retriever_beam_size = reader_beam_size
    else:
        retriever_beam_size = params["retriever_beam_size"]
    assert reader_beam_size <= retriever_beam_size

    with tf.device("/cpu:0"):
        retriever_outputs = retrieve(features=features,
                                     retriever_beam_size=retriever_beam_size,
                                     mode=mode,
                                     params=params)

    with tf.variable_scope("reader"):
        reader_outputs = read(
            features=features,
            retriever_logits=retriever_outputs.logits[:reader_beam_size],
            blocks=retriever_outputs.blocks[:reader_beam_size],
            mode=mode,
            params=params,
            labels=labels)

    predictions = get_predictions(reader_outputs, params)

    if mode == tf.estimator.ModeKeys.PREDICT:
        loss = None
        train_op = None
        eval_metric_ops = None
    else:
        # [retriever_beam_size]
        retriever_correct = orqa_ops.has_answer(
            blocks=retriever_outputs.blocks, answers=labels)

        # [reader_beam_size, num_candidates]
        reader_correct = compute_correct_candidates(
            candidate_starts=reader_outputs.candidate_starts,
            candidate_ends=reader_outputs.candidate_ends,
            gold_starts=reader_outputs.gold_starts,
            gold_ends=reader_outputs.gold_ends)

        eval_metric_ops = compute_eval_metrics(
            labels=labels,
            predictions=predictions,
            retriever_correct=retriever_correct,
            reader_correct=reader_correct)

        # []
        loss = compute_loss(retriever_logits=retriever_outputs.logits,
                            retriever_correct=retriever_correct,
                            reader_logits=reader_outputs.logits,
                            reader_correct=reader_correct)

        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=min(10000,
                                 max(100,
                                     int(params["num_train_steps"] / 10))),
            use_tpu=False)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      predictions=predictions,
                                      eval_metric_ops=eval_metric_ops)
Ejemplo n.º 2
0
 def test_has_answer(self):
     result = orqa_ops.has_answer(blocks=["abcdefg", "hijklmn"],
                                  answers=["hij"])
     self.assertAllEqual(result.numpy(), [False, True])