def get_data_loader(self, file):
        train_examples = read_squad_examples(file,
                                             is_training=True,
                                             debug=config.debug)
        train_features = convert_examples_to_features(
            train_examples,
            tokenizer=self.tokenizer,
            max_seq_length=config.max_seq_len,
            max_query_length=config.max_query_len,
            doc_stride=128,
            is_training=True)

        all_c_ids = torch.tensor([f.c_ids for f in train_features],
                                 dtype=torch.long)
        all_c_lens = torch.sum(torch.sign(all_c_ids), 1)
        all_q_ids = torch.tensor([f.q_ids for f in train_features],
                                 dtype=torch.long)
        all_tag_ids = torch.tensor([f.tag_ids for f in train_features],
                                   dtype=torch.long)
        train_data = TensorDataset(all_c_ids, all_c_lens, all_tag_ids,
                                   all_q_ids)
        train_loader = DataLoader(train_data, shuffle=False, batch_size=1)

        self.all_c_tokens = [f.context_tokens for f in train_features]
        self.all_answer_text = [f.answer_text for f in train_features]
        self.golden_q_ids = all_q_ids

        return train_loader
Beispiel #2
0
    def get_data_loader(self, file):
        train_examples = read_squad_examples(file, is_training=True, debug=config.debug)
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer=self.tokenizer,
                                                      max_seq_length=config.max_seq_len,
                                                      max_query_length=config.max_query_len,
                                                      doc_stride=128,
                                                      is_training=True)
        all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long)
        all_c_lens = torch.sum(torch.sign(all_c_ids), 1)
        all_tag_ids = torch.tensor([f.tag_ids for f in train_features], dtype=torch.long)
        all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
        all_noq_start_positions = torch.tensor([f.noq_start_position for f in train_features], dtype=torch.long)
        all_noq_end_positions = torch.tensor([f.noq_end_position for f in train_features], dtype=torch.long)

        train_data = TensorDataset(all_c_ids, all_c_lens, all_tag_ids,
                                   all_q_ids, all_input_ids, all_input_mask,
                                   all_segment_ids, all_start_positions, all_end_positions,
                                   all_noq_start_positions, all_noq_end_positions)

        sampler = RandomSampler(train_data)
        batch_size = int(config.batch_size / config.gradient_accumulation_steps)
        train_loader = DataLoader(train_data, sampler=sampler, batch_size=batch_size)

        return train_loader
Beispiel #3
0
    def get_data_loader(self, file):
        train_examples = read_squad_examples(file, is_training=True, debug=config.debug)
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer=self.tokenizer,
                                                      max_seq_length=config.max_seq_len,
                                                      max_query_length=config.max_query_len,
                                                      doc_stride=128,
                                                      is_training=True)

        all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long)
        all_c_lens = torch.sum(torch.sign(all_c_ids), 1).long()
        all_noq_start_positions = torch.tensor([f.noq_start_position for f in train_features], dtype=torch.long)
        all_noq_end_positions = torch.tensor([f.noq_end_position for f in train_features], dtype=torch.long)

        train_data = TensorDataset(all_c_ids, all_c_lens, all_noq_start_positions, all_noq_end_positions)
        train_loader = DataLoader(train_data, shuffle=True, batch_size=config.batch_size)

        return train_loader
