def main(_):
    do_model_aggregation = FLAGS.num_aggregation_labels > 0
    do_model_classification = FLAGS.num_classification_labels > 0
    bert_config = experiment_utils.bert_config_from_flags()
    total_steps = experiment_utils.num_train_steps()
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=total_steps,
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        positive_weight=FLAGS.positive_weight,
        num_aggregation_labels=FLAGS.num_aggregation_labels,
        num_classification_labels=FLAGS.num_classification_labels,
        aggregation_loss_importance=FLAGS.aggregation_loss_importance,
        use_answer_as_supervision=FLAGS.use_answer_as_supervision,
        answer_loss_importance=FLAGS.answer_loss_importance,
        use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
        huber_loss_delta=FLAGS.huber_loss_delta,
        temperature=FLAGS.temperature,
        agg_temperature=FLAGS.agg_temperature,
        use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
        use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
        average_approximation_function=tapas_classifier_model.\
          AverageApproximationFunction(FLAGS.average_approximation_function),
        cell_select_pref=FLAGS.cell_select_pref,
        answer_loss_cutoff=FLAGS.answer_loss_cutoff,
        grad_clipping=FLAGS.grad_clipping,
        reset_output_cls=FLAGS.reset_output_cls,
        disabled_features=FLAGS.disabled_features,
        max_num_rows=FLAGS.max_num_rows,
        max_num_columns=FLAGS.max_num_columns,
        average_logits_per_cell=FLAGS.average_logits_per_cell,
        init_cell_selection_weights_to_zero=FLAGS
        .init_cell_selection_weights_to_zero,
        select_one_column=FLAGS.select_one_column,
        allow_empty_column_selection=FLAGS.allow_empty_column_selection,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        disable_per_token_loss=FLAGS.disable_per_token_loss,
        reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
        span_prediction=tapas_classifier_model.SpanPredictionMode(
            FLAGS.span_prediction),
        proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None,)

    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        tapas_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=False)
        estimator.train(input_fn=train_input_fn, max_steps=total_steps)

    eval_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="eval",
        file_patterns=FLAGS.input_file_eval,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=not FLAGS.use_tpu)
    if FLAGS.do_eval:
        eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default"
        for _, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                marker_file_prefix=os.path.join(estimator.model_dir,
                                                f"eval_{eval_name}"),
                minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
            tf.logging.info("Running eval: %s", eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=eval_name,
                                        checkpoint_path=checkpoint)
            tf.logging.info("Eval result:\n%s", result)

    if FLAGS.do_predict:
        predict_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="predict",
            file_patterns=FLAGS.input_file_predict,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=not FLAGS.use_tpu)

        if FLAGS.prediction_output_dir:
            prediction_output_dir = FLAGS.prediction_output_dir
            tf.io.gfile.makedirs(prediction_output_dir)
        else:
            prediction_output_dir = estimator.model_dir

        marker_file_prefix = os.path.join(prediction_output_dir, "predict")
        # When two separate jobs are launched we don't want conflicting markers.
        if FLAGS.input_file_predict is not None:
            marker_file_prefix += "_test"
        if FLAGS.input_file_eval is not None:
            marker_file_prefix += "_dev"

        for current_step, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                single_step=FLAGS.evaluated_checkpoint_step,
                marker_file_prefix=marker_file_prefix):

            try:
                if FLAGS.input_file_predict is not None:
                    _predict_and_export_metrics(
                        mode="predict",
                        name=FLAGS.eval_name,
                        input_fn=predict_input_fn,
                        output_dir=prediction_output_dir,
                        estimator=estimator,
                        checkpoint=checkpoint,
                        current_step=current_step,
                        do_model_classification=do_model_classification,
                        do_model_aggregation=do_model_aggregation,
                        output_token_answers=not FLAGS.disable_per_token_loss)
                if FLAGS.input_file_eval is not None:
                    _predict_and_export_metrics(
                        mode="eval",
                        name=FLAGS.eval_name,
                        input_fn=eval_input_fn,
                        output_dir=prediction_output_dir,
                        estimator=estimator,
                        checkpoint=checkpoint,
                        current_step=current_step,
                        do_model_classification=do_model_classification,
                        do_model_aggregation=do_model_aggregation,
                        output_token_answers=not FLAGS.disable_per_token_loss)
            except ValueError:
                tf.logging.error(
                    "Error getting predictions for checkpoint %s: %s",
                    checkpoint, traceback.format_exc())
