def _create_estimator(self, params):
    tf.logging.info("Setting random seed to {}".format(420))
    np.random.seed(420)

    # Small bert model for testing.
    bert_config = modeling.BertConfig.from_dict({
        "vocab_size": 10,
        "type_vocab_size": [3, 256, 256, 2, 256, 256, 10],
        "num_hidden_layers": 2,
        "num_attention_heads": 2,
        "hidden_size": 128,
        "intermediate_size": 512,
    })
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=params["init_checkpoint"],
        learning_rate=params["learning_rate"],
        num_train_steps=params["num_train_steps"],
        num_warmup_steps=params["num_warmup_steps"],
        use_tpu=params["use_tpu"],
        positive_weight=10.0,
        num_aggregation_labels=params["num_aggregation_labels"],
        num_classification_labels=params["num_classification_labels"],
        aggregation_loss_importance=params["aggregation_loss_importance"],
        use_answer_as_supervision=params["use_answer_as_supervision"],
        answer_loss_importance=params["answer_loss_importance"],
        use_normalized_answer_loss=params["use_normalized_answer_loss"],
        huber_loss_delta=params["huber_loss_delta"],
        temperature=params["temperature"],
        agg_temperature=params["agg_temperature"],
        use_gumbel_for_cells=params["use_gumbel_for_cells"],
        use_gumbel_for_agg=params["use_gumbel_for_agg"],
        average_approximation_function=params["average_approximation_function"],
        cell_select_pref=params["cell_select_pref"],
        answer_loss_cutoff=params["answer_loss_cutoff"],
        grad_clipping=params["grad_clipping"],
        max_num_rows=params["max_num_rows"],
        max_num_columns=params["max_num_columns"],
        average_logits_per_cell=params["average_logits_per_cell"],
        select_one_column=params["select_one_column"],
        disable_per_token_loss=params.get("disable_per_token_loss", False),
    )
    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)

    estimator = tf.estimator.tpu.TPUEstimator(
        params={
            "gradient_accumulation_steps":
                params.get("gradient_accumulation_steps", 1),
        },
        use_tpu=params["use_tpu"],
        model_fn=model_fn,
        config=tf.estimator.tpu.RunConfig(
            model_dir=self.get_temp_dir(),
            save_summary_steps=params["num_train_steps"],
            save_checkpoints_steps=params["num_train_steps"]),
        train_batch_size=params["batch_size"],
        predict_batch_size=params["batch_size"],
        eval_batch_size=params["batch_size"])

    return estimator
Exemplo n.º 2
0
    def _create_estimator(
        self,
        span_prediction=_SpanPredictionMode.NONE,
        num_aggregation_labels=0,
        num_classification_labels=0,
    ):
        # Small bert model for testing.
        bert_config = modeling.BertConfig.from_dict({
            'vocab_size':
            30522,
            'type_vocab_size': [3, 256, 256, 2, 256, 256, 10],
            'num_hidden_layers':
            2,
            'num_attention_heads':
            2,
            'hidden_size':
            128,
            'intermediate_size':
            512,
        })
        tapas_config = tapas_classifier_model.TapasClassifierConfig(
            bert_config=bert_config,
            init_checkpoint=None,
            learning_rate=0.0,
            num_train_steps=1,
            num_warmup_steps=1,
            use_tpu=False,
            positive_weight=1.0,
            num_aggregation_labels=num_aggregation_labels,
            num_classification_labels=num_classification_labels,
            aggregation_loss_importance=0.0,
            use_answer_as_supervision=False,
            answer_loss_importance=0.0,
            use_normalized_answer_loss=False,
            huber_loss_delta=0.0,
            temperature=1.0,
            agg_temperature=1.0,
            use_gumbel_for_cells=False,
            use_gumbel_for_agg=False,
            average_approximation_function='ratio',
            cell_select_pref=1.0,
            answer_loss_cutoff=0.0,
            grad_clipping=1.0,
            max_num_rows=64,
            max_num_columns=32,
            average_logits_per_cell=True,
            span_prediction=span_prediction,
            select_one_column=True)
        model_fn = tapas_classifier_model.model_fn_builder(tapas_config)

        estimator = tf.estimator.tpu.TPUEstimator(
            use_tpu=False,
            model_fn=model_fn,
            config=tf.estimator.tpu.RunConfig(model_dir=self.get_temp_dir()),
            train_batch_size=_BATCH_SIZE,
            predict_batch_size=_BATCH_SIZE,
            eval_batch_size=_BATCH_SIZE)

        return estimator
