Exemplo n.º 1
0
def main():
    args = parse_args()
    print("***** Loading tokenizer and model *****")
    electra_model = args.electra_model
    config = ElectraConfig.from_pretrained(electra_model)
    tokenizer = ElectraTokenizer.from_pretrained(electra_model)
    model = TFElectraForQuestionAnswering.from_pretrained(electra_model,
                                                          config=config,
                                                          args=args)

    print("***** Loading fine-tuned checkpoint: {} *****".format(
        args.init_checkpoint))
    model.load_weights(args.init_checkpoint,
                       by_name=False,
                       skip_mismatch=False).expect_partial()

    question, text = args.question, args.context
    encoding = tokenizer.encode_plus(question, text, return_tensors='tf')
    input_ids, token_type_ids, attention_mask = encoding["input_ids"], encoding["token_type_ids"], \
                                                encoding["attention_mask"]
    all_tokens = tokenizer.convert_ids_to_tokens(input_ids.numpy()[0])
    if not args.joint_head:
        start_logits, end_logits = model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )[:2]
        start_logits = start_logits[0].numpy().tolist()
        end_logits = end_logits[0].numpy().tolist()
        result = RawResult(unique_id=0,
                           start_logits=start_logits,
                           end_logits=end_logits)

        start_indices = _get_best_indices(result.start_logits,
                                          args.n_best_size)
        end_indices = _get_best_indices(result.end_logits, args.n_best_size)
        predictions = get_predictions(start_indices, end_indices, result,
                                      len(all_tokens), args)
        null_score = result.start_logits[0] + result.end_logits[0]

    else:
        outputs = model(input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)
        output = [output[0].numpy().tolist() for output in outputs]
        start_logits = output[0]
        start_top_index = output[1]
        end_logits = output[2]
        end_top_index = output[3]
        cls_logits = output[4]
        result = SquadResult(
            0,
            start_logits,
            end_logits,
            start_top_index=start_top_index,
            end_top_index=end_top_index,
            cls_logits=cls_logits,
        )
        predictions = get_predictions_joint_head(result.start_top_index,
                                                 result.end_top_index, result,
                                                 len(all_tokens), args)
        null_score = result.cls_logits

    predictions = sorted(predictions,
                         key=lambda x: (x.start_logit + x.end_logit),
                         reverse=True)
    answer = predictions[0]
    answer = ' '.join(all_tokens[answer.start_index:answer.end_index + 1])
    if args.null_score_diff_threshold > null_score and args.version_2_with_negative:
        answer = ''

    print(answer)

    return answer
