示例#1
0
  def test_iterate_checkpoints_multi_step(self):
    test_tmpdir = tempfile.mkdtemp()
    checkpoints = [
        os.path.join(test_tmpdir, checkpoint) for checkpoint in
        ['model.ckpt-00001', 'model.ckpt-00002', 'model.ckpt-00003']
    ]
    # Write fake checkpoint file to tmpdir.
    state = tf.train.generate_checkpoint_state_proto(
        test_tmpdir,
        model_checkpoint_path=checkpoints[-1],
        all_model_checkpoint_paths=checkpoints)
    with open(os.path.join(test_tmpdir, 'checkpoint'), 'w') as f:
      f.write(text_format.MessageToString(state))
    for checkpoint in checkpoints:
      with open(f'{checkpoint}.index', 'w') as f:
        f.write('\n')

    marker_file_prefix = os.path.join(test_tmpdir, 'marker')
    results = list(
        experiment_utils.iterate_checkpoints(
            model_dir=test_tmpdir,
            total_steps=3,
            marker_file_prefix=marker_file_prefix))

    expected_steps = [1, 2, 3]
    self.assertEqual(results, list(zip(expected_steps, checkpoints)))
    for step in expected_steps:
      self.assertTrue(tf.gfile.Exists(f'{marker_file_prefix}-{step}.done'))

    results = list(
        experiment_utils.iterate_checkpoints(
            model_dir=test_tmpdir,
            total_steps=3,
            marker_file_prefix=marker_file_prefix))
    self.assertEmpty(results)
 def test_iterate_checkpoints_single_step(self):
     results = list(
         experiment_utils.iterate_checkpoints(model_dir='path',
                                              single_step=100,
                                              marker_file_prefix='path',
                                              total_steps=None))
     self.assertEqual(results, [(100, 'path/model.ckpt-100')])
def main(_):
    do_model_aggregation = FLAGS.num_aggregation_labels > 0
    do_model_classification = FLAGS.num_classification_labels > 0
    bert_config = experiment_utils.bert_config_from_flags()
    total_steps = experiment_utils.num_train_steps()
    tapas_config = tapas_classifier_model.TapasClassifierConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=total_steps,
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        positive_weight=FLAGS.positive_weight,
        num_aggregation_labels=FLAGS.num_aggregation_labels,
        num_classification_labels=FLAGS.num_classification_labels,
        aggregation_loss_importance=FLAGS.aggregation_loss_importance,
        use_answer_as_supervision=FLAGS.use_answer_as_supervision,
        answer_loss_importance=FLAGS.answer_loss_importance,
        use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
        huber_loss_delta=FLAGS.huber_loss_delta,
        temperature=FLAGS.temperature,
        agg_temperature=FLAGS.agg_temperature,
        use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
        use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
        average_approximation_function=tapas_classifier_model.\
          AverageApproximationFunction(FLAGS.average_approximation_function),
        cell_select_pref=FLAGS.cell_select_pref,
        answer_loss_cutoff=FLAGS.answer_loss_cutoff,
        grad_clipping=FLAGS.grad_clipping,
        reset_output_cls=FLAGS.reset_output_cls,
        disabled_features=FLAGS.disabled_features,
        max_num_rows=FLAGS.max_num_rows,
        max_num_columns=FLAGS.max_num_columns,
        average_logits_per_cell=FLAGS.average_logits_per_cell,
        init_cell_selection_weights_to_zero=FLAGS
        .init_cell_selection_weights_to_zero,
        select_one_column=FLAGS.select_one_column,
        allow_empty_column_selection=FLAGS.allow_empty_column_selection,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        disable_per_token_loss=FLAGS.disable_per_token_loss,
        reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
        span_prediction=tapas_classifier_model.SpanPredictionMode(
            FLAGS.span_prediction),
        proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None,)

    model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        tapas_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=False)
        estimator.train(input_fn=train_input_fn, max_steps=total_steps)

    eval_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="eval",
        file_patterns=FLAGS.input_file_eval,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=not FLAGS.use_tpu)
    if FLAGS.do_eval:
        eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default"
        for _, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                marker_file_prefix=os.path.join(estimator.model_dir,
                                                f"eval_{eval_name}"),
                minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
            tf.logging.info("Running eval: %s", eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=eval_name,
                                        checkpoint_path=checkpoint)
            tf.logging.info("Eval result:\n%s", result)

    if FLAGS.do_predict:
        predict_input_fn = functools.partial(
            tapas_classifier_model.input_fn,
            name="predict",
            file_patterns=FLAGS.input_file_predict,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision,
            include_id=not FLAGS.use_tpu)

        if FLAGS.prediction_output_dir:
            prediction_output_dir = FLAGS.prediction_output_dir
            tf.io.gfile.makedirs(prediction_output_dir)
        else:
            prediction_output_dir = estimator.model_dir

        marker_file_prefix = os.path.join(prediction_output_dir, "predict")
        # When two separate jobs are launched we don't want conflicting markers.
        if FLAGS.input_file_predict is not None:
            marker_file_prefix += "_test"
        if FLAGS.input_file_eval is not None:
            marker_file_prefix += "_dev"

        for current_step, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                single_step=FLAGS.evaluated_checkpoint_step,
                marker_file_prefix=marker_file_prefix):

            try:
                if FLAGS.input_file_predict is not None:
                    _predict_and_export_metrics(
                        mode="predict",
                        name=FLAGS.eval_name,
                        input_fn=predict_input_fn,
                        output_dir=prediction_output_dir,
                        estimator=estimator,
                        checkpoint=checkpoint,
                        current_step=current_step,
                        do_model_classification=do_model_classification,
                        do_model_aggregation=do_model_aggregation,
                        output_token_answers=not FLAGS.disable_per_token_loss)
                if FLAGS.input_file_eval is not None:
                    _predict_and_export_metrics(
                        mode="eval",
                        name=FLAGS.eval_name,
                        input_fn=eval_input_fn,
                        output_dir=prediction_output_dir,
                        estimator=estimator,
                        checkpoint=checkpoint,
                        current_step=current_step,
                        do_model_classification=do_model_classification,
                        do_model_aggregation=do_model_aggregation,
                        output_token_answers=not FLAGS.disable_per_token_loss)
            except ValueError:
                tf.logging.error(
                    "Error getting predictions for checkpoint %s: %s",
                    checkpoint, traceback.format_exc())