Beispiel #4
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.get_logger().propagate = False

    albert_config = modeling.AlbertConfig.from_json_file(
        FLAGS.albert_config_file)

    validate_flags_or_throw(albert_config)

    tf.gfile.MakeDirs(FLAGS.output_dir)
    print("Output:", FLAGS.output_dir)

    tokenizer = fine_tuning_utils.create_vocab(
        vocab_file=FLAGS.vocab_file,
        do_lower_case=FLAGS.do_lower_case,
        spm_model_file=FLAGS.spm_model_file,
        hub_module=FLAGS.albert_hub_module_handle)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
    if FLAGS.do_train:
        iterations_per_loop = int(
            min(FLAGS.iterations_per_loop, FLAGS.save_checkpoints_steps))
    else:
        iterations_per_loop = FLAGS.iterations_per_loop
    run_config = contrib_tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        keep_checkpoint_max=0,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=contrib_tpu.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    train_examples = squad_utils.read_squad_examples(
        input_file=FLAGS.train_file, is_training=True)
    num_train_steps = int(
        len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    if FLAGS.do_train:
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

        # Pre-shuffle the input to avoid having to make a very large shuffle
        # buffer in in the `input_fn`.
        rng = random.Random(12345)
        rng.shuffle(train_examples)

    model_fn = squad_utils.v2_model_fn_builder(
        albert_config=albert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        max_seq_length=FLAGS.max_seq_length,
        start_n_top=FLAGS.start_n_top,
        end_n_top=FLAGS.end_n_top,
        dropout_prob=FLAGS.dropout_prob,
        hub_module=FLAGS.albert_hub_module_handle)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = contrib_tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        # We write to a temporary file to avoid storing very large constant tensors
        # in memory.

        if not tf.gfile.Exists(FLAGS.train_feature_file):
            train_writer = squad_utils.FeatureWriter(filename=os.path.join(
                FLAGS.train_feature_file),
                                                     is_training=True)
            squad_utils.convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=FLAGS.max_seq_length,
                doc_stride=FLAGS.doc_stride,
                max_query_length=FLAGS.max_query_length,
                is_training=True,
                output_fn=train_writer.process_feature,
                do_lower_case=FLAGS.do_lower_case)
            train_writer.close()

        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num orig examples = %d", len(train_examples))
        # tf.logging.info("  Num split examples = %d", train_writer.num_features)
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        del train_examples

        train_input_fn = squad_utils.input_fn_builder(
            input_file=FLAGS.train_feature_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.train_batch_size,
            is_v2=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_predict:
        with tf.gfile.Open(FLAGS.predict_file) as predict_file:
            prediction_json = json.load(predict_file)["data"]
        eval_examples = squad_utils.read_squad_examples(
            input_file=FLAGS.predict_file, is_training=False)

        if (tf.gfile.Exists(FLAGS.predict_feature_file)
                and tf.gfile.Exists(FLAGS.predict_feature_left_file)):
            tf.logging.info("Loading eval features from {}".format(
                FLAGS.predict_feature_left_file))
            with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin:
                eval_features = pickle.load(fin)
        else:
            eval_writer = squad_utils.FeatureWriter(
                filename=FLAGS.predict_feature_file, is_training=False)
            eval_features = []

            def append_feature(feature):
                eval_features.append(feature)
                eval_writer.process_feature(feature)

            squad_utils.convert_examples_to_features(
                examples=eval_examples,
                tokenizer=tokenizer,
                max_seq_length=FLAGS.max_seq_length,
                doc_stride=FLAGS.doc_stride,
                max_query_length=FLAGS.max_query_length,
                is_training=False,
                output_fn=append_feature,
                do_lower_case=FLAGS.do_lower_case)
            eval_writer.close()

            with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout:
                pickle.dump(eval_features, fout)

        tf.logging.info("***** Running predictions *****")
        tf.logging.info("  Num orig examples = %d", len(eval_examples))
        tf.logging.info("  Num split examples = %d", len(eval_features))
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_input_fn = squad_utils.input_fn_builder(
            input_file=FLAGS.predict_feature_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=False,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.predict_batch_size,
            is_v2=True)

        def get_result(checkpoint):
            """Evaluate the checkpoint on SQuAD v2.0."""
            # If running eval on the TPU, you will need to specify the number of
            # steps.
            reader = tf.train.NewCheckpointReader(checkpoint)
            global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
            all_results = []
            for result in estimator.predict(predict_input_fn,
                                            yield_single_examples=True,
                                            checkpoint_path=checkpoint):
                if len(all_results) % 1000 == 0:
                    tf.logging.info("Processing example: %d" %
                                    (len(all_results)))
                unique_id = int(result["unique_ids"])
                start_top_log_probs = ([
                    float(x) for x in result["start_top_log_probs"].flat
                ])
                start_top_index = [
                    int(x) for x in result["start_top_index"].flat
                ]
                end_top_log_probs = ([
                    float(x) for x in result["end_top_log_probs"].flat
                ])
                end_top_index = [int(x) for x in result["end_top_index"].flat]

                cls_logits = float(result["cls_logits"].flat[0])
                all_results.append(
                    squad_utils.RawResultV2(
                        unique_id=unique_id,
                        start_top_log_probs=start_top_log_probs,
                        start_top_index=start_top_index,
                        end_top_log_probs=end_top_log_probs,
                        end_top_index=end_top_index,
                        cls_logits=cls_logits))

            output_prediction_file = os.path.join(FLAGS.output_dir,
                                                  "predictions.json")
            output_nbest_file = os.path.join(FLAGS.output_dir,
                                             "nbest_predictions.json")
            output_null_log_odds_file = os.path.join(FLAGS.output_dir,
                                                     "null_odds.json")

            result_dict = {}
            cls_dict = {}
            squad_utils.accumulate_predictions_v2(
                result_dict, cls_dict, eval_examples, eval_features,
                all_results, FLAGS.n_best_size, FLAGS.max_answer_length,
                FLAGS.start_n_top, FLAGS.end_n_top)

            return squad_utils.evaluate_v2(
                result_dict, cls_dict, prediction_json, eval_examples,
                eval_features, all_results, FLAGS.n_best_size,
                FLAGS.max_answer_length, output_prediction_file,
                output_nbest_file, output_null_log_odds_file), int(global_step)

        def _find_valid_cands(curr_step):
            filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
            candidates = []
            for filename in filenames:
                if filename.endswith(".index"):
                    ckpt_name = filename[:-6]
                    idx = ckpt_name.split("-")[-1]
                    if idx != "best" and int(idx) > curr_step:
                        candidates.append(filename)
            return candidates

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
        key_name = "f1"
        writer = tf.gfile.GFile(output_eval_file, "w")
        if tf.gfile.Exists(checkpoint_path + ".index"):
            result = get_result(checkpoint_path)
            best_perf = result[0][key_name]
            global_step = result[1]
        else:
            global_step = -1
            best_perf = -1
            checkpoint_path = None
        while global_step < num_train_steps:
            steps_and_files = {}
            filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
            for filename in filenames:
                if filename.endswith(".index"):
                    ckpt_name = filename[:-6]
                    cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
                    if cur_filename.split("-")[-1] == "best":
                        continue
                    gstep = int(cur_filename.split("-")[-1])
                    if gstep not in steps_and_files:
                        tf.logging.info(
                            "Add {} to eval list.".format(cur_filename))
                        steps_and_files[gstep] = cur_filename
            tf.logging.info("found {} files.".format(len(steps_and_files)))
            if not steps_and_files:
                tf.logging.info(
                    "found 0 file, global step: {}. Sleeping.".format(
                        global_step))
                time.sleep(60)
            else:
                for ele in sorted(steps_and_files.items()):
                    step, checkpoint_path = ele
                    print("GS: ", global_step, step)
                    if global_step >= step:
                        if len(_find_valid_cands(step)) > 1:
                            for ext in [
                                    "meta", "data-00000-of-00001", "index"
                            ]:
                                src_ckpt = checkpoint_path + ".{}".format(ext)
                                tf.logging.info("removing {}".format(src_ckpt))
                                tf.gfile.Remove(src_ckpt)
                        continue
                    result, global_step = get_result(checkpoint_path)
                    print("EVAL RESULTS")
                    tf.logging.info("***** Eval results *****")
                    for key in sorted(result.keys()):
                        tf.logging.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))
                    if result[key_name] > best_perf:
                        best_perf = result[key_name]
                        for ext in ["meta", "data-00000-of-00001", "index"]:
                            src_ckpt = checkpoint_path + ".{}".format(ext)
                            tgt_ckpt = checkpoint_path.rsplit(
                                "-", 1)[0] + "-best.{}".format(ext)
                            tf.logging.info("saving {} to {}".format(
                                src_ckpt, tgt_ckpt))
                            tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
                            writer.write("saved {} to {}\n".format(
                                src_ckpt, tgt_ckpt))
                    writer.write("best {} = {}\n".format(key_name, best_perf))
                    tf.logging.info("  best {} = {}\n".format(
                        key_name, best_perf))

                    if len(_find_valid_cands(global_step)) > 2:
                        for ext in ["meta", "data-00000-of-00001", "index"]:
                            src_ckpt = checkpoint_path + ".{}".format(ext)
                            tf.logging.info("removing {}".format(src_ckpt))
                            tf.gfile.Remove(src_ckpt)
                    writer.write("=" * 50 + "\n")
            print("Sleeping")
            time.sleep(10)
        checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
        result, global_step = get_result(checkpoint_path)
        tf.logging.info("***** Final Eval results *****")
        for key in sorted(result.keys()):
            tf.logging.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))
        writer.write("best perf happened at step: {}".format(global_step))
