Ejemplo n.º 1
0
def main(args):

    if not args.subject and not args.relation:
        raise ValueError(
            'You need to specify --subject and --relation to query language models.'
        )

    print('Language Models: {}'.format(args.models_names))

    models = {}
    for lm in args.models_names:
        models[lm] = build_model_by_name(lm, args)

    vocab_subset = None
    if args.common_vocab_filename is not None:
        common_vocab = load_vocab(args.common_vocab_filename)
        print('Common vocabulary size: {}'.format(len(common_vocab)))
        vocab_subset = [x for x in common_vocab]

    prompt_file = os.path.join(args.prompts, args.relation + '.jsonl')
    if not os.path.exists(prompt_file):
        raise ValueError('Relation "{}" does not exist.'.format(args.relation))
    prompts, weights = load_prompt_weights(prompt_file)

    for model_name, model in models.items():
        print('\n{}:'.format(model_name))

        index_list = None
        if vocab_subset is not None:
            filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
                vocab_subset)

        ensemble_log_probs = 0
        for prompt, weight in zip(prompts, weights):
            prompt = parse_prompt(prompt, args.subject, model.mask_token)
            log_prob, [token_ids
                       ], [masked_indices
                           ], _, _ = model.get_batch_generation([prompt],
                                                                try_cuda=True)

            if vocab_subset is not None:
                filtered_log_probs = model.filter_logprobs(
                    log_prob, filter_logprob_indices)
            else:
                filtered_log_probs = log_prob

            # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
            if masked_indices and len(masked_indices) > 0:
                filtered_log_probs = filtered_log_probs[0][masked_indices[0]]
                ensemble_log_probs += filtered_log_probs * weight

        ensemble_log_probs = F.log_softmax(ensemble_log_probs, dim=0)
        evaluation_metrics.get_ranking(ensemble_log_probs,
                                       model.vocab,
                                       label_index=None,
                                       index_list=index_list,
                                       topk=1000,
                                       P_AT=10,
                                       print_generation=True)
Ejemplo n.º 2
0
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError('Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    print(model)
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    elif model_type_name == "roberta":
        model_name = "RoBERTa_{}".format(args.roberta_model_name)
    elif model_type_name == "hfroberta":
        model_name = "hfRoBERTa_{}".format(args.hfroberta_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger
        )

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.0
    MRR_negative = 0.0
    MRR_positive = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    # spearman rank correlation
    # overlap at 1
    if args.use_negated_probes:
        Spearman = 0.0
        Overlap = 0.0
        num_valid_negation = 0.0

    data = load_file(args.dataset_filename)

    print(len(data))

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        all_samples = lowercase_samples(
            data, use_negated_probes=args.use_negated_probes
        )
    else:
        # keep samples as they are
        all_samples = data
        # TREx data
        for i, sample in enumerate(all_samples):
            if 'masked_sentences' not in sample.keys():
                sample['masked_sentences'] = []
                for evidence in sample['evidences']:
                    sample['masked_sentences'].append(evidence['masked_sentence'])
                if i == 0:
                    print('not masked_sentences, but masked_sentence.')

    all_samples, ret_msg = filter_samples(
        model, data, vocab_subset, args.max_sentence_length, args.template
    )

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        facts = []
        for sample in all_samples:
            sub = sample["sub_label"]
            obj = sample["obj_label"]
            if (sub, obj) not in facts:
                facts.append((sub, obj))
        local_msg = "distinct template facts: {}".format(len(facts))
        logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj) = fact
            sample = {}
            sample["sub_label"] = sub
            sample["obj_label"] = obj
            # sobstitute all sentences with a standard template
            sample["masked_sentences"] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK
            )
            if args.use_negated_probes:
                # substitute all negated sentences with a standard template
                sample["negated"] = parse_template(
                    args.template_negated.strip(),
                    sample["sub_label"].strip(),
                    base.MASK,
                )
            all_samples.append(sample)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")
    if args.use_negated_probes:
        sentences_batches_negated, ret_msg = batchify_negated(
            all_samples, args.batch_size
        )
        logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []

    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        (
            original_log_probs_list,
            token_ids_list,
            masked_indices_list,
        ) = model.get_batch_generation(sentences_b, logger=logger)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices
            )
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]
                    )
                )
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]
                    )
                )
            elif vocab_subset is not None and sample["obj_label"] not in vocab_subset:
                raise ValueError(
                    "object label {} not in vocab subset".format(sample["obj_label"])
                )

            label_index_list.append(obj_label_id)

        arguments = [
            {
                "original_log_probs": original_log_probs,
                "filtered_log_probs": filtered_log_probs,
                "token_ids": token_ids,
                "vocab": model.vocab,
                "label_index": label_index[0],
                "masked_indices": masked_indices,
                "interactive": args.interactive,
                "index_list": index_list,
                "sample": sample,
            }
            for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip(
                original_log_probs_list,
                filtered_log_probs_list,
                token_ids_list,
                masked_indices_list,
                label_index_list,
                samples_b,
            )
        ]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        if args.use_negated_probes:
            sentences_b_negated = sentences_batches_negated[i]

            # if no negated sentences in batch
            if all(s[0] == "" for s in sentences_b_negated):
                res_negated = [(float("nan"), float("nan"), "")] * args.batch_size
            # eval negated batch
            else:
                (
                    original_log_probs_list_negated,
                    token_ids_list_negated,
                    masked_indices_list_negated,
                ) = model.get_batch_generation(sentences_b_negated, logger=logger)
                if vocab_subset is not None:
                    # filter log_probs
                    filtered_log_probs_list_negated = model.filter_logprobs(
                        original_log_probs_list_negated, filter_logprob_indices
                    )
                else:
                    filtered_log_probs_list_negated = original_log_probs_list_negated

                arguments = [
                    {
                        "log_probs": filtered_log_probs,
                        "log_probs_negated": filtered_log_probs_negated,
                        "token_ids": token_ids,
                        "vocab": model.vocab,
                        "label_index": label_index[0],
                        "masked_indices": masked_indices,
                        "masked_indices_negated": masked_indices_negated,
                        "index_list": index_list,
                    }
                    for filtered_log_probs, filtered_log_probs_negated, token_ids, masked_indices, masked_indices_negated, label_index in zip(
                        filtered_log_probs_list,
                        filtered_log_probs_list_negated,
                        token_ids_list,
                        masked_indices_list,
                        masked_indices_list_negated,
                        label_index_list,
                    )
                ]
                res_negated = pool.map(run_thread_negated, arguments)

        for idx, result in enumerate(res):

            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element["sample"] = sample
            element["uuid"] = sample["uuid"]
            element["token_ids"] = token_ids_list[idx]
            element["masked_indices"] = masked_indices_list[idx]
            element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk
            element["sample_MRR"] = sample_MRR
            element["sample_Precision"] = sample_P
            element["sample_perplexity"] = sample_perplexity
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            if args.use_negated_probes:
                overlap, spearman, msg = res_negated[idx]
                # sum overlap and spearmanr if not nan
                if spearman == spearman:
                    element["spearmanr"] = spearman
                    element["overlap"] = overlap
                    Overlap += overlap
                    Spearman += spearman
                    num_valid_negation += 1.0

            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element["sample_Precision1"]

            # the judgment of the annotators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if "judgments" in sample:
                # only for Google-RE
                for x in sample["judgments"]:
                    if x["judgment"] == "yes":
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element["judgement"] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element["judgement"] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P

            list_of_results.append(element)

    pool.close()
    pool.join()

    # stats
    try:
       # Mean reciprocal rank
       MRR /= len(list_of_results)

       # Precision
       Precision /= len(list_of_results)
       Precision1 /= len(list_of_results)
    except ZeroDivisionError:
       MRR = Precision = Precision1 = 0.0

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)

    if args.use_negated_probes:
        Overlap /= num_valid_negation
        Spearman /= num_valid_negation
        msg += "\n"
        msg += "results negation:\n"
        msg += "all_negated_samples: {}\n".format(int(num_valid_negation))
        msg += "global spearman rank affirmative/negated: {}\n".format(Spearman)
        msg += "global overlap at 1 affirmative/negated: {}\n".format(Overlap)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement
        )
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement
        )
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(
        list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision
    )
    with open("{}/result.pkl".format(log_directory), "wb") as f:
        pickle.dump(all_results, f)

    return Precision1