def main():
    args = parse_args()

    hvd.init()
    set_affinity(hvd.local_rank())

    if is_main_process():
        log("Running total processes: {}".format(get_world_size()))
    log("Starting process: {}".format(get_rank()))

    if is_main_process():
        dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                                           filename=args.json_summary),
                                dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step)])
    else:
        dllogger.init(backends=[])

    tf.random.set_seed(args.seed)
    dllogger.log(step="PARAMETER", data={"SEED": args.seed})
    # script parameters
    BATCH_SIZE = args.train_batch_size
    EVAL_BATCH_SIZE = args.predict_batch_size
    USE_XLA = args.xla
    USE_AMP = args.amp
    EPOCHS = args.num_train_epochs

    if not args.do_train:
        EPOCHS = args.num_train_epochs = 1
        log("Since running inference only, setting args.num_train_epochs to 1")

    if not os.path.exists(args.output_dir) and is_main_process():
        os.makedirs(args.output_dir)

    # TensorFlow configuration
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
    tf.config.optimizer.set_jit(USE_XLA)
    #tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
    
    if args.amp:
        policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale="dynamic")
        tf.keras.mixed_precision.experimental.set_policy(policy)
        print('Compute dtype: %s' % policy.compute_dtype)  # Compute dtype: float16
        print('Variable dtype: %s' % policy.variable_dtype)  # Variable dtype: float32

    if is_main_process():
        log("***** Loading tokenizer and model *****")
    # Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
    electra_model = args.electra_model
    config = ElectraConfig.from_pretrained(electra_model, cache_dir=args.cache_dir)
    config.update({"amp": args.amp})
    if args.vocab_file is None:
        tokenizer = ElectraTokenizer.from_pretrained(electra_model, cache_dir=args.cache_dir)
    else:
        tokenizer = ElectraTokenizer(
            vocab_file=args.vocab_file,
            do_lower_case=args.do_lower_case)

    model = TFElectraForQuestionAnswering.from_pretrained(electra_model, config=config, cache_dir=args.cache_dir, args=args)

    if is_main_process():
        log("***** Loading dataset *****")
    # Load data
    processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
    train_examples = processor.get_train_examples(args.data_dir) if args.do_train else None
    dev_examples = processor.get_dev_examples(args.data_dir) if args.do_predict else None

    if is_main_process():
        log("***** Loading features *****")
    # Load cached features
    squad_version = '2.0' if args.version_2_with_negative else '1.1'
    if args.cache_dir is None:
        args.cache_dir = args.data_dir
    cached_train_features_file = args.cache_dir.rstrip('/') + '/' + 'TF2_train-v{4}.json_{1}_{2}_{3}'.format(
        electra_model.split("/")[1], str(args.max_seq_length), str(args.doc_stride),
        str(args.max_query_length), squad_version)
    cached_dev_features_file = args.cache_dir.rstrip('/') + '/' + 'TF2_dev-v{4}.json_{1}_{2}_{3}'.format(
        electra_model.split("/")[1], str(args.max_seq_length), str(args.doc_stride),
        str(args.max_query_length), squad_version)

    try:
        with open(cached_train_features_file, "rb") as reader:
            train_features = pickle.load(reader) if args.do_train else []
        with open(cached_dev_features_file, "rb") as reader:
            dev_features = pickle.load(reader) if args.do_predict else []
    except:
        train_features = (  # TODO: (yy) do on rank 0?
            squad_convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True,
                return_dataset="",
            )
            if args.do_train
            else []
        )
        dev_features = (
            squad_convert_examples_to_features(
                examples=dev_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=False,
                return_dataset="",
            )
            if args.do_predict
            else []
        )
        # Dump Cached features
        if not args.skip_cache and is_main_process():
            if args.do_train:
                log("***** Building Cache Files: {} *****".format(cached_train_features_file))
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
            if args.do_predict:
                log("***** Building Cache Files: {} *****".format(cached_dev_features_file))
                with open(cached_dev_features_file, "wb") as writer:
                    pickle.dump(dev_features, writer)

    len_train_features = len(train_features)
    total_train_steps = int((len_train_features * EPOCHS / BATCH_SIZE) / get_world_size()) + 1
    train_steps_per_epoch = int((len_train_features / BATCH_SIZE) / get_world_size()) + 1
    len_dev_features = len(dev_features)
    total_dev_steps = int((len_dev_features / EVAL_BATCH_SIZE)) + 1

    train_dataset = get_dataset_from_features(train_features, BATCH_SIZE,
                                              v2=args.version_2_with_negative) if args.do_train else []
    dev_dataset = get_dataset_from_features(dev_features, EVAL_BATCH_SIZE, drop_remainder=False, ngpu=1, mode="dev",
                                            v2=args.version_2_with_negative) if args.do_predict else []

    opt = create_optimizer(init_lr=args.learning_rate, num_train_steps=total_train_steps,
                           num_warmup_steps=int(args.warmup_proportion * total_train_steps),
                           weight_decay_rate=args.weight_decay_rate,
                           layerwise_lr_decay=args.layerwise_lr_decay,
                           n_transformer_layers=model.num_hidden_layers)
    if USE_AMP:
        # loss scaling is currently required when using mixed precision
        opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")

    # Define loss function
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    loss_class = tf.keras.losses.BinaryCrossentropy(
        from_logits=True,
        name='binary_crossentropy'
    )
    metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
    model.compile(optimizer=opt, loss=loss, metrics=[metric])
    train_loss_results = []

    if args.do_train and is_main_process():
        log("***** Running training *****")
        log("  Num examples = ", len_train_features)
        log("  Num Epochs = ", args.num_train_epochs)
        log("  Instantaneous batch size per GPU = ", args.train_batch_size)
        log(
            "  Total train batch size (w. parallel, distributed & accumulation) = ",
            args.train_batch_size
            * get_world_size(),
        )
        log("  Total optimization steps =", total_train_steps)

    total_train_time = 0
    latency = []
    for epoch in range(EPOCHS):
        if args.do_train:
            epoch_loss_avg = tf.keras.metrics.Mean()
            epoch_perf_avg = tf.keras.metrics.Mean()
            epoch_start = time.time()

            epoch_iterator = tqdm(train_dataset, total=train_steps_per_epoch, desc="Iteration", mininterval=5,
                                  disable=not is_main_process())
            for iter, inputs in enumerate(epoch_iterator):
                # breaking criterion if max_steps if > 1
                if args.max_steps > 0 and (epoch * train_steps_per_epoch + iter) > args.max_steps:
                    break
                iter_start = time.time()
                # Optimize the model
                loss_value = train_step(model, inputs, loss, USE_AMP, opt, (iter == 0 and epoch == 0),
                                        v2=args.version_2_with_negative, loss_class=loss_class, fp16=USE_AMP)
                epoch_perf_avg.update_state(1. * BATCH_SIZE / (time.time() - iter_start))
                if iter % args.log_freq == 0:
                    if is_main_process():
                        log("\nEpoch: {:03d}, Step:{:6d}, Loss:{:12.8f}, Perf:{:5.0f}, loss_scale:{}, opt_step:{}".format(epoch, iter, loss_value,
                                                                                              epoch_perf_avg.result() * get_world_size(), opt.loss_scale if config.amp else 1,
                                                                                              int(opt.iterations)))
                    dllogger.log(step=(epoch, iter,), data={"step_loss": float(loss_value.numpy()),
                                                            "train_perf": float( epoch_perf_avg.result().numpy() * get_world_size())})

                # Track progress
                epoch_loss_avg.update_state(loss_value)  # Add current batch loss

            # End epoch
            train_loss_results.append(epoch_loss_avg.result())
            total_train_time += float(time.time() - epoch_start)
            # Summarize and save checkpoint at the end of each epoch
            if is_main_process():

                dllogger.log(step=tuple(), data={"e2e_train_time": total_train_time,
                                                 "training_sequences_per_second": float(
                                                     epoch_perf_avg.result().numpy() * get_world_size()),
                                                 "final_loss": float(epoch_loss_avg.result().numpy())})

            if not args.skip_checkpoint:
                if args.ci:
                    checkpoint_name = "{}/electra_base_qa_v2_{}_epoch_{}_ckpt".format(args.output_dir, args.version_2_with_negative, epoch + 1)
                else:
                    checkpoint_name = "checkpoints/electra_base_qa_v2_{}_epoch_{}_ckpt".format(args.version_2_with_negative, epoch + 1)
                if is_main_process():
                    model.save_weights(checkpoint_name)


        if args.do_predict and (args.evaluate_during_training or epoch == args.num_train_epochs - 1):
            if not args.do_train:
                log("***** Loading checkpoint: {} *****".format(args.init_checkpoint))
                model.load_weights(args.init_checkpoint).expect_partial()

            current_feature_id = 0
            all_results = []
            if is_main_process():
                log("***** Running evaluation *****")
                log("  Num Batches = ", total_dev_steps)
                log("  Batch size = ", args.predict_batch_size)

            raw_infer_start = time.time()
            if is_main_process():
                infer_perf_avg = tf.keras.metrics.Mean()
                dev_iterator = tqdm(dev_dataset, total=total_dev_steps, desc="Iteration", mininterval=5,
                                    disable=not is_main_process())
                for input_ids, input_mask, segment_ids, start_positions, end_positions, cls_index, p_mask, is_impossible in dev_iterator:
                    # training=False is needed only if there are layers with different
                    # behavior during training versus inference (e.g. Dropout).

                    iter_start = time.time()

                    if not args.joint_head:
                        batch_start_logits, batch_end_logits = infer_step(model, input_ids,
                                                                          attention_mask=input_mask,
                                                                          token_type_ids=segment_ids,
                                                                          )[:2]
                        #Synchronize with GPU to compute time
                        _ = batch_start_logits.numpy()
                                                            
                    else:
                        
                        outputs = infer_step(model, input_ids,
                                             attention_mask=input_mask,
                                             token_type_ids=segment_ids,
                                             cls_index=cls_index,
                                             p_mask=p_mask,
                                             )
                        #Synchronize with GPU to compute time
                        _ = outputs[0].numpy()

                    infer_time = (time.time() - iter_start)
                    infer_perf_avg.update_state(1. * EVAL_BATCH_SIZE / infer_time)
                    latency.append(infer_time)

                    for iter_ in range(input_ids.shape[0]):

                        if not args.joint_head:
                            start_logits = batch_start_logits[iter_].numpy().tolist()
                            end_logits = batch_end_logits[iter_].numpy().tolist()
                            dev_feature = dev_features[current_feature_id]
                            current_feature_id += 1
                            unique_id = int(dev_feature.unique_id)
                            all_results.append(RawResult(unique_id=unique_id,
                                                         start_logits=start_logits,
                                                         end_logits=end_logits))
                        else:
                            dev_feature = dev_features[current_feature_id]
                            current_feature_id += 1
                            unique_id = int(dev_feature.unique_id)
                            output = [output[iter_].numpy().tolist() for output in outputs]

                            start_logits = output[0]
                            start_top_index = output[1]
                            end_logits = output[2]
                            end_top_index = output[3]
                            cls_logits = output[4]
                            result = SquadResult(
                                unique_id,
                                start_logits,
                                end_logits,
                                start_top_index=start_top_index,
                                end_top_index=end_top_index,
                                cls_logits=cls_logits,
                            )

                            all_results.append(result)

                # Compute and save predictions
                answers, nbest_answers = get_answers(dev_examples, dev_features, all_results, args)

                output_prediction_file = os.path.join(args.output_dir, "predictions.json")
                output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
                e2e_infer_time = time.time() - raw_infer_start
                # if args.version_2_with_negative:
                #     output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
                # else:
                #     output_null_log_odds_file = None
                with open(output_prediction_file, "w") as f:
                    f.write(json.dumps(answers, indent=4) + "\n")
                with open(output_nbest_file, "w") as f:
                    f.write(json.dumps(nbest_answers, indent=4) + "\n")

                if args.do_eval:
                    if args.version_2_with_negative:
                        dev_file = "dev-v2.0.json"
                    else:
                        dev_file = "dev-v1.1.json"

                    eval_out = subprocess.check_output([sys.executable, args.eval_script,
                                                        args.data_dir + "/" + dev_file, output_prediction_file])
                    log(eval_out.decode('UTF-8'))
                    scores = str(eval_out).strip()
                    exact_match = float(scores.split(":")[1].split(",")[0])
                    if args.version_2_with_negative:
                        f1 = float(scores.split(":")[2].split(",")[0])
                    else:
                        f1 = float(scores.split(":")[2].split("}")[0])

                    log("Epoch: {:03d} Results: {}".format(epoch, eval_out.decode('UTF-8')))
                    log("**EVAL SUMMARY** - Epoch: {:03d},  EM: {:6.3f}, F1: {:6.3f}, Infer_Perf: {:4.0f} seq/s"
                          .format(epoch, exact_match, f1, infer_perf_avg.result()))

                latency_all = sorted(latency)[:-2]
                log(
                    "**LATENCY SUMMARY** - Epoch: {:03d},  Ave: {:6.3f} ms, 90%: {:6.3f} ms, 95%: {:6.3f} ms, 99%: {:6.3f} ms"
                    .format(epoch, sum(latency_all) / len(latency_all) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.9)]) / int(len(latency_all) * 0.9) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.95)]) / int(len(latency_all) * 0.95) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.99)]) / int(len(latency_all) * 0.99) * 1000,
                            ))
                dllogger.log(step=tuple(),
                             data={"inference_sequences_per_second": float(infer_perf_avg.result().numpy()), 
                                   "e2e_inference_time": e2e_infer_time})

    if is_main_process() and args.do_train and args.do_eval:
        log(
            "**RESULTS SUMMARY** - EM: {:6.3f}, F1: {:6.3f}, Train_Time: {:4.0f} s, Train_Perf: {:4.0f} seq/s, Infer_Perf: {:4.0f} seq/s"
            .format(exact_match, f1, total_train_time, epoch_perf_avg.result() * get_world_size(),
                    infer_perf_avg.result()))
        dllogger.log(step=tuple(), data={"exact_match": exact_match, "F1": f1})
