示例#1
0
def answer_questions(input):
  bert_model = "../files/models/squad/"
  output_dir = "../files/results"
  predict_file = '../files/json/my-data.json'
  
  qa_json = create_qa_json(input)
  with open(predict_file, 'w+') as f:
    f.write(json.dumps(qa_json))

  args = Namespace()
  args.bert_model = bert_model
  args.predict_batch_size=8
  args.do_predict=True
  args.max_seq_length=384
  args.doc_stride=128
  args.output_dir=output_dir
  args.version_2_with_negative = True
  args.null_score_diff_threshold=-3.3908887952566147
  args.predict_file=predict_file
  args.overwrite_output_dir=True
  args.do_lower_case=True
  args.n_best_size=3

# Arguments with default values
  args.train_file = None
  args.max_query_length=64
  args.do_train=False
  args.train_batch_size=32
  args.learning_rate=5e-5
  args.num_train_epochs=3.0
  args.warmup_proportion=0.1
  args.max_answer_length=30
  args.verbose_logging=False
  args.no_cuda=False
  args.seed=42
  args.gradient_accumulation_steps=1
  args.local_rank=-1
  args.fp16=False
  args.loss_scale=0
  args.server_ip=''
  args.server_port=''

  run_squad.main(args)

  with open('../files/results/nbest_predictions.json', 'r') as f:
    output = f.read()

  return json.loads(output)
示例#2
0
    def test_run_squad(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_squad.py
            --model_type=distilbert
            --model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
            --data_dir=./tests/fixtures/tests_samples/SQUAD
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=10
            --warmup_steps=2
            --do_train
            --do_eval
            --version_2_with_negative
            --learning_rate=2e-4
            --per_gpu_train_batch_size=2
            --per_gpu_eval_batch_size=1
            --seed=42
        """.split()

        with patch.object(sys, "argv", testargs):
            result = run_squad.main()
            self.assertGreaterEqual(result["f1"], 25)
            self.assertGreaterEqual(result["exact"], 21)
    def test_run_squad_adapters(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        testargs = """
            run_squad.py
            --model_type=bert
            --model_name_or_path=bert-base-uncased
            --data_dir=./tests/fixtures/tests_samples/SQUAD
            --model_name=bert-base-uncased
            --output_dir=./tests/fixtures/tests_samples/temp_dir
            --max_steps=20
            --warmup_steps=2
            --do_train
            --do_eval
            --version_2_with_negative
            --learning_rate=2e-4
            --per_gpu_train_batch_size=2
            --per_gpu_eval_batch_size=1
            --overwrite_output_dir
            --seed=42
            --train_adapter
            --adapter_config=houlsby
            --adapter_reduction_factor=8
        """.split()
        with patch.object(sys, "argv", testargs):
            result = run_squad.main()
            self.assertGreaterEqual(result["f1"], 30)
            self.assertGreaterEqual(result["exact"], 30)
    def test_run_squad(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        testargs = [
            "run_squad.py",
            "--train_file=./examples/tests_samples/SQUAD/dev-v2.0-small.json",
            "--predict_file=./examples/tests_samples/SQUAD/dev-v2.0-small.json",
            "--model_name=bert-base-uncased",
            "--output_dir=./examples/tests_samples/temp_dir",
            "--max_steps=10",
            "--warmup_steps=2",
            "--do_train",
            "--do_eval",
            "--version_2_with_negative",
            "--learning_rate=1e-4",
            "--per_gpu_train_batch_size=2",
            "--per_gpu_eval_batch_size=1",
            "--overwrite_output_dir",
            "--seed=42",
        ]
        model_type, model_name = ("--model_type=bert",
                                  "--model_name_or_path=bert-base-uncased")
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
            result = run_squad.main()
            self.assertGreaterEqual(result["f1"], 30)
            self.assertGreaterEqual(result["exact"], 30)
示例#5
0
import run_squad as rs

rs.main(_)