Ejemplo n.º 3
0
def main(args):

    if not args.text and not args.interactive:
        msg = "ERROR: either you start LAMA eval_generation with the " \
              "interactive option (--i) or you pass in input a piece of text (--t)"
        raise ValueError(msg)

    stopping_condition = True

    print("Language Models: {}".format(args.models_names))

    models = {}
    for lm in args.models_names:
        models[lm] = build_model_by_name(lm, args)

    vocab_subset = None
    if args.common_vocab_filename is not None:
        common_vocab = load_vocab(args.common_vocab_filename)
        print("common vocabulary size: {}".format(len(common_vocab)))
        vocab_subset = [x for x in common_vocab]

    while stopping_condition:
        if args.text:
            text = args.text
            stopping_condition = False
        else:
            text = input("insert text:")

        if args.split_sentence:
            import spacy
            # use spacy to tokenize input sentence
            nlp = spacy.load(args.spacy_model)
            tokens = nlp(text)
            print(tokens)
            sentences = []
            for s in tokens.sents:
                print(" - {}".format(s))
                sentences.append(s.text)
        else:
            sentences = [text]

        if len(sentences) > 2:
            print(
                "WARNING: only the first two sentences in the text will be considered!"
            )
            sentences = sentences[:2]

        for model_name, model in models.items():
            print("\n{}:".format(model_name))
            original_log_probs_list, [token_ids], [
                masked_indices
            ] = model.get_batch_generation([sentences], try_cuda=False)

            index_list = None
            if vocab_subset is not None:
                # filter log_probs
                filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
                    vocab_subset)
                filtered_log_probs_list = model.filter_logprobs(
                    original_log_probs_list, filter_logprob_indices)
            else:
                filtered_log_probs_list = original_log_probs_list

            # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
            if masked_indices and len(masked_indices) > 0:
                evaluation_metrics.get_ranking(filtered_log_probs_list[0],
                                               masked_indices,
                                               model.vocab,
                                               index_list=index_list)

            # prediction and perplexity for the whole softmax
            print_sentence_predictions(original_log_probs_list[0],
                                       token_ids,
                                       model.vocab,
                                       masked_indices=masked_indices)
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    args.output_feature_path = getattr(args, 'output_feature_path', '')
    if getattr(args, 'knn_thresh', 0) > 0:
        assert hasattr(args, 'knn_path')
        assert hasattr(args, 'modify_ans')
    else:
        args.knn_thresh = 0

    if getattr(args, 'knn_path', ''):
        knn_dict = torch.load(args.knn_path)
        if getattr(args, 'consine_dist', True):
            knn_dict['mask_features'] = knn_dict['mask_features'] / torch.norm(
                knn_dict['mask_features'], dim=1, keepdim=True)
        else:
            knn_dict['mask_features'] = knn_dict['mask_features']
        new_ans_dict = json.load(open(args.modify_ans))
        knn_dict['obj_labels'] = [
            new_ans_dict[uuid] for uuid in knn_dict['uuids']
        ]
    else:
        new_ans_dict = None
        knn_dict = None

    print(model)
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open(os.path.join(log_directory, 'args.json'), "w") as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.0
    MRR_negative = 0.0
    MRR_positive = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision1_modified = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    # spearman rank correlation
    # overlap at 1
    if args.use_negated_probes:
        Spearman = 0.0
        Overlap = 0.0
        num_valid_negation = 0.0

    data = load_file(args.dataset_filename)

    print(len(data))

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        if getattr(args, 'use_evidences', False):
            new_all_samples = []
            for sample in all_samples:
                if len(args.uuid_list
                       ) > 0 and sample['uuid'] not in args.uuid_list:
                    continue
                elif len(args.uuid_list) > 0:
                    print(sample['uuid'])
                sub = sample["sub_label"]
                if new_ans_dict is not None and sample['uuid'] in new_ans_dict:
                    # we need to replace the answer in this way
                    obj = new_ans_dict[sample['uuid']]
                else:
                    obj = sample["obj_label"]

                if sample['uuid'] == '11fc104b-bba2-412c-b2d7-cf06cd2bd715':
                    sample['evidences'] = sample['evidences'][:32]

                for ne, evidence in enumerate(sample['evidences']):
                    # maximum of 10 evidences per fact
                    if ne >= 10:
                        continue
                    new_sample = {'sub_label': sub, 'obj_label': obj}
                    if '[MASK]' not in evidence['masked_sentence']:
                        continue
                    new_sample['masked_sentences'] = [
                        evidence['masked_sentence']
                    ]
                    new_sample['uuid'] = sample['uuid']
                    new_all_samples.append(new_sample)

            all_samples = new_all_samples
        else:
            facts = []
            for sample in all_samples:
                sub = sample["sub_label"]
                if new_ans_dict is not None and sample['uuid'] in new_ans_dict:
                    # we need to replace the answer in this way
                    obj = new_ans_dict[sample['uuid']]
                else:
                    obj = sample["obj_label"]
                if (sub, obj) not in facts:
                    facts.append((sample['uuid'], sub, obj))
            local_msg = "distinct template facts: {}".format(len(facts))
            logger.info("\n" + local_msg + "\n")
            print(local_msg)

            all_samples = []
            for fact in facts:
                (uuid, sub, obj) = fact
                sample = {}
                sample["sub_label"] = sub
                sample["obj_label"] = obj
                sample["uuid"] = uuid
                # sobstitute all sentences with a standard template
                sample["masked_sentences"] = parse_template(
                    args.template.strip(), sample["sub_label"].strip(),
                    base.MASK)
                if args.use_negated_probes:
                    # substitute all negated sentences with a standard template
                    sample["negated"] = parse_template(
                        args.template_negated.strip(),
                        sample["sub_label"].strip(),
                        base.MASK,
                    )
                all_samples.append(sample)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")
    if args.use_negated_probes:
        sentences_batches_negated, ret_msg = batchify_negated(
            all_samples, args.batch_size)
        logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []
    total_modified = 0

    mask_feature_all, answers_list, uid_list = [], [], []
    correct_uuids = []
    knn_preds_list = []
    for i in tqdm(range(len(samples_batches))):
        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        rets = model.get_batch_generation(sentences_b,
                                          logger=logger,
                                          return_features=args.return_features
                                          or args.knn_thresh > 0)
        if len(rets) == 4:
            original_log_probs_list, token_ids_list, masked_indices_list, feature_tensor = rets
            mask_feature_all.append(feature_tensor)
        else:
            original_log_probs_list, token_ids_list, masked_indices_list = rets

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        modified_flags_list = []
        for ns, sample in enumerate(samples_b):
            obj_label_id = model.get_id(sample["obj_label"])
            answers_list.append(sample["obj_label"])
            uid_list.append(sample['uuid'])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)

            if args.knn_thresh > 0:
                feature = feature_tensor[ns].view(1, -1)
                if getattr(args, 'consine_dist', True):
                    dist = torch.sum(feature * knn_dict['mask_features'],
                                     dim=1) / torch.norm(feature)
                else:
                    dist = torch.norm(feature - knn_dict['mask_features'],
                                      dim=1)
                min_dist, min_idx = torch.min(dist, dim=0)
                # print(min_dist.item())
                if min_dist < args.knn_thresh:
                    knn_pred = knn_dict['obj_labels'][min_idx.item()]
                    knn_preds_list.append(model.get_id(knn_pred)[0])
                    # if knn_dict['uuids'][min_idx.item()] == sample['uuid']:
                    #     pdb.set_trace()
                else:
                    knn_preds_list.append(-1)
                # log_probs.unsqueeze()
                # knn_preds_list.
            else:
                knn_preds_list.append(-1)

            # label whether the fact has been modified
            modified_flags_list.append(new_ans_dict is not None
                                       and sample['uuid'] in new_ans_dict)

        arguments = [{
            "original_log_probs": original_log_probs,
            "filtered_log_probs": filtered_log_probs,
            "token_ids": token_ids,
            "vocab": model.vocab,
            "label_index": label_index[0],
            "masked_indices": masked_indices,
            "interactive": args.interactive,
            "index_list": index_list,
            "sample": sample,
            "knn_pred": knn_pred,
            "modified": modified
        } for original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample, knn_pred, modified in
                     zip(original_log_probs_list, filtered_log_probs_list,
                         token_ids_list, masked_indices_list, label_index_list,
                         samples_b, knn_preds_list, modified_flags_list)]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        if args.use_negated_probes:
            sentences_b_negated = sentences_batches_negated[i]

            # if no negated sentences in batch
            if all(s[0] == "" for s in sentences_b_negated):
                res_negated = [(float("nan"), float("nan"), "")
                               ] * args.batch_size
            # eval negated batch
            else:
                (
                    original_log_probs_list_negated,
                    token_ids_list_negated,
                    masked_indices_list_negated,
                ) = model.get_batch_generation(sentences_b_negated,
                                               logger=logger)
                if vocab_subset is not None:
                    # filter log_probs
                    filtered_log_probs_list_negated = model.filter_logprobs(
                        original_log_probs_list_negated,
                        filter_logprob_indices)
                else:
                    filtered_log_probs_list_negated = original_log_probs_list_negated

                arguments = [{
                    "log_probs": filtered_log_probs,
                    "log_probs_negated": filtered_log_probs_negated,
                    "token_ids": token_ids,
                    "vocab": model.vocab,
                    "label_index": label_index[0],
                    "masked_indices": masked_indices,
                    "masked_indices_negated": masked_indices_negated,
                    "index_list": index_list,
                } for filtered_log_probs, filtered_log_probs_negated,
                             token_ids, masked_indices, masked_indices_negated,
                             label_index in zip(
                                 filtered_log_probs_list,
                                 filtered_log_probs_list_negated,
                                 token_ids_list,
                                 masked_indices_list,
                                 masked_indices_list_negated,
                                 label_index_list,
                             )]
                res_negated = pool.map(run_thread_negated, arguments)

        for idx, result in enumerate(res):
            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element["sample"] = sample
            element["uuid"] = sample["uuid"]
            element["token_ids"] = token_ids_list[idx]
            element["masked_indices"] = masked_indices_list[idx]
            element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk
            element["sample_MRR"] = sample_MRR
            element["sample_Precision"] = sample_P
            element["sample_perplexity"] = sample_perplexity
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]
            element["modified"] = result_masked_topk["modified"]
            if result_masked_topk["P_AT_1"] > 0:
                correct_uuids.append(element['uuid'])

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            if args.use_negated_probes:
                overlap, spearman, msg = res_negated[idx]
                # sum overlap and spearmanr if not nan
                if spearman == spearman:
                    element["spearmanr"] = spearman
                    element["overlap"] = overlap
                    Overlap += overlap
                    Spearman += spearman
                    num_valid_negation += 1.0

            MRR += sample_MRR
            Precision += sample_P
            if element["modified"]:
                Precision1_modified += element["sample_Precision1"]
            else:
                Precision1 += element["sample_Precision1"]

            # the judgment of the annotators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if "judgments" in sample:
                # only for Google-RE
                for x in sample["judgments"]:
                    if x["judgment"] == "yes":
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element["judgement"] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element["judgement"] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P
            if element["modified"]:
                total_modified += 1
            else:
                list_of_results.append(element)

    pool.close()
    pool.join()

    if args.output_feature_path and len(list_of_results) == 0:
        # torch.save(out_dict, args.output_feature_path)
        # return empty results
        return Precision1, uid_list, mask_feature_all, answers_list
    elif len(list_of_results) == 0:
        pdb.set_trace()

    # stats
    # Mean reciprocal rank
    MRR /= len(list_of_results)

    # Precision
    Precision /= len(list_of_results)
    # Precision1 /= len(list_of_results)

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)

    if args.use_negated_probes:
        Overlap /= num_valid_negation
        Spearman /= num_valid_negation
        msg += "\n"
        msg += "results negation:\n"
        msg += "all_negated_samples: {}\n".format(int(num_valid_negation))
        msg += "global spearman rank affirmative/negated: {}\n".format(
            Spearman)
        msg += "global overlap at 1 affirmative/negated: {}\n".format(Overlap)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement)
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement)
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(list_of_results=list_of_results,
                       global_MRR=MRR,
                       global_P_at_10=Precision)
    with open("{}/result.pkl".format(log_directory), "wb") as f:
        pickle.dump(all_results, f)

    if args.output_feature_path:
        # torch.save(out_dict, args.output_feature_path)
        return Precision1, len(
            list_of_results
        ), Precision1_modified, total_modified, uid_list, mask_feature_all, answers_list

    return Precision1, len(
        list_of_results), Precision1_modified, total_modified, correct_uuids
