def main(_):
    do_model_aggregation = FLAGS.num_aggregation_labels > 0
    do_model_classification = FLAGS.num_classification_labels > 0
    bert_config = experiment_utils.bert_config_from_flags()
    total_steps = experiment_utils.num_train_steps()
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=total_steps,
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        positive_weight=FLAGS.positive_weight,
        num_aggregation_labels=FLAGS.num_aggregation_labels,
        num_classification_labels=FLAGS.num_classification_labels,
        aggregation_loss_importance=FLAGS.aggregation_loss_importance,
        use_answer_as_supervision=FLAGS.use_answer_as_supervision,
        answer_loss_importance=FLAGS.answer_loss_importance,
        use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
        huber_loss_delta=FLAGS.huber_loss_delta,
        temperature=FLAGS.temperature,
        agg_temperature=FLAGS.agg_temperature,
        use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
        use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
        average_approximation_function=tapas_classifier_model.\
          AverageApproximationFunction(FLAGS.average_approximation_function),
        cell_select_pref=FLAGS.cell_select_pref,
        answer_loss_cutoff=FLAGS.answer_loss_cutoff,
        grad_clipping=FLAGS.grad_clipping,
        reset_output_cls=FLAGS.reset_output_cls,
        disabled_features=FLAGS.disabled_features,
        max_num_rows=FLAGS.max_num_rows,
        max_num_columns=FLAGS.max_num_columns,
        average_logits_per_cell=FLAGS.average_logits_per_cell,
        init_cell_selection_weights_to_zero=FLAGS
        .init_cell_selection_weights_to_zero,
        select_one_column=FLAGS.select_one_column,
        allow_empty_column_selection=FLAGS.allow_empty_column_selection,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        disable_per_token_loss=FLAGS.disable_per_token_loss,
        reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
        span_prediction=tapas_classifier_model.SpanPredictionMode(
            FLAGS.span_prediction),
        proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None,)

    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        tapas_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=False)
        estimator.train(input_fn=train_input_fn, max_steps=total_steps)

    eval_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="eval",
        file_patterns=FLAGS.input_file_eval,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=not FLAGS.use_tpu)
    if FLAGS.do_eval:
        eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default"
        for _, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                marker_file_prefix=os.path.join(estimator.model_dir,
                                                f"eval_{eval_name}"),
                minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
            tf.logging.info("Running eval: %s", eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=eval_name,
                                        checkpoint_path=checkpoint)
            tf.logging.info("Eval result:\n%s", result)

    if FLAGS.do_predict:
        predict_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="predict",
            file_patterns=FLAGS.input_file_predict,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=not FLAGS.use_tpu)

        if FLAGS.prediction_output_dir:
            prediction_output_dir = FLAGS.prediction_output_dir
            tf.io.gfile.makedirs(prediction_output_dir)
        else:
            prediction_output_dir = estimator.model_dir

        marker_file_prefix = os.path.join(prediction_output_dir, "predict")
        # When two separate jobs are launched we don't want conflicting markers.
        if FLAGS.input_file_predict is not None:
            marker_file_prefix += "_test"
        if FLAGS.input_file_eval is not None:
            marker_file_prefix += "_dev"

        for current_step, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                single_step=FLAGS.evaluated_checkpoint_step,
                marker_file_prefix=marker_file_prefix):

            try:
                if FLAGS.input_file_predict is not None:
                    _predict_and_export_metrics(
                        mode="predict",
                        name=FLAGS.eval_name,
                        input_fn=predict_input_fn,
                        output_dir=prediction_output_dir,
                        estimator=estimator,
                        checkpoint=checkpoint,
                        current_step=current_step,
                        do_model_classification=do_model_classification,
                        do_model_aggregation=do_model_aggregation,
                        output_token_answers=not FLAGS.disable_per_token_loss)
                if FLAGS.input_file_eval is not None:
                    _predict_and_export_metrics(
                        mode="eval",
                        name=FLAGS.eval_name,
                        input_fn=eval_input_fn,
                        output_dir=prediction_output_dir,
                        estimator=estimator,
                        checkpoint=checkpoint,
                        current_step=current_step,
                        do_model_classification=do_model_classification,
                        do_model_aggregation=do_model_aggregation,
                        output_token_answers=not FLAGS.disable_per_token_loss)
            except ValueError:
                tf.logging.error(
                    "Error getting predictions for checkpoint %s: %s",
                    checkpoint, traceback.format_exc())