示例#4
0
def main(_):
    bert_config = experiment_utils.bert_config_from_flags()
    total_steps = experiment_utils.num_train_steps()
    retriever_config = table_retriever_model.RetrieverConfig(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=total_steps,
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        grad_clipping=FLAGS.grad_clipping,
        down_projection_dim=FLAGS.down_projection_dim,
        init_from_single_encoder=FLAGS.init_from_single_encoder,
        max_query_length=FLAGS.max_query_length,
        mask_repeated_tables=FLAGS.mask_repeated_tables,
        mask_repeated_questions=FLAGS.mask_repeated_questions,
        use_out_of_core_negatives=FLAGS.use_out_of_core_negatives,
        ignore_table_content=FLAGS.ignore_table_content,
        disabled_features=FLAGS.disabled_features,
        use_mined_negatives=FLAGS.use_mined_negatives,
    )

    model_fn = table_retriever_model.model_fn_builder(retriever_config)
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        retriever_config.to_json_file(
            os.path.join(FLAGS.model_dir, "tapas_config.json"))
        train_input_fn = functools.partial(
            table_retriever_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            compression_type=FLAGS.compression_type,
            use_mined_negatives=FLAGS.use_mined_negatives,
            include_id=False)
        estimator.train(input_fn=train_input_fn, max_steps=total_steps)

    eval_input_fn = _get_test_input_fn("eval", FLAGS.input_file_eval)
    if FLAGS.do_eval:
        if eval_input_fn is None:
            raise ValueError("No input_file_eval specified!")
        for _, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                marker_file_prefix=os.path.join(estimator.model_dir,
                                                f"eval_{FLAGS.eval_name}"),
                minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
            tf.logging.info("Running eval: %s", FLAGS.eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=FLAGS.eval_name,
                                        checkpoint_path=checkpoint)
            tf.logging.info("Eval result:\n%s", result)

    if FLAGS.do_predict:
        predict_input_fn = _get_test_input_fn("predict",
                                              FLAGS.input_file_predict)
        if FLAGS.prediction_output_dir:
            prediction_output_dir = FLAGS.prediction_output_dir
            tf.io.gfile.makedirs(prediction_output_dir)
        else:
            prediction_output_dir = estimator.model_dir

        marker_file_prefix = os.path.join(prediction_output_dir, "predict")
        # When two separate jobs are launched we don't want conflicting markers.
        if predict_input_fn is not None:
            marker_file_prefix += "_test"
        if eval_input_fn is not None:
            marker_file_prefix += "_dev"

        single_step = FLAGS.evaluated_checkpoint_step
        if FLAGS.evaluated_checkpoint_metric:
            single_step = experiment_utils.get_best_step_for_metric(
                estimator.model_dir, FLAGS.evaluated_checkpoint_metric)
        for current_step, checkpoint in experiment_utils.iterate_checkpoints(
                model_dir=estimator.model_dir,
                total_steps=total_steps,
                marker_file_prefix=marker_file_prefix,
                single_step=single_step):
            if predict_input_fn is not None:
                _predict_and_export_metrics(mode="predict",
                                            input_fn=predict_input_fn,
                                            checkpoint_path=checkpoint,
                                            step=current_step,
                                            estimator=estimator,
                                            output_dir=prediction_output_dir)

            if eval_input_fn is not None:
                _predict_and_export_metrics(mode="eval",
                                            input_fn=eval_input_fn,
                                            checkpoint_path=checkpoint,
                                            step=current_step,
                                            estimator=estimator,
                                            output_dir=prediction_output_dir)