Ejemplo n.º 5
0
def main(args):
    verbose = False
    moral_template = "{} is a [MASK] behavior ."

    data_user_study = pickle.load(
        open('../MoRT/data/parsed_yes_no_BERTBias.p', 'rb'))
    actions = list(data_user_study.keys())
    actions = ['help people', 'help white people', 'help black people']
    print("Language Models: {}".format(args.models_names))

    models = {}
    for lm in args.models_names:
        models[lm] = build_model_by_name(lm, args)

    vocab_subset = None
    if args.common_vocab_filename is not None:
        common_vocab = load_vocab(args.common_vocab_filename)
        print("common vocabulary size: {}".format(len(common_vocab)))
        vocab_subset = [x for x in common_vocab]

    data_user_study_BERTKnowledge = dict()
    for action in actions:
        action_ = action.split(" ")
        action_[0] = verb_noun_mapping[action_[0]].capitalize()
        action_ = " ".join(action_)
        text = moral_template.format(action_)
        if args.split_sentence:
            import spacy
            # use spacy to tokenize input sentence
            nlp = spacy.load(args.spacy_model)
            tokens = nlp(text)
            print(tokens)
            sentences = []
            for s in tokens.sents:
                print(" - {}".format(s))
                sentences.append(s.text)
        else:
            sentences = [text]

        if len(sentences) > 2:
            print(
                "WARNING: only the first two sentences in the text will be considered!"
            )
            sentences = sentences[:2]

        for model_name, model in models.items():
            if model_name not in list(data_user_study_BERTKnowledge.keys()):
                data_user_study_BERTKnowledge[model_name] = {}
            if verbose:
                print("\n{}:".format(model_name))
            original_log_probs_list, [token_ids], [
                masked_indices
            ] = model.get_batch_generation([sentences], try_cuda=False)

            index_list = None
            if vocab_subset is not None:
                # filter log_probs
                filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
                    vocab_subset)
                filtered_log_probs_list = model.filter_logprobs(
                    original_log_probs_list, filter_logprob_indices)
            else:
                filtered_log_probs_list = original_log_probs_list
            # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
            if masked_indices and len(masked_indices) > 0:
                _, _, experiment_result, _ = evaluation_metrics.get_ranking(
                    filtered_log_probs_list[0],
                    masked_indices,
                    model.vocab,
                    index_list=index_list,
                    print_generation=verbose)

            experiment_result_topk = [(r['i'], r['token_word_form'],
                                       r['log_prob'])
                                      for r in experiment_result['topk'][:10]]
            data_user_study_BERTKnowledge[model_name][action] = [
                text, experiment_result_topk
            ]
            # prediction and perplexity for the whole softmax
            if verbose:
                print_sentence_predictions(original_log_probs_list[0],
                                           token_ids,
                                           model.vocab,
                                           masked_indices=masked_indices)

    print(data_user_study_BERTKnowledge)

    pickle.dump(data_user_study_BERTKnowledge,
                open('./parsed_BERTKnowledge_tests.p', 'wb'))
Ejemplo n.º 6
0
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            "Please specify a single language model (e.g., --lm \"bert\").")

    msg = ""

    [model_type_name] = args.models_names

    print(model)
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == 'fairseq':
        model_name = 'fairseq_{}'.format(args.fairseq_model_name)
    elif model_type_name == 'bert':
        model_name = 'BERT_{}'.format(args.bert_model_name)
    elif model_type_name == 'elmo':
        model_name = 'ELMo_{}'.format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), 'w') as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.
    MRR_negative = 0.
    MRR_positive = 0.

    # Precision at (default 10)
    Precision = 0.
    Precision1 = 0.
    Precision_negative = 0.
    Precision_positivie = 0.

    data = load_file(args.dataset_filename)

    print(len(data))

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        all_samples = lowercase_samples(data)
    else:
        # keep samples as they are
        all_samples = data

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != '':
        facts = []
        for sample in all_samples:
            sub = sample['sub_label']
            obj = sample['obj_label']
            if (sub, obj) not in facts:
                facts.append((sub, obj))
        local_msg = "distinct template facts: {}".format(len(facts))
        logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj) = fact
            sample = {}
            sample['sub_label'] = sub
            sample['obj_label'] = obj
            # sobstitute all sentences with a standard template
            sample['masked_sentences'] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK)
            all_samples.append(sample)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if 'uuid' not in sample:
            sample['uuid'] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []

    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        original_log_probs_list, token_ids_list, masked_indices_list = model.get_batch_generation(
            sentences_b, logger=logger)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample['obj_label'])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample['obj_label']))
            elif (model.vocab[obj_label_id[0]] != sample['obj_label']):
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample['obj_label']))
            elif vocab_subset is not None and sample[
                    'obj_label'] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample['obj_label']))

            label_index_list.append(obj_label_id)

        arguments = [{
            'original_log_probs': original_log_probs,
            'filtered_log_probs': filtered_log_probs,
            'token_ids': token_ids,
            'vocab': model.vocab,
            'label_index': label_index[0],
            'masked_indices': masked_indices,
            'interactive': args.interactive,
            'index_list': index_list,
            'sample': sample
        } for original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample in zip(
                         original_log_probs_list, filtered_log_probs_list,
                         token_ids_list, masked_indices_list, label_index_list,
                         samples_b)]

        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):

            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element['sample'] = sample
            element['uuid'] = sample['uuid']
            element['token_ids'] = token_ids_list[idx]
            element['masked_indices'] = masked_indices_list[idx]
            element['label_index'] = label_index_list[idx]
            element['masked_topk'] = result_masked_topk
            element['sample_MRR'] = sample_MRR
            element['sample_Precision'] = sample_P
            element['sample_perplexity'] = sample_perplexity
            element['sample_Precision1'] = result_masked_topk["P_AT_1"]

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element['sample_Precision1']

            # the judgment of the annotators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if 'judgments' in sample:
                # only for Google-RE
                for x in sample['judgments']:
                    if (x['judgment'] == "yes"):
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element['judgement'] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element['judgement'] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P

            list_of_results.append(element)

    pool.close()
    pool.join()

    # stats
    # Mean reciprocal rank
    MRR /= len(list_of_results)

    # Precision
    Precision /= len(list_of_results)
    Precision1 /= len(list_of_results)

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement)
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement)
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(
        list_of_results=list_of_results,
        global_MRR=MRR,
        global_P_at_10=Precision,
    )
    with open("{}/result.pkl".format(log_directory), 'wb') as f:
        pickle.dump(all_results, f)

    return Precision1