Exemplo n.º 3
0
def main(args,
         shuffle_data=True,
         model=None,
         qamodel=None,
         tokenizer=None,
         zsre=False,
         v2=True,
         must_choose_answer=False,
         condition_on_answer_exists=False,
         condition_on_single_token=False,
         condition_on_multi_token=False,
         condition_on_answer_does_not_exist=False):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    print(model)
    #if model is None:
    #    #model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.0
    MRR_negative = 0.0
    MRR_positive = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    # EM
    EM = 0.0

    # F1
    F1 = 0.0
    is_error = 0
    no_overlap = 0
    larger_by_1 = 0
    larger_by_2 = 0
    larger_by_3 = 0
    larger_by_4 = 0
    larger_by_5_or_more = 0
    data = load_file(args.dataset_filename)

    print(len(data))

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        all_samples = lowercase_samples(data)
    else:
        # keep samples as they are
        all_samples = data

    all_samples, ret_msg = filter_samples(
        model,
        data,
        vocab_subset,
        args.max_sentence_length,
        args.template,
        condition_on_answer_exists=condition_on_answer_exists,
        condition_on_single_token=condition_on_single_token,
        condition_on_multi_token=condition_on_multi_token,
        condition_on_answer_does_not_exist=condition_on_answer_does_not_exist,
        is_zsre=zsre)

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))
    if len(all_samples) == 0:  # or len(all_samples) >= 50:
        return None, None, None, None, None, None, None, None, None, None, None, None

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        facts = []
        sub_objs = []
        for sample in all_samples:
            sub = sample["sub_label"]
            obj = sample["obj_label"]
            target = sample['reconstructed_word']
            question = args.question
            if 'reconstructed_word' not in sample:
                raise Exception('Reconstructed word not in sample... fix this')
            else:
                if 'masked_sentences' in sample:
                    # Some of the masked sentences don't have a mask in them, need to find first with mask
                    context = None
                    for sent in sample['masked_sentences']:
                        if not zsre:
                            if '[MASK]' in sent:
                                context = sent.replace(
                                    '[MASK]', sample['reconstructed_word'])
                                break
                        else:
                            context = sent
                    if context is None:
                        print('No valid context found, skipping sample')
                        continue
                else:
                    context = None
                    for evidence in sample['evidences']:
                        if not zsre:
                            if '[MASK]' in evidence['masked_sentence']:
                                context = evidence['masked_sentence'].replace(
                                    '[MASK]', sample['reconstructed_word'])
                                break
                        else:
                            context = evidence['masked_sentence']
                    if context is None:
                        print('No valid context found, skipping sample')
                        continue

            #context = context.replace('(', '')
            #context = context.replace(')', '')
            if (sub, target, context) not in sub_objs:
                sub_objs.append((sub, target, context))
                if 'reconstructed_word' in sample:
                    facts.append((sub, obj, context, question,
                                  sample['reconstructed_word']))
                else:
                    facts.append((sub, obj, context, question, obj))

                #break
        local_msg = "distinct template facts: {}".format(len(facts))
        logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj, context, question, rw) = fact
            sample = {}
            sample["sub_label"] = sub
            sample["obj_label"] = obj
            sample["reconstructed_word"] = rw
            # sobstitute all sentences with a standard template
            sample['context'] = context
            sample["masked_sentences"] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK)
            question = question.replace('[X]', sub)
            sample['question'] = question
            #query = sample['masked_sentences'][0].replace(base.MASK, '')
            #sample['query'] = query
            #print(f'query={query}')
            #docs = retrieve_docs(query, ranker, conn, 30)
            #sample['context'] = docs[0]
            #print(f'docs={docs}')
            all_samples.append(sample)
    #else:
    #    for sample in all_samples:
    #        query = sample['masked_sentences'][0].replace(base.MASK, '')
    #        sample['query'] = query
    #        #print(f'query={query}')
    #        docs = retrieve_docs(query, ranker, conn, 1)
    #        sample['context'] = docs[0]

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []

    viz = False
    num_viz = 10
    final_viz = []
    viz_thres = 11
    qamodel.eval().cuda()
    # Defaults from huggingface
    do_lower_case = True
    max_answer_length = 30
    verbose_logging = False
    null_score_diff_threshold = 0.0
    n_best = 20
    max_query_length = 64
    # Training specifics:
    doc_stride = 128
    max_seq_length = 384

    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]
        mymodel_probs_list = []
        predictions_list = []

        examples = read_input_examples(samples_b)
        features = convert_examples_to_features(
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_training=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            cls_token_at_end=False,
            sequence_a_is_doc=False)

        # Convert to Tensors and build dataset
        all_input_ids = torch.tensor([f.input_ids for f in features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in features],
                                       dtype=torch.long)
        all_cls_index = torch.tensor([f.cls_index for f in features],
                                     dtype=torch.long)
        all_p_mask = torch.tensor([f.p_mask for f in features],
                                  dtype=torch.float)
        all_example_index = torch.arange(all_input_ids.size(0),
                                         dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_example_index, all_cls_index, all_p_mask)
        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset,
                                     sampler=eval_sampler,
                                     batch_size=len(samples_b))
        all_results = []
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            #stime = time.time()
            batch = tuple(t.cuda() for t in batch)
            with torch.no_grad():
                inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
                inputs['token_type_ids'] = batch[
                    2]  # XLM don't use segment_ids
                example_indices = batch[3]
                outputs = qamodel(**inputs)

            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                result = RawResult(unique_id=unique_id,
                                   start_logits=to_list(outputs[0][i]),
                                   end_logits=to_list(outputs[1][i]))
                all_results.append(result)
            #total_time = time.time() - stime
            #print(total_time)
            #import ipdb
            #ipdb.set_trace()

        predictions = get_predictions(examples,
                                      features,
                                      all_results,
                                      n_best,
                                      max_answer_length,
                                      do_lower_case,
                                      verbose_logging,
                                      v2,
                                      null_score_diff_threshold,
                                      must_choose_answer=must_choose_answer)
        predictions = [predictions[p] for p in predictions]
        predictions_list.extend(predictions)

        torch.cuda.empty_cache()

        original_log_probs_list, token_ids_list, masked_indices_list = model.get_batch_generation(
            sentences_b, logger=logger)
        mymodel_probs_list = original_log_probs_list

        #obj_len = 0
        #for obj in gc.get_objects():
        #    try:
        #        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
        #            print(type(obj), obj.size())
        #            obj_len += 1
        #    except:
        #        pass
        #print(obj_len)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            #elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
            #    raise ValueError(
            #        "object label {} not in model vocabulary".format(
            #            sample["obj_label"]
            #        )
            #    )
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)

        arguments = [{
            "mymodel_probs":
            mymodel_probs,
            "original_log_probs":
            original_log_probs,
            "filtered_log_probs":
            filtered_log_probs,
            "target":
            sample["reconstructed_word"],
            "prediction":
            pred,
            "token_ids":
            token_ids,
            "vocab":
            model.vocab,
            "label_index":
            label_index[0] if len(label_index) > 0 else 0,
            "masked_indices":
            masked_indices,
            "interactive":
            args.interactive,
            "index_list":
            index_list,
            "sample":
            sample,
        } for mymodel_probs, original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample, pred in zip(
                         mymodel_probs_list,
                         original_log_probs_list,
                         filtered_log_probs_list,
                         token_ids_list,
                         masked_indices_list,
                         label_index_list,
                         samples_b,
                         predictions_list,
                     )]

        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):

            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg, sample_em, sample_f1, sample_is_error, sample_no_overlap, sample_larger_by_1, sample_larger_by_2, sample_larger_by_3, sample_larger_by_4, sample_larger_by_5_or_more = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element["sample"] = sample
            element["uuid"] = sample["uuid"]
            element["token_ids"] = token_ids_list[idx]
            element["masked_indices"] = masked_indices_list[idx]
            element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk
            element["sample_MRR"] = sample_MRR
            element["sample_Precision"] = sample_P
            element["sample_perplexity"] = sample_perplexity
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]
            element['sample_em'] = sample_em
            element['sample_f1'] = sample_f1

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element["sample_Precision1"]
            is_error += sample_is_error
            no_overlap += sample_no_overlap
            larger_by_1 += sample_larger_by_1
            larger_by_2 += sample_larger_by_2
            larger_by_3 += sample_larger_by_3
            larger_by_4 += sample_larger_by_4
            larger_by_5_or_more += sample_larger_by_5_or_more
            EM += sample_em
            F1 += sample_f1

            # the judgment of the annotators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if "judgments" in sample:
                # only for Google-RE
                for x in sample["judgments"]:
                    if x["judgment"] == "yes":
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element["judgement"] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element["judgement"] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P

            list_of_results.append(element)

    if viz:
        with open('viz.pkl', 'wb') as wf:
            pickle.dump(final_viz, wf)

    pool.close()
    pool.join()

    # stats
    # Mean reciprocal rank
    MRR /= len(list_of_results)

    # Precision
    Precision /= len(list_of_results)
    Precision1 /= len(list_of_results)

    EM /= len(list_of_results)
    F1 /= len(list_of_results)

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)
    msg += "global EM {}\n".format(EM)
    msg += "global F1: {}\n".format(F1)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement)
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement)
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(list_of_results=list_of_results,
                       global_MRR=MRR,
                       global_P_at_10=Precision)
    with open("{}/result.pkl".format(log_directory), "wb") as f:
        pickle.dump(all_results, f)

    return Precision1, Precision, MRR, EM, F1, is_error, no_overlap, larger_by_1, larger_by_2, larger_by_3, larger_by_4, larger_by_5_or_more