Exemple #2
0
def main(_):
    bert_config = experiment_utils.bert_config_from_flags()
    total_steps = experiment_utils.num_train_steps()
    retriever_config = table_retriever_model.RetrieverConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=total_steps,
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        grad_clipping=FLAGS.grad_clipping,
        down_projection_dim=FLAGS.down_projection_dim,
        init_from_single_encoder=FLAGS.init_from_single_encoder,
        max_query_length=FLAGS.max_query_length,
        mask_repeated_tables=FLAGS.mask_repeated_tables,
        mask_repeated_questions=FLAGS.mask_repeated_questions,
        use_out_of_core_negatives=FLAGS.use_out_of_core_negatives,
        ignore_table_content=FLAGS.ignore_table_content,
        disabled_features=FLAGS.disabled_features,
        use_mined_negatives=FLAGS.use_mined_negatives,
    )

    model_fn = table_retriever_model.model_fn_builder(retriever_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        retriever_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            table_retriever_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            compression_type=FLAGS.compression_type,
            use_mined_negatives=FLAGS.use_mined_negatives,
            include_id=False)
        estimator.train(input_fn=train_input_fn, max_steps=total_steps)

    eval_input_fn = _get_test_input_fn("eval", FLAGS.input_file_eval)
    if FLAGS.do_eval:
        if eval_input_fn is None:
            raise ValueError("No input_file_eval specified!")
        for _, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                marker_file_prefix=os.path.join(estimator.model_dir,
                                                f"eval_{FLAGS.eval_name}"),
                minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
            tf.logging.info("Running eval: %s", FLAGS.eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=FLAGS.eval_name,
                                        checkpoint_path=checkpoint)
            tf.logging.info("Eval result:\n%s", result)

    if FLAGS.do_predict:
        predict_input_fn = _get_test_input_fn("predict",
                                              FLAGS.input_file_predict)
        if FLAGS.prediction_output_dir:
            prediction_output_dir = FLAGS.prediction_output_dir
            tf.io.gfile.makedirs(prediction_output_dir)
        else:
            prediction_output_dir = estimator.model_dir

        marker_file_prefix = os.path.join(prediction_output_dir, "predict")
        # When two separate jobs are launched we don't want conflicting markers.
        if predict_input_fn is not None:
            marker_file_prefix += "_test"
        if eval_input_fn is not None:
            marker_file_prefix += "_dev"

        single_step = FLAGS.evaluated_checkpoint_step
        if FLAGS.evaluated_checkpoint_metric:
            single_step = experiment_utils.get_best_step_for_metric(
                estimator.model_dir, FLAGS.evaluated_checkpoint_metric)
        for current_step, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                marker_file_prefix=marker_file_prefix,
                single_step=single_step):
            if predict_input_fn is not None:
                _predict_and_export_metrics(mode="predict",
                                            input_fn=predict_input_fn,
                                            checkpoint_path=checkpoint,
                                            step=current_step,
                                            estimator=estimator,
                                            output_dir=prediction_output_dir)

            if eval_input_fn is not None:
                _predict_and_export_metrics(mode="eval",
                                            input_fn=eval_input_fn,
                                            checkpoint_path=checkpoint,
                                            step=current_step,
                                            estimator=estimator,
                                            output_dir=prediction_output_dir)