Ejemplo n.º 7
0
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""
    [model_type_name] = args.models_names

    # print("------- Model: {}".format(model))
    # print("------- Args: {}".format(args))
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    # logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    # with open("{}/args.json".format(log_directory), "w") as outfile:
    #     json.dump(vars(args), outfile)

    # Mean reciprocal rank
    MRR = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    data = load_file(args.dataset_filename)

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    # logger.info("\n" + ret_msg + "\n")

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        facts = []
        for sample in all_samples:
            sub = sample["sub_label"]
            obj = sample["obj_label"]
            if (sub, obj) not in facts:
                facts.append((sub, obj))
        local_msg = "distinct template facts: {}".format(len(facts))
        # logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj) = fact
            sample = {}
            sample["sub_label"] = sub
            sample["obj_label"] = obj
            # sobstitute all sentences with a standard template
            sample["masked_sentences"] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK)
            if args.use_negated_probes:
                # substitute all negated sentences with a standard template
                sample["negated"] = parse_template(
                    args.template_negated.strip(),
                    sample["sub_label"].strip(),
                    base.MASK,
                )
            all_samples.append(sample)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    # logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)

    # list_of_results = []
    # list_of_ranks = []
    item_count = 0
    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        (
            original_log_probs_list,
            token_ids_list,
            masked_indices_list,
        ) = model.get_batch_generation(sentences_b, logger=logger)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)

        arguments = [{
            "original_log_probs": original_log_probs,
            "filtered_log_probs": filtered_log_probs,
            "token_ids": token_ids,
            "vocab": model.vocab,
            "label_index": label_index[0],
            "masked_indices": masked_indices,
            "interactive": args.interactive,
            "index_list": index_list,
            "sample": sample,
        } for original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample in zip(
                         original_log_probs_list,
                         filtered_log_probs_list,
                         token_ids_list,
                         masked_indices_list,
                         label_index_list,
                         samples_b,
                     )]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):
            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

            # print("~~~~~~~~~~~~~~~~~~")
            # print(result_masked_topk)

            # logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            obj = sample['obj_label']
            sub = sample['sub_label']
            element["masked_sentences"] = sample["masked_sentences"][0]
            # element["uuid"] = sample["uuid"]
            element["subject"] = sub
            element["object"] = obj
            element["rank"] = int(result_masked_topk['rank'])
            # element["sample_Precision1"] = result_masked_topk["P_AT_1"]

            # element["sample"] = sample
            # element["token_ids"] = token_ids_list[idx]
            # element["masked_indices"] = masked_indices_list[idx]
            # element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk['topk'][:20]
            # element["sample_MRR"] = sample_MRR
            # element["sample_Precision"] = sample_P
            # element["sample_perplexity"] = sample_perplexity

            # list_of_results[sub + "_" + obj].append(element)
            # list_of_ranks[sub + "_" + obj].append(element["rank"])

            # print("~~~~~~ rank: {}".format(result_masked_topk['rank']))
            MRR += sample_MRR
            Precision += sample_P
            Precision1 += result_masked_topk["P_AT_1"]

            item_count += 1
            append_data_line_to_jsonl(
                "reproduction/data/P_AT_K/{}_rank_results.jsonl".format(
                    args.label), element)
            append_data_line_to_jsonl(
                "reproduction/data/P_AT_K/{}_rank_list.jsonl".format(
                    args.label), element["rank"])

    pool.close()
    pool.join()
    Precision1 /= item_count

    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_results.jsonl".format(args.label), list_of_results) # 3122
    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_dic.jsonl".format(args.label), list_of_ranks) # 3122
    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_list.jsonl".format(args.label), list(list_of_ranks.values())) # 3122

    return Precision1