Beispiel #5
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    albert_config = modeling.AlbertConfig.from_json_file(
        FLAGS.albert_config_file)

    validate_flags_or_throw(albert_config)

    tf.gfile.MakeDirs(FLAGS.output_dir)

    tokenizer = fine_tuning_utils.create_vocab(
        vocab_file=FLAGS.vocab_file,
        do_lower_case=FLAGS.do_lower_case,
        spm_model_file=FLAGS.spm_model_file,
        hub_module=FLAGS.albert_hub_module_handle)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
    if FLAGS.do_train:
        iterations_per_loop = int(
            min(FLAGS.iterations_per_loop, FLAGS.save_checkpoints_steps))
    else:
        iterations_per_loop = FLAGS.iterations_per_loop
    run_config = contrib_tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        keep_checkpoint_max=0,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=contrib_tpu.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    # if FLAGS.do_train:
    #     train_examples = squad_utils.read_squad_examples(
    #         input_file=FLAGS.train_file, is_training=True)
    #     num_train_steps = int(
    #         len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    #     num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
    #
    #     # Pre-shuffle the input to avoid having to make a very large shuffle
    #     # buffer in in the `input_fn`.
    #     rng = random.Random(12345)
    #     rng.shuffle(train_examples)

    model_fn = squad_utils.v2_model_fn_builder(
        albert_config=albert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        max_seq_length=FLAGS.max_seq_length,
        start_n_top=FLAGS.start_n_top,
        end_n_top=FLAGS.end_n_top,
        dropout_prob=FLAGS.dropout_prob,
        hub_module=FLAGS.albert_hub_module_handle)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = contrib_tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        # We write to a temporary file to avoid storing very large constant tensors
        # in memory.

        if not tf.gfile.Exists(FLAGS.train_feature_file):
            train_writer = squad_utils.FeatureWriter(filename=os.path.join(
                FLAGS.train_feature_file),
                                                     is_training=True)
            squad_utils.convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=FLAGS.max_seq_length,
                doc_stride=FLAGS.doc_stride,
                max_query_length=FLAGS.max_query_length,
                is_training=True,
                output_fn=train_writer.process_feature,
                do_lower_case=FLAGS.do_lower_case)
            train_writer.close()

        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num orig examples = %d", len(train_examples))
        # tf.logging.info("  Num split examples = %d", train_writer.num_features)
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        del train_examples

        train_input_fn = squad_utils.input_fn_builder(
            input_file=FLAGS.train_feature_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.train_batch_size,
            is_v2=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_predict:
        with tf.gfile.Open(FLAGS.predict_file) as predict_file:
            prediction_json = json.load(predict_file)["data"]
        eval_examples = squad_utils.read_squad_examples(
            input_file=FLAGS.predict_file, is_training=False)

        if (tf.gfile.Exists(FLAGS.predict_feature_file)
                and tf.gfile.Exists(FLAGS.predict_feature_left_file)):
            tf.logging.info("Loading eval features from {}".format(
                FLAGS.predict_feature_left_file))
            with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin:
                eval_features = pickle.load(fin)
        else:
            eval_writer = squad_utils.FeatureWriter(
                filename=FLAGS.predict_feature_file, is_training=False)
            eval_features = []

            def append_feature(feature):
                eval_features.append(feature)
                eval_writer.process_feature(feature)

            squad_utils.convert_examples_to_features(
                examples=eval_examples,
                tokenizer=tokenizer,
                max_seq_length=FLAGS.max_seq_length,
                doc_stride=FLAGS.doc_stride,
                max_query_length=FLAGS.max_query_length,
                is_training=False,
                output_fn=append_feature,
                do_lower_case=FLAGS.do_lower_case)
            eval_writer.close()

            with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout:
                pickle.dump(eval_features, fout)

        tf.logging.info("***** Running predictions *****")
        tf.logging.info("  Num orig examples = %d", len(eval_examples))
        tf.logging.info("  Num split examples = %d", len(eval_features))
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_input_fn = squad_utils.input_fn_builder(
            input_file=FLAGS.predict_feature_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=False,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.predict_batch_size,
            is_v2=True)

        def get_result(checkpoint):
            """Evaluate the checkpoint on SQuAD v2.0."""
            # If running eval on the TPU, you will need to specify the number of
            # steps.
            reader = tf.train.NewCheckpointReader(checkpoint)
            global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
            all_results = []
            for result in estimator.predict(predict_input_fn,
                                            yield_single_examples=True,
                                            checkpoint_path=checkpoint):
                if len(all_results) % 1000 == 0:
                    tf.logging.info("Processing example: %d" %
                                    (len(all_results)))
                unique_id = int(result["unique_ids"])

                cls_logits = float(result["cls_logits"].flat[0])
                all_results.append(
                    squad_utils.RawResultV2(unique_id=unique_id,
                                            cls_logits=cls_logits))

            output_prediction_file = os.path.join(FLAGS.output_dir,
                                                  "predictions.json")
            output_nbest_file = os.path.join(FLAGS.output_dir,
                                             "nbest_predictions.json")
            output_null_log_odds_file = os.path.join(FLAGS.output_dir,
                                                     "null_odds.json")

            result_dict = {}
            cls_dict = {}
            squad_utils.accumulate_predictions_v2(
                result_dict, cls_dict, eval_examples, eval_features,
                all_results, FLAGS.n_best_size, FLAGS.max_answer_length,
                FLAGS.start_n_top, FLAGS.end_n_top)

            from squad_utils import make_qid_to_has_ans
            import numpy as np
            qid_to_has_ans = make_qid_to_has_ans(
                prediction_json)  # maps qid to True/False
            has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
            no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
            print("has_ans", len(has_ans_qids))
            print("no_ans", len(no_ans_qids))

            def compute_metrics_with_threshold(threshold):
                nonlocal result_dict
                result_dict = {}
                tp = 0
                tn = 0
                fp = 0
                fn = 0
                for example_index, example in enumerate(eval_examples):
                    m = np.min(cls_dict[example_index])
                    predict_is_impossible = 1 / (1 + np.exp(-m)) > threshold
                    # predict_is_impossible = m > threshold
                    result_dict[example.qas_id] = m
                    if example.is_impossible:
                        if predict_is_impossible:
                            tp += 1
                        else:
                            fn += 1
                    else:
                        if predict_is_impossible:
                            fp += 1
                        else:
                            tn += 1
                precision = tp / (tp + fp)
                recall = tp / (fn + tp)
                f1 = 2 * tp / (2 * tp + fp + fn)
                tf.logging.info(f"precision: {precision}"
                                f"recall: {recall}"
                                f"f1: {f1}")
                return precision, recall, f1

            # precision, recall, f1 = compute_metrics_with_threshold(0.4)
            precision, recall, f1 = compute_metrics_with_threshold(0.5)
            # precision, recall, f1 = compute_metrics_with_threshold(0.6)

            with tf.gfile.GFile(output_prediction_file, "w") as writer:
                writer.write(json.dumps(result_dict, indent=4) + "\n")

            return {
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "total": len(eval_examples)
            }, int(global_step)

        def _find_valid_cands(curr_step):
            filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
            candidates = []
            for filename in filenames:
                if filename.endswith(".index"):
                    ckpt_name = filename[:-6]
                    idx = ckpt_name.split("-")[-1]
                    if idx != "best" and int(idx) > curr_step:
                        candidates.append(filename)
            return candidates

        # output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        # checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
        # key_name = "f1"
        # writer = tf.gfile.GFile(output_eval_file, "w")
        # if tf.gfile.Exists(checkpoint_path + ".index"):
        #     result = get_result(checkpoint_path)
        #     best_perf = result[0][key_name]
        #     global_step = result[1]
        # else:
        #     global_step = -1
        #     best_perf = -1
        #     checkpoint_path = None
        # while global_step < num_train_steps:
        #     steps_and_files = {}
        #     filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
        #     for filename in filenames:
        #         if filename.endswith(".index"):
        #             ckpt_name = filename[:-6]
        #             cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
        #             if cur_filename.split("-")[-1] == "best":
        #                 continue
        #             gstep = int(cur_filename.split("-")[-1])
        #             if gstep not in steps_and_files:
        #                 tf.logging.info("Add {} to eval list.".format(cur_filename))
        #                 steps_and_files[gstep] = cur_filename
        #     tf.logging.info("found {} files.".format(len(steps_and_files)))
        #     if not steps_and_files:
        #         tf.logging.info("found 0 file, global step: {}. Sleeping."
        #                         .format(global_step))
        #         time.sleep(60)
        #     else:
        #         for ele in sorted(steps_and_files.items()):
        #             step, checkpoint_path = ele
        #             if global_step >= step:
        #                 if len(_find_valid_cands(step)) > 1:
        #                     for ext in ["meta", "data-00000-of-00001", "index"]:
        #                         src_ckpt = checkpoint_path + ".{}".format(ext)
        #                         tf.logging.info("removing {}".format(src_ckpt))
        #                         tf.gfile.Remove(src_ckpt)
        #                 continue
        #             result, global_step = get_result(checkpoint_path)
        #             tf.logging.info("***** Eval results *****")
        #             for key in sorted(result.keys()):
        #                 tf.logging.info("  %s = %s", key, str(result[key]))
        #                 writer.write("%s = %s\n" % (key, str(result[key])))
        #             if result[key_name] > best_perf:
        #                 best_perf = result[key_name]
        #                 for ext in ["meta", "data-00000-of-00001", "index"]:
        #                     src_ckpt = checkpoint_path + ".{}".format(ext)
        #                     tgt_ckpt = checkpoint_path.rsplit(
        #                         "-", 1)[0] + "-best.{}".format(ext)
        #                     tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
        #                     tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
        #                     writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
        #             writer.write("best {} = {}\n".format(key_name, best_perf))
        #             tf.logging.info("  best {} = {}\n".format(key_name, best_perf))
        #
        #             if len(_find_valid_cands(global_step)) > 2:
        #                 for ext in ["meta", "data-00000-of-00001", "index"]:
        #                     src_ckpt = checkpoint_path + ".{}".format(ext)
        #                     tf.logging.info("removing {}".format(src_ckpt))
        #                     tf.gfile.Remove(src_ckpt)
        #             writer.write("=" * 50 + "\n")

        result, global_step = get_result(FLAGS.init_checkpoint)
Beispiel #6
0
                    f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
RawResult = collections.namedtuple("RawResult",
                                   ["unique_id", "start_logits", "end_logits"])
test_file = "./squad/new_test-v1.1.json"
eval_examples = read_squad_examples(test_file, is_training=False, debug=False)
eval_features = convert_examples_to_features(eval_examples,
                                             tokenizer=tokenizer,
                                             max_seq_length=config.max_seq_len,
                                             max_query_length=config.max_query_len,
                                             doc_stride=128,
                                             is_training=False)

all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0))
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=8)