def main(_):
  do_model_aggregation = FLAGS.num_aggregation_labels > 0
  do_model_classification = FLAGS.num_classification_labels > 0
  bert_config = experiment_utils.bert_config_from_flags()
  total_steps = experiment_utils.num_train_steps()
  tapas_config = tapas_classifier_model.TapasClassifierConfig(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=total_steps,
      num_warmup_steps=experiment_utils.num_warmup_steps(),
      use_tpu=FLAGS.use_tpu,
      positive_weight=FLAGS.positive_weight,
      num_aggregation_labels=FLAGS.num_aggregation_labels,
      num_classification_labels=FLAGS.num_classification_labels,
      aggregation_loss_importance=FLAGS.aggregation_loss_importance,
      use_answer_as_supervision=FLAGS.use_answer_as_supervision,
      answer_loss_importance=FLAGS.answer_loss_importance,
      use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
      huber_loss_delta=FLAGS.huber_loss_delta,
      temperature=FLAGS.temperature,
      agg_temperature=FLAGS.agg_temperature,
      use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
      use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
      average_approximation_function=(
          tapas_classifier_model.AverageApproximationFunction(
              FLAGS.average_approximation_function)),
      cell_select_pref=FLAGS.cell_select_pref,
      answer_loss_cutoff=FLAGS.answer_loss_cutoff,
      grad_clipping=FLAGS.grad_clipping,
      classification_label_weight={
          int(pair.split(":")[0]): float(pair.split(":")[1])
          for pair in FLAGS.classification_label_weight.split(",")
          if pair
      },
      table_pruning_config_file=FLAGS.table_pruning_config_file,
      restrict_attention_mode=(attention_utils.RestrictAttentionMode(
          FLAGS.restrict_attention_mode)),
      restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size,
      restrict_attention_header_size=FLAGS.restrict_attention_header_size,
      restrict_attention_row_heads_ratio=(
          FLAGS.restrict_attention_row_heads_ratio),
      mask_examples_without_labels=FLAGS.mask_examples_without_labels,
      cell_cross_entropy_hard_em=FLAGS.cell_cross_entropy_hard_em,
      cell_cross_entropy=FLAGS.cell_cross_entropy,
      reset_output_cls=FLAGS.reset_output_cls,
      disabled_features=FLAGS.disabled_features,
      max_num_rows=FLAGS.max_num_rows,
      max_num_columns=FLAGS.max_num_columns,
      average_logits_per_cell=FLAGS.average_logits_per_cell,
      init_cell_selection_weights_to_zero=FLAGS
      .init_cell_selection_weights_to_zero,
      select_one_column=FLAGS.select_one_column,
      allow_empty_column_selection=FLAGS.allow_empty_column_selection,
      disable_position_embeddings=FLAGS.disable_position_embeddings,
      disable_per_token_loss=FLAGS.disable_per_token_loss,
      reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
      span_prediction=tapas_classifier_model.SpanPredictionMode(
          FLAGS.span_prediction),
      proj_value_length=_get_projection_length(FLAGS.proj_value_length),
      attention_bias_disabled=FLAGS.attention_bias_disabled,
      attention_bias_use_relative_scalar_only=FLAGS
      .attention_bias_use_relative_scalar_only,
  )

  model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
  estimator = experiment_utils.build_estimator(model_fn)

  if tapas_config.classification_label_weight:
    if any(x < 0 for x in tapas_config.classification_label_weight.values()):
      raise ValueError("Label weights cannot be negative in input: "
                       f"{tapas_config.classification_label_weight}.")
    if any(x < 0 or x >= tapas_config.num_classification_labels
           for x in tapas_config.classification_label_weight.keys()):
      raise ValueError("Invalid label in label weights for input: "
                       f"{tapas_config.classification_label_weight}.")

  if FLAGS.do_train:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    # Copy the table pruning config if pruning is used.
    if FLAGS.table_pruning_config_file:
      table_pruning_path = os.path.join(FLAGS.model_dir,
                                        "table_pruning_config.textproto")
      tf.io.gfile.copy(
          FLAGS.table_pruning_config_file, table_pruning_path, overwrite=True)
    bert_config.to_json_file(os.path.join(FLAGS.model_dir, "bert_config.json"))
    tapas_config.to_json_file(
        os.path.join(FLAGS.model_dir, "tapas_config.json"))
    train_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="train",
        file_patterns=FLAGS.input_file_train,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=True,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=False)
    estimator.train(input_fn=train_input_fn, max_steps=total_steps)

  eval_input_fn = functools.partial(
      tapas_classifier_model.input_fn,
      name="eval",
      file_patterns=FLAGS.input_file_eval,
      data_format=FLAGS.data_format,
      compression_type=FLAGS.compression_type,
      is_training=False,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=do_model_classification,
      add_answer=FLAGS.use_answer_as_supervision,
      include_id=not FLAGS.use_tpu)
  if FLAGS.do_eval:
    eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default"
    for _, checkpoint in experiment_utils.iterate_checkpoints(
        model_dir=estimator.model_dir,
        total_steps=total_steps,
        marker_file_prefix=os.path.join(estimator.model_dir,
                                        f"eval_{eval_name}"),
        minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
      tf.logging.info("Running eval: %s", eval_name)
      result = estimator.evaluate(
          input_fn=eval_input_fn,
          steps=FLAGS.num_eval_steps,
          name=eval_name,
          checkpoint_path=checkpoint)
      tf.logging.info("Eval result:\n%s", result)

  if FLAGS.do_predict:
    predict_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="predict",
        file_patterns=FLAGS.input_file_predict,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=not FLAGS.use_tpu)

    if FLAGS.prediction_output_dir:
      prediction_output_dir = FLAGS.prediction_output_dir
      tf.io.gfile.makedirs(prediction_output_dir)
    else:
      prediction_output_dir = estimator.model_dir

    marker_file_prefix = os.path.join(prediction_output_dir, "predict")
    # When two separate jobs are launched we don't want conflicting markers.
    if FLAGS.input_file_predict is not None:
      marker_file_prefix += "_test"
    if FLAGS.input_file_eval is not None:
      marker_file_prefix += "_dev"

    for current_step, checkpoint in experiment_utils.iterate_checkpoints(
        model_dir=estimator.model_dir,
        total_steps=total_steps,
        single_step=FLAGS.evaluated_checkpoint_step,
        marker_file_prefix=marker_file_prefix):

      try:
        if FLAGS.input_file_predict is not None:
          _predict_and_export_metrics(
              mode="predict",
              name=FLAGS.eval_name,
              input_fn=predict_input_fn,
              output_dir=prediction_output_dir,
              estimator=estimator,
              checkpoint=checkpoint,
              current_step=current_step,
              do_model_classification=do_model_classification,
              do_model_aggregation=do_model_aggregation,
              output_token_answers=not FLAGS.disable_per_token_loss)
        if FLAGS.input_file_eval is not None:
          _predict_and_export_metrics(
              mode="eval",
              name=FLAGS.eval_name,
              input_fn=eval_input_fn,
              output_dir=prediction_output_dir,
              estimator=estimator,
              checkpoint=checkpoint,
              current_step=current_step,
              do_model_classification=do_model_classification,
              do_model_aggregation=do_model_aggregation,
              output_token_answers=not FLAGS.disable_per_token_loss)
      except ValueError:
        tf.logging.error("Error getting predictions for checkpoint %s: %s",
                         checkpoint, traceback.format_exc())