Ejemplo n.º 8
0
def main(args,
         shuffle_data=True,
         model=None,
         model2=None,
         refine_template=False,
         get_objs=False,
         dynamic='none',
         use_prob=False,
         bt_obj=None,
         temp_model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    #print(model)
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename,
                                  lower=args.lowercase)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    if dynamic == 'all_topk':  # save topk results for different k
        # stats
        samples_with_negative_judgement = [
            0 for _ in range(len(args.template))
        ]
        samples_with_positive_judgement = [
            0 for _ in range(len(args.template))
        ]

        # Mean reciprocal rank
        MRR = [0.0 for _ in range(len(args.template))]
        MRR_negative = [0.0 for _ in range(len(args.template))]
        MRR_positive = [0.0 for _ in range(len(args.template))]

        # Precision at (default 10)
        Precision = [0.0 for _ in range(len(args.template))]
        Precision1 = [0.0 for _ in range(len(args.template))]
        Precision_negative = [0.0 for _ in range(len(args.template))]
        Precision_positivie = [0.0 for _ in range(len(args.template))]

        list_of_results = [[] for _ in range(len(args.template))]
        P1_li = [[] for _ in range(len(args.template))]
    else:
        # stats
        samples_with_negative_judgement = [0]
        samples_with_positive_judgement = [0]

        # Mean reciprocal rank
        MRR = [0.0]
        MRR_negative = [0.0]
        MRR_positive = [0.0]

        # Precision at (default 10)
        Precision = [0.0]
        Precision1 = [0.0]
        Precision_negative = [0.0]
        Precision_positivie = [0.0]

        list_of_results = [[]]
        P1_li = [[]]

    data = load_file(args.dataset_filename)
    for s in data:
        s['raw_sub_label'] = s['sub_label']
        s['raw_obj_label'] = s['obj_label']

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        data = lowercase_samples(data)

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print('#head-tails {} -> {}'.format(len(data), len(all_samples)))

    samples_batches_li, sentences_batches_li = [], []
    for template in args.template:
        # if template is active (1) use a single example for (sub,obj) and (2) ...
        if template and template != "":
            facts = []
            samples = []
            for sample in all_samples:
                sub = sample["sub_label"]
                obj = sample["obj_label"]
                if (sub, obj) not in facts:
                    facts.append((sub, obj))
                    samples.append(sample)
            local_msg = "distinct template facts: {}".format(len(facts))
            logger.info("\n" + local_msg + "\n")
            new_all_samples = []
            for fact, raw_sample in zip(facts, samples):
                (sub, obj) = fact
                sample = {}
                sample["sub_label"] = sub
                sample["obj_label"] = obj
                # sobstitute all sentences with a standard template
                sample["masked_sentences"] = parse_template(
                    template.strip(), raw_sample["raw_sub_label"].strip()
                    if args.upper_entity else sample["sub_label"].strip(),
                    model.mask_token)
                sub_uri = raw_sample[
                    'sub_uri'] if 'sub_uri' in raw_sample else raw_sample['sub']
                sample['entity_list'] = get_entity_list(
                    template.strip(), raw_sample['raw_sub_label'].strip(),
                    sub_uri, None, None)
                if dynamic.startswith('bt_topk') or (temp_model is not None
                                                     and bt_obj):
                    sample['sub_masked_sentences'] = parse_template_tokenize(
                        template.strip(),
                        sample["sub_label"].strip(),
                        model,
                        mask_part='sub')
                if dynamic == 'real_lm' or dynamic.startswith('real_lm_topk'):
                    sample["tokenized_sentences"] = parse_template_tokenize(
                        template.strip(),
                        sample["sub_label"].strip(),
                        model,
                        mask_part='relation')
                # substitute sub and obj placeholder in template with corresponding str
                # and add bracket to the relational phrase
                sample['bracket_sentences'] = bracket_relational_phrase(
                    template.strip(), sample['sub_label'].strip(),
                    sample['obj_label'].strip())
                new_all_samples.append(sample)

        # create uuid if not present
        i = 0
        for sample in new_all_samples:
            if "uuid" not in sample:
                sample["uuid"] = i
            i += 1

        if args.lowercase and not args.upper_entity:
            # lowercase all samples
            logger.info("lowercasing all samples...")
            new_all_samples = lowercase_samples(new_all_samples)

        # shuffle data
        if shuffle_data:
            perm = np.random.permutation(len(new_all_samples))
            new_all_samples = np.array(new_all_samples)[perm]
            raise Exception

        samples_batches, sentences_batches, ret_msg = batchify(
            new_all_samples, args.batch_size)
        logger.info("\n" + ret_msg + "\n")
        samples_batches_li.append(samples_batches)
        sentences_batches_li.append(sentences_batches)

        sub_obj_labels = [(sample['sub_label'], sample['obj_label'])
                          for batch in samples_batches for sample in batch]
        if get_objs:
            print('sub_obj_label {}'.format('\t'.join(
                map(lambda p: '{}\t{}'.format(*p), sub_obj_labels))))
            return

        if refine_template:
            bracket_sentences = [
                sample['bracket_sentences'] for sample in new_all_samples
            ]
            new_temp = model.refine_cloze(bracket_sentences,
                                          batch_size=32,
                                          try_cuda=True)
            new_temp = replace_template(template.strip(), ' '.join(new_temp))
            print('old temp: {}'.format(template.strip()))
            print('new temp: {}'.format(new_temp))
            return new_temp

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)

    samples_batches_li = list(zip(*samples_batches_li))
    sentences_batches_li = list(zip(*sentences_batches_li))

    c_inc_meaning = ['top12 prob gap', 'top1 prob']
    c_inc_stat = np.zeros((2, 3))  # [[*, c_num], [*, inc_num]]

    loss_list = []
    features_list = []
    features_list2 = []
    bt_features_list = []
    label_index_tensor_list = []

    for i in tqdm(range(len(samples_batches_li))):

        samples_b_all = samples_batches_li[i]
        sentences_b_all = sentences_batches_li[i]

        filter_lp_merge = None
        filter_lp_merge2 = None
        samples_b = samples_b_all[-1]
        max_score = float('-inf')
        consist_score_li = []

        samples_b_prev = None
        for sentences_b, samples_b_this in zip(sentences_b_all, samples_b_all):
            if samples_b_prev is not None:
                for ps, ts in zip(samples_b_prev, samples_b_this):
                    assert ps['uuid'] == ts['uuid']

            entity_list_b = [s['entity_list'] for s in samples_b_this]
            # TODO: add tokens_tensor and mask_tensor for more models
            original_log_probs_list, token_ids_list, masked_indices_list, tokens_tensor, mask_tensor = \
                model.get_batch_generation(sentences_b, logger=logger, entity_list=entity_list_b)
            if model2 is not None:
                original_log_probs_list2, token_ids_list2, masked_indices_list2, tokens_tensor2, mask_tensor2 = \
                    model2.get_batch_generation(sentences_b, logger=logger)

            if use_prob:  # use prob instead of log prob
                original_log_probs_list = original_log_probs_list.exp()
                if model2 is not None:
                    original_log_probs_list2 = original_log_probs_list2.exp()

            if dynamic == 'real_lm' or dynamic.startswith('real_lm_topk'):
                sentences_b_mask_rel = [
                    s['tokenized_sentences'][0] for s in samples_b_this
                ]
                relation_mask = [
                    s['tokenized_sentences'][1] for s in samples_b_this
                ]
                consist_log_probs_list, _, _, tokens_tensor, mask_tensor = \
                    model.get_batch_generation(sentences_b_mask_rel, logger=logger, relation_mask=relation_mask)
            else:
                consist_log_probs_list = original_log_probs_list

            if dynamic == 'lm' or dynamic == 'real_lm' or dynamic.startswith(
                    'real_lm_topk'):
                # use avg prob of the templates as score
                mask_tensor = mask_tensor.float()
                consist_log_probs_list_flat = consist_log_probs_list.view(
                    -1, consist_log_probs_list.size(-1))
                token_logprob = torch.gather(
                    consist_log_probs_list_flat,
                    dim=1,
                    index=tokens_tensor.view(
                        -1, 1)).view(*consist_log_probs_list.size()[:2])
                token_logprob = token_logprob * mask_tensor
                consist_score = token_logprob.sum(-1) / mask_tensor.sum(
                    -1)  # normalized prob
            '''
            if vocab_subset is not None:
                # filter log_probs
                filtered_log_probs_list = model.filter_logprobs(
                    original_log_probs_list, filter_logprob_indices
                )
            else:
                filtered_log_probs_list = original_log_probs_list
            '''

            # get the prediction probability
            if vocab_subset is not None:
                filtered_log_probs_list = [
                    flp[masked_indices_list[ind][0]].index_select(
                        dim=-1, index=filter_logprob_indices)
                    for ind, flp in enumerate(original_log_probs_list)
                ]
                if model2 is not None:
                    filtered_log_probs_list2 = [
                        flp[masked_indices_list2[ind][0]].index_select(
                            dim=-1, index=filter_logprob_indices)
                        for ind, flp in enumerate(original_log_probs_list2)
                    ]
            else:
                filtered_log_probs_list = [
                    flp[masked_indices_list[ind][0]]
                    for ind, flp in enumerate(original_log_probs_list)
                ]
                if model2 is not None:
                    filtered_log_probs_list2 = [
                        flp[masked_indices_list2[ind][0]]
                        for ind, flp in enumerate(original_log_probs_list2)
                    ]

            if dynamic.startswith('bt_topk'):
                obj_topk = int(dynamic.rsplit('-', 1)[1])
                top_obj_pred = [
                    flp.topk(k=obj_topk) for flp in filtered_log_probs_list
                ]
                top_obj_logprob, top_obj_pred = zip(*top_obj_pred)

            if dynamic.startswith('obj_lm_topk'):
                # use highest obj prob as consistency score
                consist_score = torch.tensor(
                    [torch.max(flp).item() for flp in filtered_log_probs_list])
            elif dynamic.startswith('obj_lmgap_topk'):
                # the gap between the highest prediction log p1 - log p2
                get_gap = lambda top2: (top2[0] - top2[1]).item()
                consist_score = torch.tensor([
                    get_gap(torch.topk(flp, k=2)[0])
                    for flp in filtered_log_probs_list
                ])
            elif dynamic.startswith('bt_topk'):
                # use the obj_topk highest obj to "back translate" sub
                consist_score_obj_topk = []
                used_vocab = vocab_subset if vocab_subset is not None else model.vocab
                for obj_i in range(obj_topk):
                    sentences_b_mask_sub = [[
                        replace_list(s['sub_masked_sentences'][0][0],
                                     model.mask_token,
                                     used_vocab[obj_pred[obj_i].item()])
                    ] for s, obj_pred in zip(samples_b_this, top_obj_pred)]
                    sub_mask = [
                        s['sub_masked_sentences'][1] for s in samples_b_this
                    ]
                    # TODO: only masked lm can do this
                    consist_log_probs_list, _, _, tokens_tensor, mask_tensor = \
                        model.get_batch_generation(sentences_b_mask_sub, logger=logger, relation_mask=sub_mask)
                    # use avg prob of the sub as score
                    mask_tensor = mask_tensor.float()
                    consist_log_probs_list_flat = consist_log_probs_list.view(
                        -1, consist_log_probs_list.size(-1))
                    token_logprob = torch.gather(
                        consist_log_probs_list_flat,
                        dim=1,
                        index=tokens_tensor.view(
                            -1, 1)).view(*consist_log_probs_list.size()[:2])
                    token_logprob = token_logprob * mask_tensor
                    consist_score = token_logprob.sum(-1) / mask_tensor.sum(
                        -1)  # normalized prob
                    consist_score_obj_topk.append(consist_score)

                # SHAPE: (batch_size, obj_topk)
                consist_score_obj_topk = torch.stack(
                    consist_score_obj_topk).permute(1, 0)
                consist_score_weight = torch.stack(top_obj_logprob).exp()
                # SHAPE: (batch_size)
                consist_score = (consist_score_obj_topk *
                                 consist_score_weight).sum(-1) / (
                                     consist_score_weight.sum(-1) + 1e-10)

            # add to overall probability
            if filter_lp_merge is None:
                filter_lp_merge = filtered_log_probs_list
                if model2 is not None:
                    filter_lp_merge2 = filtered_log_probs_list2
                if dynamic == 'lm' or dynamic == 'real_lm':
                    max_score = consist_score
                elif dynamic.startswith('real_lm_topk') or \
                        dynamic.startswith('obj_lm_topk') or \
                        dynamic.startswith('obj_lmgap_topk') or \
                        dynamic.startswith('bt_topk'):
                    consist_score_li.append(consist_score)
            else:
                if dynamic == 'none' and temp_model is None:
                    filter_lp_merge = [
                        a + b for a, b in zip(filter_lp_merge,
                                              filtered_log_probs_list)
                    ]
                elif dynamic == 'all_topk':
                    filter_lp_merge.extend(filtered_log_probs_list)
                elif dynamic == 'lm' or dynamic == 'real_lm':
                    filter_lp_merge = \
                        [a if c >= d else b for a, b, c, d in
                         zip(filter_lp_merge, filtered_log_probs_list, max_score, consist_score)]
                    max_score = torch.max(max_score, consist_score)
                elif dynamic.startswith('real_lm_topk') or \
                        dynamic.startswith('obj_lm_topk') or \
                        dynamic.startswith('obj_lmgap_topk') or \
                        dynamic.startswith('bt_topk'):
                    filter_lp_merge.extend(filtered_log_probs_list)
                    consist_score_li.append(consist_score)
                elif temp_model is not None:
                    filter_lp_merge.extend(filtered_log_probs_list)
                    if model2 is not None:
                        filter_lp_merge2.extend(filtered_log_probs_list2)

            samples_b_prev = samples_b_this

        label_index_list = []
        obj_word_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)
            obj_word_list.append(sample['obj_label'])

        if dynamic == 'all_topk' or \
                dynamic.startswith('real_lm_topk') or \
                dynamic.startswith('obj_lm_topk') or \
                dynamic.startswith('obj_lmgap_topk') or \
                dynamic.startswith('bt_topk') or \
                temp_model is not None:  # analyze prob
            # SHAPE: (batch_size, num_temp, filter_vocab_size)
            filter_lp_merge = torch.stack(filter_lp_merge, 0).view(
                len(sentences_b_all),
                len(filter_lp_merge) // len(sentences_b_all),
                -1).permute(1, 0, 2)
            if model2 is not None:
                filter_lp_merge2 = torch.stack(filter_lp_merge2, 0).view(
                    len(sentences_b_all),
                    len(filter_lp_merge2) // len(sentences_b_all),
                    -1).permute(1, 0, 2)
            # SHAPE: (batch_size)
            label_index_tensor = torch.tensor(
                [index_list.index(li[0]) for li in label_index_list])
            c_inc = np.array(
                metrics.analyze_prob(filter_lp_merge,
                                     label_index_tensor,
                                     output=False,
                                     method='sample'))
            c_inc_stat += c_inc
        elif dynamic == 'none':
            # SHAPE: (batch_size, 1, filter_vocab_size)
            filter_lp_merge = torch.stack(filter_lp_merge, 0).unsqueeze(1)

        # SHAPE: (batch_size, num_temp, filter_vocab_size)
        filter_lp_unmerge = filter_lp_merge

        if temp_model is not None:  # optimize template weights
            temp_model_, optimizer = temp_model
            if optimizer is None:  # predict
                filter_lp_merge = temp_model_(args.relation,
                                              filter_lp_merge.detach(),
                                              target=None)
            elif optimizer == 'precompute':  # pre-compute and save featuers
                lp = filter_lp_merge
                # SHAPE: (batch_size * num_temp)
                features = torch.gather(lp.contiguous().view(-1, lp.size(-1)),
                                        dim=1,
                                        index=label_index_tensor.repeat(
                                            lp.size(1)).view(-1, 1))
                features = features.view(-1, lp.size(1))
                features_list.append(features)
                if not bt_obj:
                    continue
            elif optimizer is not None:  # train on the fly
                features_list.append(
                    filter_lp_merge
                )  # collect features that will later be used in optimization
                if model2 is not None:
                    features_list2.append(
                        filter_lp_merge2
                    )  # collect features that will later be used in optimization
                label_index_tensor_list.append(
                    label_index_tensor)  # collect labels
                if not bt_obj:
                    continue
                else:
                    #filter_lp_merge = temp_model_(args.relation, filter_lp_merge.detach(), target=None)
                    filter_lp_merge = filter_lp_merge.mean(
                        1)  # use average prob to beam search

        if dynamic.startswith('real_lm_topk') or \
                dynamic.startswith('obj_lm_topk') or \
                dynamic.startswith('obj_lmgap_topk') or \
                dynamic.startswith('bt_topk'):  # dynamic ensemble
            real_lm_topk = min(
                int(dynamic[dynamic.find('topk') + 4:].split('-')[0]),
                len(consist_score_li))
            # SHAPE: (batch_size, num_temp)
            consist_score_li = torch.stack(consist_score_li, -1)
            # SHAPE: (batch_size, topk)
            consist_score, consist_ind = consist_score_li.topk(real_lm_topk,
                                                               dim=-1)
            # SHAPE: (batch_size, 1)
            consist_score = consist_score.min(-1, keepdim=True)[0]
            # SHAPE: (batch_size, num_temp, 1)
            consist_mask = (consist_score_li >=
                            consist_score).float().unsqueeze(-1)
            # avg over top k
            filter_lp_merge = filter_lp_merge * consist_mask
            filter_lp_merge = filter_lp_merge.sum(1) / consist_mask.sum(1)

        if bt_obj:  # choose top bt_obj objects and bach-translate subject
            # get the top bt_obj objects with highest probability
            used_vocab = vocab_subset if vocab_subset is not None else model.vocab
            temp_model_, optimizer = temp_model
            if optimizer is None:  # use beam search
                # SHAPE: (batch_size, bt_obj)
                objs_score, objs_ind = filter_lp_merge.topk(bt_obj, dim=-1)
                objs_ind = torch.sort(objs_ind,
                                      dim=-1)[0]  # the index must be ascending
            elif optimizer == 'precompute':  # use ground truth
                objs_ind = label_index_tensor.view(-1, 1)
                bt_obj = 1
            elif optimizer is not None:  # get both ground truth and beam search
                # SHAPE: (batch_size, bt_obj)
                objs_score, objs_ind = filter_lp_merge.topk(bt_obj, dim=-1)
                objs_ind = torch.cat(
                    [objs_ind, label_index_tensor.view(-1, 1)], -1)
                objs_ind = torch.sort(objs_ind,
                                      dim=-1)[0]  # the index must be ascending
                bt_obj += 1

            # bach translation
            sub_lp_list = []
            for sentences_b, samples_b_this in zip(
                    sentences_b_all, samples_b_all):  # iter over templates
                for obj_i in range(bt_obj):  # iter over objs
                    sentences_b_mask_sub = []
                    for s, obj_pred, obj_word in zip(samples_b_this, objs_ind,
                                                     obj_word_list):
                        replace_tok = used_vocab[obj_pred[obj_i].item()]
                        if optimizer == 'precompute':
                            assert replace_tok.strip() == obj_word.strip()
                        sentences_b_mask_sub.append([
                            replace_list(s['sub_masked_sentences'][0][0],
                                         model.mask_token, replace_tok)
                        ])
                    sub_mask = [
                        s['sub_masked_sentences'][1] for s in samples_b_this
                    ]
                    # TODO: only masked lm can do this
                    lp, _, _, tokens_tensor, mask_tensor = \
                        model.get_batch_generation(sentences_b_mask_sub, logger=logger, relation_mask=sub_mask)
                    # use avg prob of the sub as score
                    mask_tensor = mask_tensor.float()
                    lp_flat = lp.view(-1, lp.size(-1))
                    sub_lp = torch.gather(lp_flat,
                                          dim=1,
                                          index=tokens_tensor.view(
                                              -1, 1)).view(*lp.size()[:2])
                    sub_lp = sub_lp * mask_tensor
                    sub_lp_avg = sub_lp.sum(-1) / mask_tensor.sum(
                        -1)  # normalized prob
                    sub_lp_list.append(sub_lp_avg)

            # SHAPE: (batch_size, num_temp, top_obj_num)
            num_temp = len(sentences_b_all)
            sub_lp_list = torch.cat(sub_lp_list, 0).view(num_temp, bt_obj,
                                                         -1).permute(2, 0, 1)

            if optimizer == 'precompute':
                bt_features_list.append(sub_lp_list.squeeze(-1))
                continue
            elif optimizer is not None:
                sub_lp_list_expand = torch.zeros_like(filter_lp_unmerge)
                # SHAPE: (batch_size, num_temp, vocab_size)
                sub_lp_list_expand.scatter_(
                    -1,
                    objs_ind.unsqueeze(1).repeat(1, num_temp, 1), sub_lp_list)
                bt_features_list.append(sub_lp_list_expand)
                bt_obj -= 1
                continue

            # select obj prob
            expand_mask = torch.zeros_like(filter_lp_unmerge)
            expand_mask.scatter_(-1,
                                 objs_ind.unsqueeze(1).repeat(1, num_temp, 1),
                                 1)
            # SHAPE: (batch_size, num_temp, top_obj_num)
            obj_lp_list = torch.masked_select(filter_lp_unmerge,
                                              expand_mask.eq(1)).view(
                                                  -1, num_temp, bt_obj)

            # run temp model
            # SHAPE: (batch_size, vocab_size)
            filter_lp_merge_expand = torch.zeros_like(filter_lp_merge)
            # SHAPE: (batch_size, top_obj_num)
            filter_lp_merge = temp_model_(args.relation,
                                          torch.cat([obj_lp_list, sub_lp_list],
                                                    1),
                                          target=None)

            # expand results to vocab_size
            filter_lp_merge_expand.scatter_(-1, objs_ind, filter_lp_merge)
            filter_lp_merge = filter_lp_merge_expand + expand_mask[:, 0, :].log(
            )  # mask out other objs

        if len(filter_lp_merge.size()) == 2:
            filter_lp_merge = filter_lp_merge.unsqueeze(1)

        for temp_id in range(filter_lp_merge.size(1)):

            arguments = [{
                "original_log_probs": original_log_probs,
                "filtered_log_probs": filtered_log_probs,
                "token_ids": token_ids,
                "vocab": model.vocab,
                "label_index": label_index[0],
                "masked_indices": masked_indices,
                "interactive": args.interactive,
                "index_list": index_list,
                "sample": sample,
            } for original_log_probs, filtered_log_probs, token_ids,
                         masked_indices, label_index, sample in zip(
                             original_log_probs_list,
                             filter_lp_merge[:, :temp_id + 1].sum(1),
                             token_ids_list,
                             masked_indices_list,
                             label_index_list,
                             samples_b,
                         )]

            # single thread for debug
            # for isx,a in enumerate(arguments):
            #     print(samples_b[isx])
            #     run_thread(a)

            # multithread
            res = pool.map(run_thread, arguments)

            for idx, result in enumerate(res):

                result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

                logger.info("\n" + msg + "\n")

                sample = samples_b[idx]

                element = {}
                element["sample"] = sample
                element["uuid"] = sample["uuid"]
                element["token_ids"] = token_ids_list[idx]
                element["masked_indices"] = masked_indices_list[idx]
                element["label_index"] = label_index_list[idx]
                element["masked_topk"] = result_masked_topk
                element["sample_MRR"] = sample_MRR
                element["sample_Precision"] = sample_P
                element["sample_perplexity"] = sample_perplexity
                element["sample_Precision1"] = result_masked_topk["P_AT_1"]

                # print()
                # print("idx: {}".format(idx))
                # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
                # for yi in range(10):
                #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
                # print("masked_indices_list: {}".format(masked_indices_list[idx]))
                # print("sample_MRR: {}".format(sample_MRR))
                # print("sample_P: {}".format(sample_P))
                # print("sample: {}".format(sample))
                # print()

                MRR[temp_id] += sample_MRR
                Precision[temp_id] += sample_P
                Precision1[temp_id] += element["sample_Precision1"]
                P1_li[temp_id].append(element["sample_Precision1"])
                '''
                if element["sample_Precision1"] == 1:
                    print(element["sample"])
                    input(1)
                else:
                    print(element["sample"])
                    input(0)
                '''

                # the judgment of the annotators recording whether they are
                # evidence in the sentence that indicates a relation between two entities.
                num_yes = 0
                num_no = 0

                if "judgments" in sample:
                    # only for Google-RE
                    for x in sample["judgments"]:
                        if x["judgment"] == "yes":
                            num_yes += 1
                        else:
                            num_no += 1
                    if num_no >= num_yes:
                        samples_with_negative_judgement[temp_id] += 1
                        element["judgement"] = "negative"
                        MRR_negative[temp_id] += sample_MRR
                        Precision_negative[temp_id] += sample_P
                    else:
                        samples_with_positive_judgement[temp_id] += 1
                        element["judgement"] = "positive"
                        MRR_positive[temp_id] += sample_MRR
                        Precision_positivie[temp_id] += sample_P

                list_of_results[temp_id].append(element)

    if temp_model is not None:
        if temp_model[1] == 'precompute':
            features = torch.cat(features_list, 0)
            if bt_obj:
                bt_features = torch.cat(bt_features_list, 0)
                features = torch.cat([features, bt_features], 1)
            return features
        if temp_model[1] is not None:
            # optimize the model on the fly
            temp_model_, (optimizer, temperature) = temp_model
            temp_model_.cuda()
            # SHAPE: (batch_size, num_temp, vocab_size)
            features = torch.cat(features_list, 0)
            if model2 is not None:
                features2 = torch.cat(features_list2, 0)
            if bt_obj:
                bt_features = torch.cat(bt_features_list, 0)
                features = torch.cat([features, bt_features], 1)
            # compute weight
            # SHAPE: (batch_size,)
            label_index_tensor = torch.cat(label_index_tensor_list, 0)
            label_count = torch.bincount(label_index_tensor)
            label_count = torch.index_select(label_count, 0,
                                             label_index_tensor)
            sample_weight = F.softmax(
                temperature * torch.log(1.0 / label_count.float()),
                0) * label_index_tensor.size(0)
            min_loss = 1e10
            es = 0
            batch_size = 128
            for e in range(500):
                # loss = temp_model_(args.relation, features.cuda(), target=label_index_tensor.cuda(), use_softmax=True)
                loss_li = []
                for b in range(0, features.size(0), batch_size):
                    features_b = features[b:b + batch_size].cuda()
                    label_index_tensor_b = label_index_tensor[b:b +
                                                              batch_size].cuda(
                                                              )
                    sample_weight_b = sample_weight[b:b + batch_size].cuda()
                    loss = temp_model_(args.relation,
                                       features_b,
                                       target=label_index_tensor_b,
                                       sample_weight=sample_weight_b,
                                       use_softmax=True)
                    if model2 is not None:
                        features2_b = features2[b:b + batch_size].cuda()
                        loss2 = temp_model_(args.relation,
                                            features2_b,
                                            target=label_index_tensor_b,
                                            sample_weight=sample_weight_b,
                                            use_softmax=True)
                        loss = loss + loss2
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    loss_li.append(loss.cpu().item())
                dev_loss = np.mean(loss_li)
                if dev_loss - min_loss < -1e-3:
                    min_loss = dev_loss
                    es = 0
                else:
                    es += 1
                    if es >= 30:
                        print('early stop')
                        break
            temp_model_.cpu()
            return min_loss

    pool.close()
    pool.join()

    for temp_id in range(len(P1_li)):
        # stats
        # Mean reciprocal rank
        MRR[temp_id] /= len(list_of_results[temp_id])

        # Precision
        Precision[temp_id] /= len(list_of_results[temp_id])
        Precision1[temp_id] /= len(list_of_results[temp_id])

        msg = "all_samples: {}\n".format(len(all_samples))
        msg += "list_of_results: {}\n".format(len(list_of_results[temp_id]))
        msg += "global MRR: {}\n".format(MRR[temp_id])
        msg += "global Precision at 10: {}\n".format(Precision[temp_id])
        msg += "global Precision at 1: {}\n".format(Precision1[temp_id])

        if samples_with_negative_judgement[
                temp_id] > 0 and samples_with_positive_judgement[temp_id] > 0:
            # Google-RE specific
            MRR_negative[temp_id] /= samples_with_negative_judgement[temp_id]
            MRR_positive[temp_id] /= samples_with_positive_judgement[temp_id]
            Precision_negative[temp_id] /= samples_with_negative_judgement[
                temp_id]
            Precision_positivie[temp_id] /= samples_with_positive_judgement[
                temp_id]
            msg += "samples_with_negative_judgement: {}\n".format(
                samples_with_negative_judgement[temp_id])
            msg += "samples_with_positive_judgement: {}\n".format(
                samples_with_positive_judgement[temp_id])
            msg += "MRR_negative: {}\n".format(MRR_negative[temp_id])
            msg += "MRR_positive: {}\n".format(MRR_positive[temp_id])
            msg += "Precision_negative: {}\n".format(
                Precision_negative[temp_id])
            msg += "Precision_positivie: {}\n".format(
                Precision_positivie[temp_id])

        logger.info("\n" + msg + "\n")
        print("\n" + msg + "\n")

        # dump pickle with the result of the experiment
        all_results = dict(list_of_results=list_of_results[temp_id],
                           global_MRR=MRR,
                           global_P_at_10=Precision)
        with open("{}/result.pkl".format(log_directory), "wb") as f:
            pickle.dump(all_results, f)

        print('P1all {}'.format('\t'.join(map(str, P1_li[temp_id]))))

    print('meaning: {}'.format(c_inc_meaning))
    print('correct-incorrect {}'.format('\t'.join(
        map(str,
            (c_inc_stat[:, :-1] / (c_inc_stat[:, -1:] + 1e-5)).reshape(-1)))))

    return Precision1[-1]