示例#2
0
def main(_):
    do_model_aggregation = FLAGS.num_aggregation_labels > 0
    do_model_classification = FLAGS.num_classification_labels > 0
    bert_config = experiment_utils.bert_config_from_flags()
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=experiment_utils.num_train_steps(),
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        positive_weight=FLAGS.positive_weight,
        num_aggregation_labels=FLAGS.num_aggregation_labels,
        num_classification_labels=FLAGS.num_classification_labels,
        aggregation_loss_importance=FLAGS.aggregation_loss_importance,
        use_answer_as_supervision=FLAGS.use_answer_as_supervision,
        answer_loss_importance=FLAGS.answer_loss_importance,
        use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
        huber_loss_delta=FLAGS.huber_loss_delta,
        temperature=FLAGS.temperature,
        agg_temperature=FLAGS.agg_temperature,
        use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
        use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
        average_approximation_function=tapas_classifier_model.\
          AverageApproximationFunction(FLAGS.average_approximation_function),
        cell_select_pref=FLAGS.cell_select_pref,
        answer_loss_cutoff=FLAGS.answer_loss_cutoff,
        grad_clipping=FLAGS.grad_clipping,
        disabled_features=FLAGS.disabled_features,
        max_num_rows=FLAGS.max_num_rows,
        max_num_columns=FLAGS.max_num_columns,
        average_logits_per_cell=FLAGS.average_logits_per_cell,
        init_cell_selection_weights_to_zero=FLAGS
        .init_cell_selection_weights_to_zero,
        select_one_column=FLAGS.select_one_column,
        allow_empty_column_selection=FLAGS.allow_empty_column_selection,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        disable_per_token_loss=FLAGS.disable_per_token_loss,
        reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
        span_prediction=tapas_classifier_model.SpanPredictionMode(
            FLAGS.span_prediction),)

    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        tapas_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=False)
        estimator.train(input_fn=train_input_fn,
                        max_steps=experiment_utils.num_train_steps())

    eval_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="eval",
        file_patterns=FLAGS.input_file_eval,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=True)
    if FLAGS.do_eval:
        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if checkpoint == prev_checkpoint:
                if FLAGS.minutes_to_sleep_before_predictions > 0:
                    tf.logging.info("Sleeping %d mins before evaluation",
                                    FLAGS.minutes_to_sleep_before_predictions)
                    time.sleep(FLAGS.minutes_to_sleep_before_predictions * 60)
                    continue

            tf.logging.info("Running eval: %s", FLAGS.eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=FLAGS.eval_name,
                                        checkpoint_path=checkpoint)
            tf.logging.info("Eval result:\n%s", result)

            current_step = int(os.path.basename(checkpoint).split("-")[1])
            if experiment_utils.num_train_steps(
            ) is None or current_step >= experiment_utils.num_train_steps():
                tf.logging.info("Evaluation finished after training step %d",
                                current_step)
                break

            prev_checkpoint = checkpoint

    if FLAGS.do_predict:
        predict_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="predict",
            file_patterns=FLAGS.input_file_predict,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=True)

        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if checkpoint == prev_checkpoint:
                tf.logging.info("Sleeping 5 mins before predicting")
                time.sleep(5 * 60)
                continue

            current_step = int(os.path.basename(checkpoint).split("-")[1])

            def _predict_and_export_metrics(
                mode,
                input_fn,
                input_file,
                interactions_file,
            ):
                """Exports model predictions and calculates denotation metric."""
                # Predict for each new checkpoint.
                tf.logging.info(
                    "Running predictor for step %d (%s).",
                    current_step,
                    checkpoint,
                )
                result = estimator.predict(
                    input_fn=input_fn,
                    checkpoint_path=checkpoint,
                )
                if FLAGS.prediction_output_dir:
                    output_dir = FLAGS.prediction_output_dir
                    tf.io.gfile.makedirs(output_dir)
                else:
                    output_dir = FLAGS.model_dir
                output_predict_file = os.path.join(
                    output_dir, f"{mode}_results_{current_step}.tsv")
                prediction_utils.write_predictions(
                    result, output_predict_file, do_model_aggregation,
                    do_model_classification,
                    FLAGS.cell_classification_threshold)

                if FLAGS.do_sequence_prediction:
                    examples_by_position = prediction_utils.read_classifier_dataset(
                        predict_data=input_file,
                        data_format=FLAGS.data_format,
                        compression_type=FLAGS.compression_type,
                        max_seq_length=FLAGS.max_seq_length,
                        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
                        add_aggregation_function_id=do_model_aggregation,
                        add_classification_labels=do_model_classification,
                        add_answer=FLAGS.use_answer_as_supervision)
                    result_sequence = prediction_utils.compute_prediction_sequence(
                        estimator=estimator,
                        examples_by_position=examples_by_position)
                    output_predict_file_sequence = os.path.join(
                        FLAGS.model_dir,
                        mode + "_results_sequence_{}.tsv").format(current_step)
                    prediction_utils.write_predictions(
                        result_sequence, output_predict_file_sequence,
                        do_model_aggregation, do_model_classification,
                        FLAGS.cell_classification_threshold)

            if FLAGS.input_file_predict is not None:
                _predict_and_export_metrics(
                    mode="predict",
                    input_fn=predict_input_fn,
                    input_file=FLAGS.input_file_predict,
                    interactions_file=FLAGS.predict_interactions_file)
            if FLAGS.input_file_eval is not None:
                _predict_and_export_metrics(
                    mode="eval",
                    input_fn=eval_input_fn,
                    input_file=FLAGS.input_file_eval,
                    interactions_file=FLAGS.eval_interactions_file)

            num_train_steps = experiment_utils.num_train_steps()
            if num_train_steps is None or current_step >= num_train_steps:
                tf.logging.info(
                    "Predictor finished after training step %d",
                    current_step,
                )
                break
            prev_checkpoint = checkpoint