def main(_):
  do_model_aggregation = FLAGS.num_aggregation_labels > 0
  do_model_classification = FLAGS.num_classification_labels > 0
  bert_config = experiment_utils.bert_config_from_flags()
  total_steps = experiment_utils.num_train_steps()
  tapas_config = tapas_classifier_model.TapasClassifierConfig(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=total_steps,
      num_warmup_steps=experiment_utils.num_warmup_steps(),
      use_tpu=FLAGS.use_tpu,
      positive_weight=FLAGS.positive_weight,
      num_aggregation_labels=FLAGS.num_aggregation_labels,
      num_classification_labels=FLAGS.num_classification_labels,
      aggregation_loss_importance=FLAGS.aggregation_loss_importance,
      use_answer_as_supervision=FLAGS.use_answer_as_supervision,
      answer_loss_importance=FLAGS.answer_loss_importance,
      use_normalized_answer_loss=FLAGS.use_normalized_answer_loss,
      huber_loss_delta=FLAGS.huber_loss_delta,
      temperature=FLAGS.temperature,
      agg_temperature=FLAGS.agg_temperature,
      use_gumbel_for_cells=FLAGS.use_gumbel_for_cells,
      use_gumbel_for_agg=FLAGS.use_gumbel_for_agg,
      average_approximation_function=(
          tapas_classifier_model.AverageApproximationFunction(
              FLAGS.average_approximation_function)),
      cell_select_pref=FLAGS.cell_select_pref,
      answer_loss_cutoff=FLAGS.answer_loss_cutoff,
      grad_clipping=FLAGS.grad_clipping,
      classification_label_weight={
          int(pair.split(":")[0]): float(pair.split(":")[1])
          for pair in FLAGS.classification_label_weight.split(",")
          if pair
      },
      table_pruning_config_file=FLAGS.table_pruning_config_file,
      restrict_attention_mode=(attention_utils.RestrictAttentionMode(
          FLAGS.restrict_attention_mode)),
      restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size,
      restrict_attention_header_size=FLAGS.restrict_attention_header_size,
      restrict_attention_row_heads_ratio=(
          FLAGS.restrict_attention_row_heads_ratio),
      mask_examples_without_labels=FLAGS.mask_examples_without_labels,
      cell_cross_entropy_hard_em=FLAGS.cell_cross_entropy_hard_em,
      cell_cross_entropy=FLAGS.cell_cross_entropy,
      reset_output_cls=FLAGS.reset_output_cls,
      disabled_features=FLAGS.disabled_features,
      max_num_rows=FLAGS.max_num_rows,
      max_num_columns=FLAGS.max_num_columns,
      average_logits_per_cell=FLAGS.average_logits_per_cell,
      init_cell_selection_weights_to_zero=FLAGS
      .init_cell_selection_weights_to_zero,
      select_one_column=FLAGS.select_one_column,
      allow_empty_column_selection=FLAGS.allow_empty_column_selection,
      disable_position_embeddings=FLAGS.disable_position_embeddings,
      disable_per_token_loss=FLAGS.disable_per_token_loss,
      reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
      span_prediction=tapas_classifier_model.SpanPredictionMode(
          FLAGS.span_prediction),
      proj_value_length=_get_projection_length(FLAGS.proj_value_length),
      attention_bias_disabled=FLAGS.attention_bias_disabled,
      attention_bias_use_relative_scalar_only=FLAGS
      .attention_bias_use_relative_scalar_only,
  )

  model_fn = tapas_classifier_model.model_fn_builder(tapas_config)
  estimator = experiment_utils.build_estimator(model_fn)

  if tapas_config.classification_label_weight:
    if any(x < 0 for x in tapas_config.classification_label_weight.values()):
      raise ValueError("Label weights cannot be negative in input: "
                       f"{tapas_config.classification_label_weight}.")
    if any(x < 0 or x >= tapas_config.num_classification_labels
           for x in tapas_config.classification_label_weight.keys()):
      raise ValueError("Invalid label in label weights for input: "
                       f"{tapas_config.classification_label_weight}.")

  if FLAGS.do_train:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    # Copy the table pruning config if pruning is used.
    if FLAGS.table_pruning_config_file:
      table_pruning_path = os.path.join(FLAGS.model_dir,
                                        "table_pruning_config.textproto")
      tf.io.gfile.copy(
          FLAGS.table_pruning_config_file, table_pruning_path, overwrite=True)
    bert_config.to_json_file(os.path.join(FLAGS.model_dir, "bert_config.json"))
    tapas_config.to_json_file(
        os.path.join(FLAGS.model_dir, "tapas_config.json"))
    train_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="train",
        file_patterns=FLAGS.input_file_train,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=True,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=False)
    estimator.train(input_fn=train_input_fn, max_steps=total_steps)

  eval_input_fn = functools.partial(
      tapas_classifier_model.input_fn,
      name="eval",
      file_patterns=FLAGS.input_file_eval,
      data_format=FLAGS.data_format,
      compression_type=FLAGS.compression_type,
      is_training=False,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=do_model_classification,
      add_answer=FLAGS.use_answer_as_supervision,
      include_id=not FLAGS.use_tpu)
  if FLAGS.do_eval:
    eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default"
    for _, checkpoint in experiment_utils.iterate_checkpoints(
        model_dir=estimator.model_dir,
        total_steps=total_steps,
        marker_file_prefix=os.path.join(estimator.model_dir,
                                        f"eval_{eval_name}"),
        minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions):
      tf.logging.info("Running eval: %s", eval_name)
      result = estimator.evaluate(
          input_fn=eval_input_fn,
          steps=FLAGS.num_eval_steps,
          name=eval_name,
          checkpoint_path=checkpoint)
      tf.logging.info("Eval result:\n%s", result)

  if FLAGS.do_predict:
    predict_input_fn = functools.partial(
        tapas_classifier_model.input_fn,
        name="predict",
        file_patterns=FLAGS.input_file_predict,
        data_format=FLAGS.data_format,
        compression_type=FLAGS.compression_type,
        is_training=False,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        add_aggregation_function_id=do_model_aggregation,
        add_classification_labels=do_model_classification,
        add_answer=FLAGS.use_answer_as_supervision,
        include_id=not FLAGS.use_tpu)

    if FLAGS.prediction_output_dir:
      prediction_output_dir = FLAGS.prediction_output_dir
      tf.io.gfile.makedirs(prediction_output_dir)
    else:
      prediction_output_dir = estimator.model_dir

    marker_file_prefix = os.path.join(prediction_output_dir, "predict")
    # When two separate jobs are launched we don't want conflicting markers.
    if FLAGS.input_file_predict is not None:
      marker_file_prefix += "_test"
    if FLAGS.input_file_eval is not None:
      marker_file_prefix += "_dev"

    for current_step, checkpoint in experiment_utils.iterate_checkpoints(
        model_dir=estimator.model_dir,
        total_steps=total_steps,
        single_step=FLAGS.evaluated_checkpoint_step,
        marker_file_prefix=marker_file_prefix):

      try:
        if FLAGS.input_file_predict is not None:
          _predict_and_export_metrics(
              mode="predict",
              name=FLAGS.eval_name,
              input_fn=predict_input_fn,
              output_dir=prediction_output_dir,
              estimator=estimator,
              checkpoint=checkpoint,
              current_step=current_step,
              do_model_classification=do_model_classification,
              do_model_aggregation=do_model_aggregation,
              output_token_answers=not FLAGS.disable_per_token_loss)
        if FLAGS.input_file_eval is not None:
          _predict_and_export_metrics(
              mode="eval",
              name=FLAGS.eval_name,
              input_fn=eval_input_fn,
              output_dir=prediction_output_dir,
              estimator=estimator,
              checkpoint=checkpoint,
              current_step=current_step,
              do_model_classification=do_model_classification,
              do_model_aggregation=do_model_aggregation,
              output_token_answers=not FLAGS.disable_per_token_loss)
      except ValueError:
        tf.logging.error("Error getting predictions for checkpoint %s: %s",
                         checkpoint, traceback.format_exc())