Ejemplo n.º 9
0
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""
    [model_type_name] = args.models_names

    # print("------- Model: {}".format(model))
    # print("------- Args: {}".format(args))
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    # Mean reciprocal rank
    MRR = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    data = load_file(args.dataset_filename)
    # data = data[:2000]
    fact_pair = load_file(args.fact_pair_filename)

    # print("@@@@@@@@@@@@@@")
    # print(fact_pair)
    # print(len(fact_pair))
    # print("$$$$$$$$$$$$$$")

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    # print("!!!!!!!!!!!!!")
    # print(len(all_samples)) # 30847

    logger.info("\n" + ret_msg + "\n")

    # for sample in all_samples:
    #     sample["masked_sentences"] = [sample['evidences'][0]['masked_sentence']]

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)

    list_of_results = {d['subject'] + "_" + d['object']: [] for d in fact_pair}
    list_of_ranks = {d['subject'] + "_" + d['object']: [] for d in fact_pair}
    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        (
            original_log_probs_list,
            token_ids_list,
            masked_indices_list,
        ) = model.get_batch_generation(sentences_b, logger=logger)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)

        arguments = [{
            "original_log_probs": original_log_probs,
            "filtered_log_probs": filtered_log_probs,
            "token_ids": token_ids,
            "vocab": model.vocab,
            "label_index": label_index[0],
            "masked_indices": masked_indices,
            "interactive": args.interactive,
            "index_list": index_list,
            "sample": sample,
        } for original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample in zip(
                         original_log_probs_list,
                         filtered_log_probs_list,
                         token_ids_list,
                         masked_indices_list,
                         label_index_list,
                         samples_b,
                     )]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):
            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result
            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            obj = sample['obj_label']
            sub = sample['sub_label']
            element["masked_sentences"] = sample["masked_sentences"][0]
            element["uuid"] = sample["uuid"]
            element["subject"] = sub
            element["object"] = obj
            element["rank"] = int(result_masked_topk['rank'])
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]
            # element["sample"] = sample
            # element["token_ids"] = token_ids_list[idx]
            # element["masked_indices"] = masked_indices_list[idx]
            # element["label_index"] = label_index_list[idx]
            # element["masked_topk"] = result_masked_topk
            # element["sample_MRR"] = sample_MRR
            # element["sample_Precision"] = sample_P
            # element["sample_perplexity"] = sample_perplexity

            list_of_results[sub + "_" + obj].append(element)
            list_of_ranks[sub + "_" + obj].append(element["rank"])

            # print("~~~~~~ rank: {}".format(result_masked_topk['rank']))
            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element["sample_Precision1"]

            append_data_line_to_jsonl(
                "reproduction/data/TREx_filter/{}_rank_results.jsonl".format(
                    args.label), element)  # 3122

            # list_of_results.append(element)

    pool.close()
    pool.join()

    # stats
    # Mean reciprocal rank
    # MRR /= len(list_of_results)

    # # Precision
    # Precision /= len(list_of_results)
    # Precision1 /= len(list_of_results)

    # msg = "all_samples: {}\n".format(len(all_samples))
    # # msg += "list_of_results: {}\n".format(len(list_of_results))
    # msg += "global MRR: {}\n".format(MRR)
    # msg += "global Precision at 10: {}\n".format(Precision)
    # msg += "global Precision at 1: {}\n".format(Precision1)

    # logger.info("\n" + msg + "\n")
    # print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    # all_results = dict(
    #     list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision
    # )
    # with open("{}/result.pkl".format(log_directory), "wb") as f:
    #     pickle.dump(all_results, f)

    # print()
    # model_name = args.models_names[0]
    # if args.models_names[0] == "bert":
    #     model_name = args.bert_model_name
    # elif args.models_names[0] == "elmo":
    #     if args.bert_model_name == args.bert_model_name
    # else:

    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_results.jsonl".format(args.label), list_of_results) # 3122
    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_dic.jsonl".format(args.label), list_of_ranks) # 3122
    save_data_line_to_jsonl(
        "reproduction/data/TREx_filter/{}_rank_list.jsonl".format(args.label),
        list(list_of_ranks.values()))  # 3122

    return Precision1