def main(_):
  do_model_aggregation = FLAGS.num_aggregation_labels > 0
  do_model_classification = FLAGS.num_classification_labels > 0
  bert_config = experiment_utils.bert_config_from_flags()
  total_steps = experiment_utils.num_train_steps()
  tapas_config = tapas_classifier_model.TapasClassifierConfig(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=total_steps,
      num_warmup_steps=experiment_utils.num_warmup_steps(),
      use_tpu=FLAGS.use_tpu,
      positive_weight=FLAGS.positive_weight,
      num_aggregation_labels=FLAGS.num_aggregation_labels,
      num_classification_labels=FLAGS.num_classification_labels,
      aggregation_loss_importance=FLAGS.aggregation_loss_importance,
      use_answer_as_supervision=FLAGS.use_answer_as_supervision,
      answer_loss_importance=FLAGS.answer_loss_importance,
      use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
      huber_loss_delta=FLAGS.huber_loss_delta,
      temperature=FLAGS.temperature,
      agg_temperature=FLAGS.agg_temperature,
      use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
      use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
      average_approximation_function=(
          tapas_classifier_model.AverageApproximationFunction(
              FLAGS.average_approximation_function)),
      cell_select_pref=FLAGS.cell_select_pref,
      answer_loss_cutoff=FLAGS.answer_loss_cutoff,
      grad_clipping=FLAGS.grad_clipping,
      classification_label_weight={
          int(pair.split(":")[0]): float(pair.split(":")[1])
          for pair in FLAGS.classification_label_weight.split(",")
          if pair
      },
      table_pruning_config_file=FLAGS.table_pruning_config_file,
      restrict_attention_mode=(attention_utils.RestrictAttentionMode(
          FLAGS.restrict_attention_mode)),
      restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size,
      restrict_attention_header_size=FLAGS.restrict_attention_header_size,
      restrict_attention_row_heads_ratio=(
          FLAGS.restrict_attention_row_heads_ratio),
      mask_examples_without_labels=FLAGS.mask_examples_without_labels,
      cell_cross_entropy_hard_em=FLAGS.cell_cross_entropy_hard_em,
      cell_cross_entropy=FLAGS.cell_cross_entropy,
      reset_output_cls=FLAGS.reset_output_cls,
      disabled_features=FLAGS.disabled_features,
      max_num_rows=FLAGS.max_num_rows,
      max_num_columns=FLAGS.max_num_columns,
      average_logits_per_cell=FLAGS.average_logits_per_cell,
      init_cell_selection_weights_to_zero=FLAGS
      .init_cell_selection_weights_to_zero,
      select_one_column=FLAGS.select_one_column,
      allow_empty_column_selection=FLAGS.allow_empty_column_selection,
      disable_position_embeddings=FLAGS.disable_position_embeddings,
      disable_per_token_loss=FLAGS.disable_per_token_loss,
      reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
      span_prediction=tapas_classifier_model.SpanPredictionMode(
          FLAGS.span_prediction),
      proj_value_length=_get_projection_length(FLAGS.proj_value_length),
      attention_bias_disabled=FLAGS.attention_bias_disabled,
      attention_bias_use_relative_scalar_only=FLAGS
      .attention_bias_use_relative_scalar_only,
  )

  model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
  estimator = experiment_utils.build_estimator(model_fn)

  if tapas_config.classification_label_weight:
    if any(x < 0 for x in tapas_config.classification_label_weight.values()):
      raise ValueError("Label weights cannot be negative in input: "
                       f"{tapas_config.classification_label_weight}.")
    if any(x < 0 or x >= tapas_config.num_classification_labels
           for x in tapas_config.classification_label_weight.keys()):
      raise ValueError("Invalid label in label weights for input: "
                       f"{tapas_config.classification_label_weight}.")

  if FLAGS.do_train:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    # Copy the table pruning config if pruning is used.
    if FLAGS.table_pruning_config_file:
      table_pruning_path = os.path.join(FLAGS.model_dir,
                                        "table_pruning_config.textproto")
      tf.io.gfile.copy(
          FLAGS.table_pruning_config_file, table_pruning_path, overwrite=True)
    bert_config.to_json_file(os.path.join(FLAGS.model_dir, "bert_config.json"))
    tapas_config.to_json_file(
        os.path.join(FLAGS.model_dir, "tapas_config.json"))
    train_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="train",
        file_patterns=FLAGS.input_file_train,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=True,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=False)
    estimator.train(input_fn=train_input_fn, max_steps=total_steps)

  eval_input_fn = functools.partial(
      tapas_classifier_model.input_fn,
      name="eval",
      file_patterns=FLAGS.input_file_eval,
      data_format=FLAGS.data_format,
      compression_type=FLAGS.compression_type,
      is_training=False,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=do_model_classification,
      add_answer=FLAGS.use_answer_as_supervision,
      include_id=not FLAGS.use_tpu)
  if FLAGS.do_eval:
    eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default"
    for _, checkpoint in experiment_utils.iterate_checkpoints(
        model_dir=estimator.model_dir,
        total_steps=total_steps,
        marker_file_prefix=os.path.join(estimator.model_dir,
                                        f"eval_{eval_name}"),
        minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
      tf.logging.info("Running eval: %s", eval_name)
      result = estimator.evaluate(
          input_fn=eval_input_fn,
          steps=FLAGS.num_eval_steps,
          name=eval_name,
          checkpoint_path=checkpoint)
      tf.logging.info("Eval result:\n%s", result)

  if FLAGS.do_predict:
    predict_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="predict",
        file_patterns=FLAGS.input_file_predict,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=not FLAGS.use_tpu)

    if FLAGS.prediction_output_dir:
      prediction_output_dir = FLAGS.prediction_output_dir
      tf.io.gfile.makedirs(prediction_output_dir)
    else:
      prediction_output_dir = estimator.model_dir

    marker_file_prefix = os.path.join(prediction_output_dir, "predict")
    # When two separate jobs are launched we don't want conflicting markers.
    if FLAGS.input_file_predict is not None:
      marker_file_prefix += "_test"
    if FLAGS.input_file_eval is not None:
      marker_file_prefix += "_dev"

    for current_step, checkpoint in experiment_utils.iterate_checkpoints(
        model_dir=estimator.model_dir,
        total_steps=total_steps,
        single_step=FLAGS.evaluated_checkpoint_step,
        marker_file_prefix=marker_file_prefix):

      try:
        if FLAGS.input_file_predict is not None:
          _predict_and_export_metrics(
              mode="predict",
              name=FLAGS.eval_name,
              input_fn=predict_input_fn,
              output_dir=prediction_output_dir,
              estimator=estimator,
              checkpoint=checkpoint,
              current_step=current_step,
              do_model_classification=do_model_classification,
              do_model_aggregation=do_model_aggregation,
              output_token_answers=not FLAGS.disable_per_token_loss)
        if FLAGS.input_file_eval is not None:
          _predict_and_export_metrics(
              mode="eval",
              name=FLAGS.eval_name,
              input_fn=eval_input_fn,
              output_dir=prediction_output_dir,
              estimator=estimator,
              checkpoint=checkpoint,
              current_step=current_step,
              do_model_classification=do_model_classification,
              do_model_aggregation=do_model_aggregation,
              output_token_answers=not FLAGS.disable_per_token_loss)
      except ValueError:
        tf.logging.error("Error getting predictions for checkpoint %s: %s",
                         checkpoint, traceback.format_exc())