def main(_):
    bert_config = experiment_utils.bert_config_from_flags()
    model_fn = tapas_pretraining_model.model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=experiment_utils.num_train_steps(),
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        restrict_attention_mode=attention_utils.RestrictAttentionMode(
            FLAGS.restrict_attention_mode),
        restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size,
        restrict_attention_header_size=FLAGS.restrict_attention_header_size,
        restrict_attention_row_heads_ratio=(
            FLAGS.restrict_attention_row_heads_ratio),
        disabled_features=FLAGS.disabled_features,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
        proj_value_length=FLAGS.proj_value_length
        if FLAGS.proj_value_length > 0 else None,
        attention_bias_disabled=FLAGS.attention_bias_disabled,
        attention_bias_use_relative_scalar_only=FLAGS.
        attention_bias_use_relative_scalar_only,
    )
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        train_input_fn = functools.partial(
            tapas_pretraining_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq)
        estimator.train(input_fn=train_input_fn,
                        max_steps=experiment_utils.num_train_steps())

    if FLAGS.do_eval:
        eval_input_fn = functools.partial(
            tapas_pretraining_model.input_fn,
            name="eval",
            file_patterns=FLAGS.input_file_eval,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq)

        current_step = 0
        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if checkpoint == prev_checkpoint:
                tf.logging.info("Sleeping 5 mins before evaluation")
                time.sleep(5 * 60)
                continue

            tf.logging.info("Running eval: %s", FLAGS.eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=FLAGS.eval_name)
            tf.logging.info("Eval result:\n%s", result)

            current_step = int(os.path.basename(checkpoint).split("-")[1])
            if current_step >= experiment_utils.num_train_steps():
                tf.logging.info("Evaluation finished after training step %d",
                                current_step)
                break

            prev_checkpoint = checkpoint