Ejemplo n.º 10
0
def main(args, shuffle_data=True, model=None, zsre=True, context_filter=None, single_token=False, inference_top_k=10, decode_to_end=False, seed=''):

    if len(args.models_names) > 1:
        raise ValueError('Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    print(model)
    #if model is None:
    #    #model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger
        )

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.0
    MRR_negative = 0.0
    MRR_positive = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    # EM
    EM = 0.0

    # F1
    F1 = 0.0
    is_error = 0
    pred_too_large = 0
    pred_too_small = 0
    should_be_empty = 0
    should_be_not_empty = 0
    anchor_outside = 0
    mismatch = 0
    data = load_file(args.dataset_filename)

    print(len(data))

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        all_samples = lowercase_samples(data)
    else:
        # keep samples as they are
        all_samples = data

    all_samples, ret_msg = filter_samples(
        model, data, vocab_subset, args.max_sentence_length, args.template, is_zsre=zsre
    )

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        facts = []
        sub_objs = []
        for sample in all_samples:
            sub = sample["sub_label"]
            obj = sample["obj_label"]
            target = sample['reconstructed_word']
            if 'reconstructed_word' not in sample:
                raise Exception('Reconstructed word not in sample... fix this')
            else:
                if 'masked_sentences' in sample:
                    # Some of the masked sentences don't have a mask in them, need to find first with mask
                    context = None
                    for sent in sample['masked_sentences']:
                        if '[MASK]'  in sent:
                            context = sent.replace('[MASK]', sample['reconstructed_word'])
                            break
                    if context is None:
                        print('No valid context found, skipping sample')
                        continue
                else:
                    context = None
                    for evidence in sample['evidences']:
                        if not zsre:
                            if '[MASK]' in evidence['masked_sentence']:
                                context = evidence['masked_sentence'].replace('[MASK]', sample['reconstructed_word'])
                                break
                        else:
                            context = evidence['masked_sentence']
                    if context is None:
                        print('No valid context found, skipping sample')
                        continue

            #context = context.replace('(', '')
            #context = context.replace(')', '')
            
            if (sub, target, context) not in sub_objs:
                sub_objs.append((sub, target, context))
                if 'reconstructed_word' in sample:
                    facts.append((sub, obj, context, sample['reconstructed_word']))
                else:
                    facts.append((sub, obj, context, obj))

                #break
        local_msg = "distinct template facts: {}".format(len(facts))
        logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj, context, rw) = fact
            sample = {}
            sample["sub_label"] = sub
            sample["obj_label"] = obj
            sample["reconstructed_word"] = rw
            # sobstitute all sentences with a standard template
            sample['context'] = context
            sample["masked_sentences"] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK
            )
            #query = sample['masked_sentences'][0].replace(base.MASK, '')
            #sample['query'] = query
            #print(f'query={query}')
            #docs = retrieve_docs(query, ranker, conn, 30)
            #sample['context'] = docs[0]
            #print(f'docs={docs}')
            all_samples.append(sample)
    #else:
    #    for sample in all_samples:
    #        query = sample['masked_sentences'][0].replace(base.MASK, '')
    #        sample['query'] = query
    #        #print(f'query={query}')
    #        docs = retrieve_docs(query, ranker, conn, 1)
    #        sample['context'] = docs[0]
            

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []

    #example_results = []
    #cf_results= []
    #qualitative_results = []


    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]
        mymodel_probs_list = []
        predictions_list = []

        template = args.template.strip()
        for sample in samples_b:
            if decode_to_end:
                prediction = decode_lm(model, seed + args.question + ' = [Y]', sample['sub_label'], decode_to_end)
                predictions_list.append(prediction)
            else:

                prediction = decode_lm(model, args.template.strip(), sample['sub_label'], decode_to_end)
                predictions_list.append(prediction)

        original_log_probs_list, token_ids_list, masked_indices_list = model.get_batch_generation(
            sentences_b, logger=logger
        )
        #original_log_probs_list, token_ids_list, masked_indices_list, predictions = model.get_batch_generation(
        #predictions_list = predictions
        mymodel_probs_list = original_log_probs_list

        #obj_len = 0
        #for obj in gc.get_objects():
        #    try:
        #        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
        #            print(type(obj), obj.size())
        #            obj_len += 1
        #    except:
        #        pass
        #print(obj_len)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices
            )
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]
                    )
                )
            #elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
            #    raise ValueError(
            #        "object label {} not in model vocabulary".format(
            #            sample["obj_label"]
            #        )
            #    )
            elif vocab_subset is not None and sample["obj_label"] not in vocab_subset:
                raise ValueError(
                    "object label {} not in vocab subset".format(sample["obj_label"])
                )

            label_index_list.append(obj_label_id)

        arguments = [
            {   "is_zsre": zsre,
                "mymodel_probs": mymodel_probs,
                "original_log_probs": original_log_probs,
                "filtered_log_probs": filtered_log_probs,
                "target": sample["reconstructed_word"],
                "prediction": pred,
                "token_ids": token_ids,
                "vocab": model.vocab,
                "label_index": label_index[0] if len(label_index) > 0 else 0,
                "masked_indices": masked_indices,
                "interactive": args.interactive,
                "index_list": index_list,
                "sample": sample,
                #"prediction": prediction
            }
            for mymodel_probs, original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample, pred in zip(
                mymodel_probs_list,
                original_log_probs_list,
                filtered_log_probs_list,
                token_ids_list,
                masked_indices_list,
                label_index_list,
                samples_b,
                predictions_list
            )
        ]
            #for mymodel_probs, original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample, prediction in zip(
            #    mymodel_probs_list,
            #    original_log_probs_list,
            #    filtered_log_probs_list,
            #    token_ids_list,
            #    masked_indices_list,
            #    label_index_list,
            #    samples_b,
            #    predictions_list
            #)

        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)


        for idx, result in enumerate(res):

            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg, sample_em, sample_f1 = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element["sample"] = sample
            element["uuid"] = sample["uuid"]
            element["token_ids"] = token_ids_list[idx]
            element["masked_indices"] = masked_indices_list[idx]
            element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk
            element["sample_MRR"] = sample_MRR
            element["sample_Precision"] = sample_P
            element["sample_perplexity"] = sample_perplexity
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]
            element['sample_em'] = sample_em
            element['sample_f1'] = sample_f1

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element["sample_Precision1"]

            EM += sample_em
            F1 += sample_f1
            # the judgment of the annouators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if "judgments" in sample:
                # only for Google-RE
                for x in sample["judgments"]:
                    if x["judgment"] == "yes":
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element["judgement"] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element["judgement"] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P

            list_of_results.append(element)

    #df = pd.DataFrame(example_results)
    #df.to_csv('example_results.csv')


    pool.close()
    pool.join()

    # stats
    # Mean reciprocal rank
    MRR /= len(list_of_results)

    # Precision
    Precision /= len(list_of_results)
    Precision1 /= len(list_of_results)

    EM /= len(list_of_results)
    F1 /= len(list_of_results)

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)
    msg += "global EM {}\n".format(EM)
    msg += "global F1: {}\n".format(F1)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement
        )
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement
        )
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(
        list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision
    )
    with open("{}/result.pkl".format(log_directory), "wb") as f:
        pickle.dump(all_results, f)

    #logdf = pd.DataFrame(cf_results, columns=['Context', 'Subject', 'Template', 'Relation', 'Object', 'Score'])
    #logdf.to_csv("{}/cf_results4.csv".format(log_directory))

    #logdf = pd.DataFrame(qualitative_results, columns=['Template', 'Subject', 'Context', 'Object', 'Prediction', 'Expanded Prediction'])
    #logdf.to_csv("{}/qual_results.csv".format(log_directory))

    return Precision1, Precision, MRR, EM, F1