Exemplo n.º 3
0
def _train_and_predict(
    task,
    tpu_options,
    test_batch_size,
    train_batch_size,
    gradient_accumulation_steps,
    bert_config_file,
    init_checkpoint,
    test_mode,
    mode,
    output_dir,
    model_dir,
):
  """Trains, produces test predictions and eval metric."""
  file_utils.make_directories(model_dir)

  if task == tasks.Task.SQA:
    num_aggregation_labels = 0
    do_model_aggregation = False
    use_answer_as_supervision = None
  elif task in [
      tasks.Task.WTQ, tasks.Task.WIKISQL, tasks.Task.WIKISQL_SUPERVISED
  ]:
    num_aggregation_labels = 4
    do_model_aggregation = True
    use_answer_as_supervision = task != tasks.Task.WIKISQL_SUPERVISED
  else:
    raise ValueError(f'Unknown task: {task.name}')

  hparams = hparam_utils.get_hparams(task)
  if test_mode:
    if train_batch_size is None:
      train_batch_size = 1
    test_batch_size = 1
    num_train_steps = 10
    num_warmup_steps = 1
  else:
    if train_batch_size is None:
      train_batch_size = hparams['train_batch_size']
    num_train_examples = hparams['num_train_examples']
    num_train_steps = int(num_train_examples / train_batch_size)
    num_warmup_steps = int(num_train_steps * hparams['warmup_ratio'])

  bert_config = modeling.BertConfig.from_json_file(bert_config_file)
  tapas_config = tapas_classifier_model.TapasClassifierConfig(
      bert_config=bert_config,
      init_checkpoint=init_checkpoint,
      learning_rate=hparams['learning_rate'],
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=tpu_options.use_tpu,
      positive_weight=10.0,
      num_aggregation_labels=num_aggregation_labels,
      num_classification_labels=0,
      aggregation_loss_importance=1.0,
      use_answer_as_supervision=use_answer_as_supervision,
      answer_loss_importance=1.0,
      use_normalized_answer_loss=False,
      huber_loss_delta=hparams.get('huber_loss_delta'),
      temperature=hparams.get('temperature', 1.0),
      agg_temperature=1.0,
      use_gumbel_for_cells=False,
      use_gumbel_for_agg=False,
      average_approximation_function=tapas_classifier_model.\
        AverageApproximationFunction.RATIO,
      cell_select_pref=hparams.get('cell_select_pref'),
      answer_loss_cutoff=hparams.get('answer_loss_cutoff'),
      grad_clipping=hparams.get('grad_clipping'),
      disabled_features=[],
      max_num_rows=64,
      max_num_columns=32,
      average_logits_per_cell=False,
      init_cell_selection_weights_to_zero=\
        hparams['init_cell_selection_weights_to_zero'],
      select_one_column=hparams['select_one_column'],
      allow_empty_column_selection=hparams['allow_empty_column_selection'],
      disable_position_embeddings=False)

  model_fn = tapas_classifier_model.model_fn_builder(tapas_config)

  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2

  tpu_cluster_resolver = None
  if tpu_options.use_tpu and tpu_options.tpu_name:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        tpu=tpu_options.tpu_name,
        zone=tpu_options.tpu_zone,
        project=tpu_options.gcp_project,
    )

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=tpu_options.master,
      model_dir=model_dir,
      tf_random_seed=FLAGS.tf_random_seed,
      save_checkpoints_steps=1000,
      keep_checkpoint_max=5,
      keep_checkpoint_every_n_hours=4.0,
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=tpu_options.iterations_per_loop,
          num_shards=tpu_options.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  # If TPU is not available, this will fall back to normal Estimator on CPU/GPU.
  estimator = tf.estimator.tpu.TPUEstimator(
      params={'gradient_accumulation_steps': gradient_accumulation_steps},
      use_tpu=tpu_options.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=train_batch_size // gradient_accumulation_steps,
      eval_batch_size=None,
      predict_batch_size=test_batch_size)

  if mode == Mode.TRAIN:
    _print('Training')
    bert_config.to_json_file(os.path.join(model_dir, 'bert_config.json'))
    tapas_config.to_json_file(os.path.join(model_dir, 'tapas_config.json'))
    train_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name='train',
        file_patterns=_get_train_examples_file(task, output_dir),
        data_format='tfrecord',
        compression_type=FLAGS.compression_type,
        is_training=True,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=False,
        add_answer=use_answer_as_supervision,
        include_id=False,
    )
    estimator.train(
        input_fn=train_input_fn,
        max_steps=tapas_config.num_train_steps,
    )

  elif mode == Mode.PREDICT_AND_EVALUATE:

    # Starts a continous eval that starts with the latest checkpoint and runs
    # until a checkpoint with 'num_train_steps' is reached.
    prev_checkpoint = None
    while True:
      checkpoint = estimator.latest_checkpoint()

      if checkpoint == prev_checkpoint:
        _print('Sleeping 5 mins before predicting')
        time.sleep(5 * 60)
        continue

      current_step = int(os.path.basename(checkpoint).split('-')[1])
      _predict(
          estimator,
          task,
          output_dir,
          model_dir,
          do_model_aggregation,
          use_answer_as_supervision,
          use_tpu=tapas_config.use_tpu,
          global_step=current_step,
      )
      _eval(
          task=task,
          output_dir=output_dir,
          model_dir=model_dir,
          global_step=current_step)
      if current_step >= tapas_config.num_train_steps:
        _print(f'Evaluation finished after training step {current_step}.')
        break

  else:
    raise ValueError(f'Unexpected mode: {mode}.')