model = BertForQuestionAnswering.from_pretrained("./save/dual/train_507200353/bert_1_2.958")
model = model.to(config.device)
device = "cuda:2"
model.eval()
def main(args,
         shuffle_data=True,
         model=None,
         qamodel=None,
         tokenizer=None,
         zsre=False,
         v2=True,
         must_choose_answer=False,
         condition_on_answer_exists=False,
         condition_on_single_token=False,
         condition_on_multi_token=False,
         condition_on_answer_does_not_exist=False):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    print(model)
    #if model is None:
    #    #model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.0
    MRR_negative = 0.0
    MRR_positive = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    # EM
    EM = 0.0

    # F1
    F1 = 0.0
    is_error = 0
    no_overlap = 0
    larger_by_1 = 0
    larger_by_2 = 0
    larger_by_3 = 0
    larger_by_4 = 0
    larger_by_5_or_more = 0
    data = load_file(args.dataset_filename)

    print(len(data))

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        all_samples = lowercase_samples(data)
    else:
        # keep samples as they are
        all_samples = data

    all_samples, ret_msg = filter_samples(
        model,
        data,
        vocab_subset,
        args.max_sentence_length,
        args.template,
        condition_on_answer_exists=condition_on_answer_exists,
        condition_on_single_token=condition_on_single_token,
        condition_on_multi_token=condition_on_multi_token,
        condition_on_answer_does_not_exist=condition_on_answer_does_not_exist,
        is_zsre=zsre)

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))
    if len(all_samples) == 0:  # or len(all_samples) >= 50:
        return None, None, None, None, None, None, None, None, None, None, None, None

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        facts = []
        sub_objs = []
        for sample in all_samples:
            sub = sample["sub_label"]
            obj = sample["obj_label"]
            target = sample['reconstructed_word']
            question = args.question
            if 'reconstructed_word' not in sample:
                raise Exception('Reconstructed word not in sample... fix this')
            else:
                if 'masked_sentences' in sample:
                    # Some of the masked sentences don't have a mask in them, need to find first with mask
                    context = None
                    for sent in sample['masked_sentences']:
                        if not zsre:
                            if '[MASK]' in sent:
                                context = sent.replace(
                                    '[MASK]', sample['reconstructed_word'])
                                break
                        else:
                            context = sent
                    if context is None:
                        print('No valid context found, skipping sample')
                        continue
                else:
                    context = None
                    for evidence in sample['evidences']:
                        if not zsre:
                            if '[MASK]' in evidence['masked_sentence']:
                                context = evidence['masked_sentence'].replace(
                                    '[MASK]', sample['reconstructed_word'])
                                break
                        else:
                            context = evidence['masked_sentence']
                    if context is None:
                        print('No valid context found, skipping sample')
                        continue

            #context = context.replace('(', '')
            #context = context.replace(')', '')
            if (sub, target, context) not in sub_objs:
                sub_objs.append((sub, target, context))
                if 'reconstructed_word' in sample:
                    facts.append((sub, obj, context, question,
                                  sample['reconstructed_word']))
                else:
                    facts.append((sub, obj, context, question, obj))

                #break
        local_msg = "distinct template facts: {}".format(len(facts))
        logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj, context, question, rw) = fact
            sample = {}
            sample["sub_label"] = sub
            sample["obj_label"] = obj
            sample["reconstructed_word"] = rw
            # sobstitute all sentences with a standard template
            sample['context'] = context
            sample["masked_sentences"] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK)
            question = question.replace('[X]', sub)
            sample['question'] = question
            #query = sample['masked_sentences'][0].replace(base.MASK, '')
            #sample['query'] = query
            #print(f'query={query}')
            #docs = retrieve_docs(query, ranker, conn, 30)
            #sample['context'] = docs[0]
            #print(f'docs={docs}')
            all_samples.append(sample)
    #else:
    #    for sample in all_samples:
    #        query = sample['masked_sentences'][0].replace(base.MASK, '')
    #        sample['query'] = query
    #        #print(f'query={query}')
    #        docs = retrieve_docs(query, ranker, conn, 1)
    #        sample['context'] = docs[0]

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []

    viz = False
    num_viz = 10
    final_viz = []
    viz_thres = 11
    qamodel.eval().cuda()
    # Defaults from huggingface
    do_lower_case = True
    max_answer_length = 30
    verbose_logging = False
    null_score_diff_threshold = 0.0
    n_best = 20
    max_query_length = 64
    # Training specifics:
    doc_stride = 128
    max_seq_length = 384

    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]
        mymodel_probs_list = []
        predictions_list = []

        examples = read_input_examples(samples_b)
        features = convert_examples_to_features(
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_training=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            cls_token_at_end=False,
            sequence_a_is_doc=False)

        # Convert to Tensors and build dataset
        all_input_ids = torch.tensor([f.input_ids for f in features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in features],
                                       dtype=torch.long)
        all_cls_index = torch.tensor([f.cls_index for f in features],
                                     dtype=torch.long)
        all_p_mask = torch.tensor([f.p_mask for f in features],
                                  dtype=torch.float)
        all_example_index = torch.arange(all_input_ids.size(0),
                                         dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_example_index, all_cls_index, all_p_mask)
        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset,
                                     sampler=eval_sampler,
                                     batch_size=len(samples_b))
        all_results = []
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            #stime = time.time()
            batch = tuple(t.cuda() for t in batch)
            with torch.no_grad():
                inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
                inputs['token_type_ids'] = batch[
                    2]  # XLM don't use segment_ids
                example_indices = batch[3]
                outputs = qamodel(**inputs)

            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                result = RawResult(unique_id=unique_id,
                                   start_logits=to_list(outputs[0][i]),
                                   end_logits=to_list(outputs[1][i]))
                all_results.append(result)
            #total_time = time.time() - stime
            #print(total_time)
            #import ipdb
            #ipdb.set_trace()

        predictions = get_predictions(examples,
                                      features,
                                      all_results,
                                      n_best,
                                      max_answer_length,
                                      do_lower_case,
                                      verbose_logging,
                                      v2,
                                      null_score_diff_threshold,
                                      must_choose_answer=must_choose_answer)
        predictions = [predictions[p] for p in predictions]
        predictions_list.extend(predictions)

        torch.cuda.empty_cache()

        original_log_probs_list, token_ids_list, masked_indices_list = model.get_batch_generation(
            sentences_b, logger=logger)
        mymodel_probs_list = original_log_probs_list

        #obj_len = 0
        #for obj in gc.get_objects():
        #    try:
        #        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
        #            print(type(obj), obj.size())
        #            obj_len += 1
        #    except:
        #        pass
        #print(obj_len)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            #elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
            #    raise ValueError(
            #        "object label {} not in model vocabulary".format(
            #            sample["obj_label"]
            #        )
            #    )
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)

        arguments = [{
            "mymodel_probs":
            mymodel_probs,
            "original_log_probs":
            original_log_probs,
            "filtered_log_probs":
            filtered_log_probs,
            "target":
            sample["reconstructed_word"],
            "prediction":
            pred,
            "token_ids":
            token_ids,
            "vocab":
            model.vocab,
            "label_index":
            label_index[0] if len(label_index) > 0 else 0,
            "masked_indices":
            masked_indices,
            "interactive":
            args.interactive,
            "index_list":
            index_list,
            "sample":
            sample,
        } for mymodel_probs, original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample, pred in zip(
                         mymodel_probs_list,
                         original_log_probs_list,
                         filtered_log_probs_list,
                         token_ids_list,
                         masked_indices_list,
                         label_index_list,
                         samples_b,
                         predictions_list,
                     )]

        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):

            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg, sample_em, sample_f1, sample_is_error, sample_no_overlap, sample_larger_by_1, sample_larger_by_2, sample_larger_by_3, sample_larger_by_4, sample_larger_by_5_or_more = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element["sample"] = sample
            element["uuid"] = sample["uuid"]
            element["token_ids"] = token_ids_list[idx]
            element["masked_indices"] = masked_indices_list[idx]
            element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk
            element["sample_MRR"] = sample_MRR
            element["sample_Precision"] = sample_P
            element["sample_perplexity"] = sample_perplexity
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]
            element['sample_em'] = sample_em
            element['sample_f1'] = sample_f1

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element["sample_Precision1"]
            is_error += sample_is_error
            no_overlap += sample_no_overlap
            larger_by_1 += sample_larger_by_1
            larger_by_2 += sample_larger_by_2
            larger_by_3 += sample_larger_by_3
            larger_by_4 += sample_larger_by_4
            larger_by_5_or_more += sample_larger_by_5_or_more
            EM += sample_em
            F1 += sample_f1

            # the judgment of the annotators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if "judgments" in sample:
                # only for Google-RE
                for x in sample["judgments"]:
                    if x["judgment"] == "yes":
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element["judgement"] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element["judgement"] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P

            list_of_results.append(element)

    if viz:
        with open('viz.pkl', 'wb') as wf:
            pickle.dump(final_viz, wf)

    pool.close()
    pool.join()

    # stats
    # Mean reciprocal rank
    MRR /= len(list_of_results)

    # Precision
    Precision /= len(list_of_results)
    Precision1 /= len(list_of_results)

    EM /= len(list_of_results)
    F1 /= len(list_of_results)

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)
    msg += "global EM {}\n".format(EM)
    msg += "global F1: {}\n".format(F1)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement)
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement)
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(list_of_results=list_of_results,
                       global_MRR=MRR,
                       global_P_at_10=Precision)
    with open("{}/result.pkl".format(log_directory), "wb") as f:
        pickle.dump(all_results, f)

    return Precision1, Precision, MRR, EM, F1, is_error, no_overlap, larger_by_1, larger_by_2, larger_by_3, larger_by_4, larger_by_5_or_more