示例#4
0
def _train_and_predict(
    task,
    tpu_options,
    test_batch_size,
    train_batch_size,
    gradient_accumulation_steps,
    bert_config_file,
    init_checkpoint,
    test_mode,
    mode,
    output_dir,
    model_dir,
    loop_predict,
):
    """Trains, produces test predictions and eval metric."""
    file_utils.make_directories(model_dir)

    if task in (tasks.Task.SQA, tasks.Task.HYBRIDQA, tasks.Task.HYBRIDQA_RC):
        num_aggregation_labels = 0
        num_classification_labels = 0
        use_answer_as_supervision = False
    elif task in [
            tasks.Task.WTQ, tasks.Task.WIKISQL, tasks.Task.WIKISQL_SUPERVISED
    ]:
        num_aggregation_labels = 4
        num_classification_labels = 0
        use_answer_as_supervision = task != tasks.Task.WIKISQL_SUPERVISED
    elif task == tasks.Task.TABFACT:
        num_classification_labels = 2
        num_aggregation_labels = 0
        use_answer_as_supervision = True
    elif task == tasks.Task.NQ_RETRIEVAL:
        num_aggregation_labels = 0
        num_classification_labels = 2
        use_answer_as_supervision = False
    else:
        raise ValueError(f'Unknown task: {task.name}')

    do_model_aggregation = num_aggregation_labels > 0
    do_model_classification = num_classification_labels > 0

    hparams = hparam_utils.get_hparams(task)
    if test_mode:
        if train_batch_size is None:
            train_batch_size = 1
        test_batch_size = 1
        num_train_steps = 10
        num_warmup_steps = 1
    else:
        if train_batch_size is None:
            train_batch_size = hparams['train_batch_size']
        num_train_examples = hparams['num_train_examples']
        num_train_steps = int(num_train_examples / train_batch_size)
        num_warmup_steps = int(num_train_steps * hparams['warmup_ratio'])

    bert_config = modeling.BertConfig.from_json_file(bert_config_file)
    if 'bert_config_attention_probs_dropout_prob' in hparams:
        bert_config.attention_probs_dropout_prob = hparams.get(
            'bert_config_attention_probs_dropout_prob')
    if 'bert_config_hidden_dropout_prob' in hparams:
        bert_config.hidden_dropout_prob = hparams.get(
            'bert_config_hidden_dropout_prob')
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=init_checkpoint,
        learning_rate=hparams['learning_rate'],
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=tpu_options.use_tpu,
        positive_weight=10.0,
        num_aggregation_labels=num_aggregation_labels,
        num_classification_labels=num_classification_labels,
        aggregation_loss_importance=1.0,
        use_answer_as_supervision=use_answer_as_supervision,
        answer_loss_importance=1.0,
        use_normalized_answer_loss=False,
        huber_loss_delta=hparams.get('huber_loss_delta'),
        temperature=hparams.get('temperature', 1.0),
        agg_temperature=1.0,
        use_gumbel_for_cells=False,
        use_gumbel_for_agg=False,
        average_approximation_function=(
            tapas_classifier_model.AverageApproximationFunction.RATIO),
        cell_select_pref=hparams.get('cell_select_pref'),
        answer_loss_cutoff=hparams.get('answer_loss_cutoff'),
        grad_clipping=hparams.get('grad_clipping'),
        disabled_features=[],
        max_num_rows=64,
        max_num_columns=32,
        average_logits_per_cell=False,
        disable_per_token_loss=hparams.get('disable_per_token_loss', False),
        mask_examples_without_labels=hparams.get(
            'mask_examples_without_labels', False),
        init_cell_selection_weights_to_zero=(
            hparams['init_cell_selection_weights_to_zero']),
        select_one_column=hparams['select_one_column'],
        allow_empty_column_selection=hparams['allow_empty_column_selection'],
        span_prediction=tapas_classifier_model.SpanPredictionMode(
            hparams.get('span_prediction',
                        tapas_classifier_model.SpanPredictionMode.NONE)),
        disable_position_embeddings=False,
        reset_output_cls=FLAGS.reset_output_cls,
        reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
        table_pruning_config_file=FLAGS.table_pruning_config_file)

    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)

    is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2

    tpu_cluster_resolver = None
    if tpu_options.use_tpu and tpu_options.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=tpu_options.tpu_name,
            zone=tpu_options.tpu_zone,
            project=tpu_options.gcp_project,
        )

    run_config = tf_estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=tpu_options.master,
        model_dir=model_dir,
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_steps=1000,
        keep_checkpoint_max=5,
        keep_checkpoint_every_n_hours=4.0,
        tpu_config=tf_estimator.tpu.TPUConfig(
            iterations_per_loop=tpu_options.iterations_per_loop,
            num_shards=tpu_options.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    # If TPU is not available, this will fall back to normal Estimator on CPU/GPU.
    estimator = tf_estimator.tpu.TPUEstimator(
        params={'gradient_accumulation_steps': gradient_accumulation_steps},
        use_tpu=tpu_options.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=train_batch_size // gradient_accumulation_steps,
        eval_batch_size=None,
        predict_batch_size=test_batch_size)

    if mode == Mode.TRAIN:
        _print('Training')
        bert_config.to_json_file(os.path.join(model_dir, 'bert_config.json'))
        tapas_config.to_json_file(os.path.join(model_dir, 'tapas_config.json'))
        train_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name='train',
            file_patterns=_get_train_examples_file(task, output_dir),
            data_format='tfrecord',
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=use_answer_as_supervision,
            include_id=False,
        )
        estimator.train(
            input_fn=train_input_fn,
            max_steps=tapas_config.num_train_steps,
        )

    elif mode == Mode.PREDICT_AND_EVALUATE or mode == Mode.PREDICT:

        # Starts a continous eval that starts with the latest checkpoint and runs
        # until a checkpoint with 'num_train_steps' is reached.
        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if not loop_predict and not checkpoint:
                raise ValueError(f'No checkpoint found at {model_dir}.')

            if loop_predict and checkpoint == prev_checkpoint:
                _print('Sleeping 5 mins before predicting')
                time.sleep(5 * 60)
                continue

            current_step = int(os.path.basename(checkpoint).split('-')[1])
            _predict(
                estimator,
                task,
                output_dir,
                model_dir,
                do_model_aggregation,
                do_model_classification,
                use_answer_as_supervision,
                use_tpu=tapas_config.use_tpu,
                global_step=current_step,
            )
            if mode == Mode.PREDICT_AND_EVALUATE:
                _eval(task=task,
                      output_dir=output_dir,
                      model_dir=model_dir,
                      global_step=current_step)
            if not loop_predict or current_step >= tapas_config.num_train_steps:
                _print(
                    f'Evaluation finished after training step {current_step}.')
                break

            prev_checkpoint = checkpoint

    else:
        raise ValueError(f'Unexpected mode: {mode}.')