def main(_):
    do_model_aggregation = FLAGS.num_aggregation_labels > 0
    do_model_classification = FLAGS.num_classification_labels > 0
    bert_config = experiment_utils.bert_config_from_flags()
    total_steps = experiment_utils.num_train_steps()
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=total_steps,
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        positive_weight=FLAGS.positive_weight,
        num_aggregation_labels=FLAGS.num_aggregation_labels,
        num_classification_labels=FLAGS.num_classification_labels,
        aggregation_loss_importance=FLAGS.aggregation_loss_importance,
        use_answer_as_supervision=FLAGS.use_answer_as_supervision,
        answer_loss_importance=FLAGS.answer_loss_importance,
        use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
        huber_loss_delta=FLAGS.huber_loss_delta,
        temperature=FLAGS.temperature,
        agg_temperature=FLAGS.agg_temperature,
        use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
        use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
        average_approximation_function=tapas_classifier_model.\
          AverageApproximationFunction(FLAGS.average_approximation_function),
        cell_select_pref=FLAGS.cell_select_pref,
        answer_loss_cutoff=FLAGS.answer_loss_cutoff,
        grad_clipping=FLAGS.grad_clipping,
        reset_output_cls=FLAGS.reset_output_cls,
        disabled_features=FLAGS.disabled_features,
        max_num_rows=FLAGS.max_num_rows,
        max_num_columns=FLAGS.max_num_columns,
        average_logits_per_cell=FLAGS.average_logits_per_cell,
        init_cell_selection_weights_to_zero=FLAGS
        .init_cell_selection_weights_to_zero,
        select_one_column=FLAGS.select_one_column,
        allow_empty_column_selection=FLAGS.allow_empty_column_selection,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        disable_per_token_loss=FLAGS.disable_per_token_loss,
        reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
        span_prediction=tapas_classifier_model.SpanPredictionMode(
            FLAGS.span_prediction),
        proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None,)

    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        tapas_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=False)
        estimator.train(input_fn=train_input_fn, max_steps=total_steps)

    eval_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="eval",
        file_patterns=FLAGS.input_file_eval,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=not FLAGS.use_tpu)
    if FLAGS.do_eval:
        eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default"
        for _, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                marker_file_prefix=os.path.join(estimator.model_dir,
                                                f"eval_{eval_name}"),
                minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
            tf.logging.info("Running eval: %s", eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=eval_name,
                                        checkpoint_path=checkpoint)
            tf.logging.info("Eval result:\n%s", result)

    if FLAGS.do_predict:
        predict_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="predict",
            file_patterns=FLAGS.input_file_predict,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=not FLAGS.use_tpu)

        if FLAGS.prediction_output_dir:
            prediction_output_dir = FLAGS.prediction_output_dir
            tf.io.gfile.makedirs(prediction_output_dir)
        else:
            prediction_output_dir = estimator.model_dir

        marker_file_prefix = os.path.join(prediction_output_dir, "predict")
        # When two separate jobs are launched we don't want conflicting markers.
        if FLAGS.input_file_predict is not None:
            marker_file_prefix += "_test"
        if FLAGS.input_file_eval is not None:
            marker_file_prefix += "_dev"

        for current_step, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                single_step=FLAGS.evaluated_checkpoint_step,
                marker_file_prefix=marker_file_prefix):

            try:
                if FLAGS.input_file_predict is not None:
                    _predict_and_export_metrics(
                        mode="predict",
                        name=FLAGS.eval_name,
                        input_fn=predict_input_fn,
                        output_dir=prediction_output_dir,
                        estimator=estimator,
                        checkpoint=checkpoint,
                        current_step=current_step,
                        do_model_classification=do_model_classification,
                        do_model_aggregation=do_model_aggregation,
                        output_token_answers=not FLAGS.disable_per_token_loss)
                if FLAGS.input_file_eval is not None:
                    _predict_and_export_metrics(
                        mode="eval",
                        name=FLAGS.eval_name,
                        input_fn=eval_input_fn,
                        output_dir=prediction_output_dir,
                        estimator=estimator,
                        checkpoint=checkpoint,
                        current_step=current_step,
                        do_model_classification=do_model_classification,
                        do_model_aggregation=do_model_aggregation,
                        output_token_answers=not FLAGS.disable_per_token_loss)
            except ValueError:
                tf.logging.error(
                    "Error getting predictions for checkpoint %s: %s",
                    checkpoint, traceback.format_exc())