def main(_):
    do_model_aggregation = FLAGS.num_aggregation_labels > 0
    do_model_classification = FLAGS.num_classification_labels > 0
    bert_config = experiment_utils.bert_config_from_flags()
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=experiment_utils.num_train_steps(),
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        positive_weight=FLAGS.positive_weight,
        num_aggregation_labels=FLAGS.num_aggregation_labels,
        num_classification_labels=FLAGS.num_classification_labels,
        aggregation_loss_importance=FLAGS.aggregation_loss_importance,
        use_answer_as_supervision=FLAGS.use_answer_as_supervision,
        answer_loss_importance=FLAGS.answer_loss_importance,
        use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
        huber_loss_delta=FLAGS.huber_loss_delta,
        temperature=FLAGS.temperature,
        agg_temperature=FLAGS.agg_temperature,
        use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
        use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
        average_approximation_function=tapas_classifier_model.\
          AverageApproximationFunction(FLAGS.average_approximation_function),
        cell_select_pref=FLAGS.cell_select_pref,
        answer_loss_cutoff=FLAGS.answer_loss_cutoff,
        grad_clipping=FLAGS.grad_clipping,
        disabled_features=FLAGS.disabled_features,
        max_num_rows=FLAGS.max_num_rows,
        max_num_columns=FLAGS.max_num_columns,
        average_logits_per_cell=FLAGS.average_logits_per_cell,
        init_cell_selection_weights_to_zero=FLAGS
        .init_cell_selection_weights_to_zero,
        select_one_column=FLAGS.select_one_column,
        allow_empty_column_selection=FLAGS.allow_empty_column_selection,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        disable_per_token_loss=FLAGS.disable_per_token_loss)

    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        tapas_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=False)
        estimator.train(input_fn=train_input_fn,
                        max_steps=experiment_utils.num_train_steps())

    eval_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="eval",
        file_patterns=FLAGS.input_file_eval,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=True)
    if FLAGS.do_eval:
        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if checkpoint == prev_checkpoint:
                if FLAGS.minutes_to_sleep_before_predictions > 0:
                    tf.logging.info("Sleeping %d mins before evaluation",
                                    FLAGS.minutes_to_sleep_before_predictions)
                    time.sleep(FLAGS.minutes_to_sleep_before_predictions * 60)
                    continue

            tf.logging.info("Running eval: %s", FLAGS.eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=FLAGS.eval_name)
            tf.logging.info("Eval result:\n%s", result)

            current_step = int(os.path.basename(checkpoint).split("-")[1])
            if experiment_utils.num_train_steps(
            ) is None or current_step >= experiment_utils.num_train_steps():
                tf.logging.info("Evaluation finished after training step %d",
                                current_step)
                break

            prev_checkpoint = checkpoint

    if FLAGS.do_predict:
        predict_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="predict",
            file_patterns=FLAGS.input_file_predict,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=True)

        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if checkpoint == prev_checkpoint:
                tf.logging.info("Sleeping 5 mins before predicting")
                time.sleep(5 * 60)
                continue

            current_step = int(os.path.basename(checkpoint).split("-")[1])

            def _predict_and_export_metrics(
                mode,
                input_fn,
                input_file,
                interactions_file,
            ):
                """Exports model predictions and calculates denotation metric."""
                # Predict for each new checkpoint.
                tf.logging.info("Running predictor for step %d.", current_step)
                result = estimator.predict(input_fn=input_fn)
                if FLAGS.prediction_output_dir:
                    output_dir = FLAGS.prediction_output_dir
                    tf.io.gfile.makedirs(output_dir)
                else:
                    output_dir = FLAGS.model_dir
                output_predict_file = os.path.join(
                    output_dir, f"{mode}_results_{current_step}.tsv")
                prediction_utils.write_predictions(
                    result, output_predict_file, do_model_aggregation,
                    do_model_classification,
                    FLAGS.cell_classification_threshold)

                if FLAGS.do_sequence_prediction:
                    examples_by_position = prediction_utils.read_classifier_dataset(
                        predict_data=input_file,
                        data_format=FLAGS.data_format,
                        compression_type=FLAGS.compression_type,
                        max_seq_length=FLAGS.max_seq_length,
                        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
                        add_aggregation_function_id=do_model_aggregation,
                        add_classification_labels=do_model_classification,
                        add_answer=FLAGS.use_answer_as_supervision)
                    result_sequence = prediction_utils.compute_prediction_sequence(
                        estimator=estimator,
                        examples_by_position=examples_by_position)
                    output_predict_file_sequence = os.path.join(
                        FLAGS.model_dir,
                        mode + "_results_sequence_{}.tsv").format(current_step)
                    prediction_utils.write_predictions(
                        result_sequence, output_predict_file_sequence,
                        do_model_aggregation, do_model_classification,
                        FLAGS.cell_classification_threshold)

            if FLAGS.input_file_predict is not None:
                _predict_and_export_metrics(
                    mode="predict",
                    input_fn=predict_input_fn,
                    input_file=FLAGS.input_file_predict,
                    interactions_file=FLAGS.predict_interactions_file)
            _predict_and_export_metrics(
                mode="eval",
                input_fn=eval_input_fn,
                input_file=FLAGS.input_file_eval,
                interactions_file=FLAGS.eval_interactions_file)

            num_train_steps = experiment_utils.num_train_steps()
            if num_train_steps is None or current_step >= num_train_steps:
                tf.logging.info("Predictor finished after training step %d",
                                current_step)
                break
            prev_checkpoint = checkpoint