Пример #1
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if FLAGS.mode == 'export_only':
        export_squad(FLAGS.model_export_path, input_meta_data)
        return

    # Configures cluster spec for multi-worker distribution strategy.
    if FLAGS.num_gpus > 0:
        _ = distribute_utils.configure_cluster(FLAGS.worker_hosts,
                                               FLAGS.task_index)
    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        all_reduce_alg=FLAGS.all_reduce_alg,
        tpu_address=FLAGS.tpu)

    if 'train' in FLAGS.mode:
        if FLAGS.log_steps:
            custom_callbacks = [
                keras_utils.TimeHistory(
                    batch_size=FLAGS.train_batch_size,
                    log_steps=FLAGS.log_steps,
                    logdir=FLAGS.model_dir,
                )
            ]
        else:
            custom_callbacks = None

        train_squad(
            strategy,
            input_meta_data,
            custom_callbacks=custom_callbacks,
            run_eagerly=FLAGS.run_eagerly,
            sub_model_export_name=FLAGS.sub_model_export_name,
        )
    if 'predict' in FLAGS.mode:
        predict_squad(strategy, input_meta_data)
    if 'eval' in FLAGS.mode:
        eval_metrics = eval_squad(strategy, input_meta_data)
        f1_score = eval_metrics['final_f1']
        logging.info('SQuAD eval F1-score: %f', f1_score)
        summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
        summary_writer = tf.summary.create_file_writer(summary_dir)
        with summary_writer.as_default():
            # TODO(lehou): write to the correct step number.
            tf.summary.scalar('F1-score', f1_score, step=0)
            summary_writer.flush()
        # Also write eval_metrics to json file.
        squad_lib_wp.write_to_json_files(
            eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
        time.sleep(60)
Пример #2
0
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
                  version_2_with_negative, output_dir):
  """Save output to json files for question answering."""
  output_prediction_file = os.path.join(output_dir, 'predictions.json')
  output_nbest_file = os.path.join(output_dir, 'nbest_predictions.json')
  output_null_log_odds_file = os.path.join(output_dir, 'null_odds.json')
  tf.compat.v1.logging.info('Writing predictions to: %s',
                            (output_prediction_file))
  tf.compat.v1.logging.info('Writing nbest to: %s', (output_nbest_file))

  squad_lib.write_to_json_files(all_predictions, output_prediction_file)
  squad_lib.write_to_json_files(all_nbest_json, output_nbest_file)
  if version_2_with_negative:
    squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)