def main(_):
  do_model_aggregation = FLAGS.num_aggregation_labels > 0
  do_model_classification = FLAGS.num_classification_labels > 0
  bert_config = experiment_utils.bert_config_from_flags()
  total_steps = experiment_utils.num_train_steps()
  tapas_config = tapas_classifier_model.TapasClassifierConfig(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=total_steps,
      num_warmup_steps=experiment_utils.num_warmup_steps(),
      use_tpu=FLAGS.use_tpu,
      positive_weight=FLAGS.positive_weight,
      num_aggregation_labels=FLAGS.num_aggregation_labels,
      num_classification_labels=FLAGS.num_classification_labels,
      aggregation_loss_importance=FLAGS.aggregation_loss_importance,
      use_answer_as_supervision=FLAGS.use_answer_as_supervision,
      answer_loss_importance=FLAGS.answer_loss_importance,
      use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
      huber_loss_delta=FLAGS.huber_loss_delta,
      temperature=FLAGS.temperature,
      agg_temperature=FLAGS.agg_temperature,
      use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
      use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
      average_approximation_function=(
          tapas_classifier_model.AverageApproximationFunction(
              FLAGS.average_approximation_function)),
      cell_select_pref=FLAGS.cell_select_pref,
      answer_loss_cutoff=FLAGS.answer_loss_cutoff,
      grad_clipping=FLAGS.grad_clipping,
      classification_label_weight={
          int(pair.split(":")[0]): float(pair.split(":")[1])
          for pair in FLAGS.classification_label_weight.split(",")
          if pair
      },
      table_pruning_config_file=FLAGS.table_pruning_config_file,
      restrict_attention_mode=(attention_utils.RestrictAttentionMode(
          FLAGS.restrict_attention_mode)),
      restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size,
      restrict_attention_header_size=FLAGS.restrict_attention_header_size,
      restrict_attention_row_heads_ratio=(
          FLAGS.restrict_attention_row_heads_ratio),
      mask_examples_without_labels=FLAGS.mask_examples_without_labels,
      cell_cross_entropy_hard_em=FLAGS.cell_cross_entropy_hard_em,
      cell_cross_entropy=FLAGS.cell_cross_entropy,
      reset_output_cls=FLAGS.reset_output_cls,
      disabled_features=FLAGS.disabled_features,
      max_num_rows=FLAGS.max_num_rows,
      max_num_columns=FLAGS.max_num_columns,
      average_logits_per_cell=FLAGS.average_logits_per_cell,
      init_cell_selection_weights_to_zero=FLAGS
      .init_cell_selection_weights_to_zero,
      select_one_column=FLAGS.select_one_column,
      allow_empty_column_selection=FLAGS.allow_empty_column_selection,
      disable_position_embeddings=FLAGS.disable_position_embeddings,
      disable_per_token_loss=FLAGS.disable_per_token_loss,
      reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
      span_prediction=tapas_classifier_model.SpanPredictionMode(
          FLAGS.span_prediction),
      proj_value_length=_get_projection_length(FLAGS.proj_value_length),
      attention_bias_disabled=FLAGS.attention_bias_disabled,
      attention_bias_use_relative_scalar_only=FLAGS
      .attention_bias_use_relative_scalar_only,
  )

  model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
  estimator = experiment_utils.build_estimator(model_fn)

  if tapas_config.classification_label_weight:
    if any(x < 0 for x in tapas_config.classification_label_weight.values()):
      raise ValueError("Label weights cannot be negative in input: "
                       f"{tapas_config.classification_label_weight}.")
    if any(x < 0 or x >= tapas_config.num_classification_labels
           for x in tapas_config.classification_label_weight.keys()):
      raise ValueError("Invalid label in label weights for input: "
                       f"{tapas_config.classification_label_weight}.")

  if FLAGS.do_train:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    # Copy the table pruning config if pruning is used.
    if FLAGS.table_pruning_config_file:
      table_pruning_path = os.path.join(FLAGS.model_dir,
                                        "table_pruning_config.textproto")
      tf.io.gfile.copy(
          FLAGS.table_pruning_config_file, table_pruning_path, overwrite=True)
    bert_config.to_json_file(os.path.join(FLAGS.model_dir, "bert_config.json"))
    tapas_config.to_json_file(
        os.path.join(FLAGS.model_dir, "tapas_config.json"))
    train_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="train",
        file_patterns=FLAGS.input_file_train,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=True,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=False)
    estimator.train(input_fn=train_input_fn, max_steps=total_steps)

  eval_input_fn = functools.partial(
      tapas_classifier_model.input_fn,
      name="eval",
      file_patterns=FLAGS.input_file_eval,
      data_format=FLAGS.data_format,
      compression_type=FLAGS.compression_type,
      is_training=False,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=do_model_classification,
      add_answer=FLAGS.use_answer_as_supervision,
      include_id=not FLAGS.use_tpu)
  if FLAGS.do_eval:
    eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default"
    for _, checkpoint in experiment_utils.iterate_checkpoints(
        model_dir=estimator.model_dir,
        total_steps=total_steps,
        marker_file_prefix=os.path.join(estimator.model_dir,
                                        f"eval_{eval_name}"),
        minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
      tf.logging.info("Running eval: %s", eval_name)
      result = estimator.evaluate(
          input_fn=eval_input_fn,
          steps=FLAGS.num_eval_steps,
          name=eval_name,
          checkpoint_path=checkpoint)
      tf.logging.info("Eval result:\n%s", result)

  if FLAGS.do_predict:
    predict_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="predict",
        file_patterns=FLAGS.input_file_predict,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=not FLAGS.use_tpu)

    if FLAGS.prediction_output_dir:
      prediction_output_dir = FLAGS.prediction_output_dir
      tf.io.gfile.makedirs(prediction_output_dir)
    else:
      prediction_output_dir = estimator.model_dir

    marker_file_prefix = os.path.join(prediction_output_dir, "predict")
    # When two separate jobs are launched we don't want conflicting markers.
    if FLAGS.input_file_predict is not None:
      marker_file_prefix += "_test"
    if FLAGS.input_file_eval is not None:
      marker_file_prefix += "_dev"

    for current_step, checkpoint in experiment_utils.iterate_checkpoints(
        model_dir=estimator.model_dir,
        total_steps=total_steps,
        single_step=FLAGS.evaluated_checkpoint_step,
        marker_file_prefix=marker_file_prefix):

      try:
        if FLAGS.input_file_predict is not None:
          _predict_and_export_metrics(
              mode="predict",
              name=FLAGS.eval_name,
              input_fn=predict_input_fn,
              output_dir=prediction_output_dir,
              estimator=estimator,
              checkpoint=checkpoint,
              current_step=current_step,
              do_model_classification=do_model_classification,
              do_model_aggregation=do_model_aggregation,
              output_token_answers=not FLAGS.disable_per_token_loss)
        if FLAGS.input_file_eval is not None:
          _predict_and_export_metrics(
              mode="eval",
              name=FLAGS.eval_name,
              input_fn=eval_input_fn,
              output_dir=prediction_output_dir,
              estimator=estimator,
              checkpoint=checkpoint,
              current_step=current_step,
              do_model_classification=do_model_classification,
              do_model_aggregation=do_model_aggregation,
              output_token_answers=not FLAGS.disable_per_token_loss)
      except ValueError:
        tf.logging.error("Error getting predictions for checkpoint %s: %s",
                         checkpoint, traceback.format_exc())
def main(_):
    do_model_aggregation = FLAGS.num_aggregation_labels > 0
    do_model_classification = FLAGS.num_classification_labels > 0
    bert_config = experiment_utils.bert_config_from_flags()
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=experiment_utils.num_train_steps(),
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        positive_weight=FLAGS.positive_weight,
        num_aggregation_labels=FLAGS.num_aggregation_labels,
        num_classification_labels=FLAGS.num_classification_labels,
        aggregation_loss_importance=FLAGS.aggregation_loss_importance,
        use_answer_as_supervision=FLAGS.use_answer_as_supervision,
        answer_loss_importance=FLAGS.answer_loss_importance,
        use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
        huber_loss_delta=FLAGS.huber_loss_delta,
        temperature=FLAGS.temperature,
        agg_temperature=FLAGS.agg_temperature,
        use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
        use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
        average_approximation_function=tapas_classifier_model.\
          AverageApproximationFunction(FLAGS.average_approximation_function),
        cell_select_pref=FLAGS.cell_select_pref,
        answer_loss_cutoff=FLAGS.answer_loss_cutoff,
        grad_clipping=FLAGS.grad_clipping,
        disabled_features=FLAGS.disabled_features,
        max_num_rows=FLAGS.max_num_rows,
        max_num_columns=FLAGS.max_num_columns,
        average_logits_per_cell=FLAGS.average_logits_per_cell,
        init_cell_selection_weights_to_zero=FLAGS
        .init_cell_selection_weights_to_zero,
        select_one_column=FLAGS.select_one_column,
        allow_empty_column_selection=FLAGS.allow_empty_column_selection,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        disable_per_token_loss=FLAGS.disable_per_token_loss)

    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        tapas_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=False)
        estimator.train(input_fn=train_input_fn,
                        max_steps=experiment_utils.num_train_steps())

    eval_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="eval",
        file_patterns=FLAGS.input_file_eval,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=True)
    if FLAGS.do_eval:
        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if checkpoint == prev_checkpoint:
                if FLAGS.minutes_to_sleep_before_predictions > 0:
                    tf.logging.info("Sleeping %d mins before evaluation",
                                    FLAGS.minutes_to_sleep_before_predictions)
                    time.sleep(FLAGS.minutes_to_sleep_before_predictions * 60)
                    continue

            tf.logging.info("Running eval: %s", FLAGS.eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=FLAGS.eval_name)
            tf.logging.info("Eval result:\n%s", result)

            current_step = int(os.path.basename(checkpoint).split("-")[1])
            if experiment_utils.num_train_steps(
            ) is None or current_step >= experiment_utils.num_train_steps():
                tf.logging.info("Evaluation finished after training step %d",
                                current_step)
                break

            prev_checkpoint = checkpoint

    if FLAGS.do_predict:
        predict_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="predict",
            file_patterns=FLAGS.input_file_predict,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=True)

        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if checkpoint == prev_checkpoint:
                tf.logging.info("Sleeping 5 mins before predicting")
                time.sleep(5 * 60)
                continue

            current_step = int(os.path.basename(checkpoint).split("-")[1])

            def _predict_and_export_metrics(
                mode,
                input_fn,
                input_file,
                interactions_file,
            ):
                """Exports model predictions and calculates denotation metric."""
                # Predict for each new checkpoint.
                tf.logging.info("Running predictor for step %d.", current_step)
                result = estimator.predict(input_fn=input_fn)
                if FLAGS.prediction_output_dir:
                    output_dir = FLAGS.prediction_output_dir
                    tf.io.gfile.makedirs(output_dir)
                else:
                    output_dir = FLAGS.model_dir
                output_predict_file = os.path.join(
                    output_dir, f"{mode}_results_{current_step}.tsv")
                prediction_utils.write_predictions(
                    result, output_predict_file, do_model_aggregation,
                    do_model_classification,
                    FLAGS.cell_classification_threshold)

                if FLAGS.do_sequence_prediction:
                    examples_by_position = prediction_utils.read_classifier_dataset(
                        predict_data=input_file,
                        data_format=FLAGS.data_format,
                        compression_type=FLAGS.compression_type,
                        max_seq_length=FLAGS.max_seq_length,
                        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
                        add_aggregation_function_id=do_model_aggregation,
                        add_classification_labels=do_model_classification,
                        add_answer=FLAGS.use_answer_as_supervision)
                    result_sequence = prediction_utils.compute_prediction_sequence(
                        estimator=estimator,
                        examples_by_position=examples_by_position)
                    output_predict_file_sequence = os.path.join(
                        FLAGS.model_dir,
                        mode + "_results_sequence_{}.tsv").format(current_step)
                    prediction_utils.write_predictions(
                        result_sequence, output_predict_file_sequence,
                        do_model_aggregation, do_model_classification,
                        FLAGS.cell_classification_threshold)

            if FLAGS.input_file_predict is not None:
                _predict_and_export_metrics(
                    mode="predict",
                    input_fn=predict_input_fn,
                    input_file=FLAGS.input_file_predict,
                    interactions_file=FLAGS.predict_interactions_file)
            _predict_and_export_metrics(
                mode="eval",
                input_fn=eval_input_fn,
                input_file=FLAGS.input_file_eval,
                interactions_file=FLAGS.eval_interactions_file)

            num_train_steps = experiment_utils.num_train_steps()
            if num_train_steps is None or current_step >= num_train_steps:
                tf.logging.info("Predictor finished after training step %d",
                                current_step)
                break
            prev_checkpoint = checkpoint