def main(args):
    sentences = [
        ["the cat is on the table ."],  # single-sentence instance
        ["the dog is sleeping on the sofa .",
         "he makes happy noises ."],  # two-sentence
    ]

    print("Language Models: {}".format(args.models_names))

    models = {}
    for lm in args.models_names:
        models[lm] = build_model_by_name(lm, args)

    for model_name, model in models.items():
        print("\n{}:".format(model_name))
        if args.cuda:
            model.try_cuda()
        contextual_embeddings, sentence_lengths, tokenized_text_list = model.get_contextual_embeddings(
            sentences)

        # contextual_embeddings is a list of tensors, one tensor for each layer.
        # Each element contains one layer of the representations with shape
        # (x, y, z).
        #   x    - the batch size
        #   y    - the sequence length of the batch
        #   z    - the length of each layer vector

        print(f'Number of layers: {len(contextual_embeddings)}')
        for layer_id, layer in enumerate(contextual_embeddings):
            print(f'Layer {layer_id} has shape: {layer.shape}')

        print("sentence_lengths: {}".format(sentence_lengths))
        print("tokenized_text_list: {}".format(tokenized_text_list))
Example #2
0
File: rc.py Project: jzbjyb/LAMA
def fill_cloze(args, input_jsonl, batch_size, beam_size):
    try_cuda = torch.cuda.is_available()
    model = build_model_by_name(args.models_names[0], args)
    with open(input_jsonl, 'r') as fin:
        data = [json.loads(l) for l in fin]
        # only keep qa pairs (1) with uppercase initials (2) <= 200 chars (3) not contain number
        data = [
            d for d in data
            if d['answer'][0].isupper() and len(d['sentence']) <= 200
            and not bool(re.search(r'\d', d['sentence']))
        ]
        print('#qa pairs {}'.format(len(data)))

    acc_token_li, acc_sent_li = [], []
    for b in tqdm(range(0, len(data), batch_size)):
        data_batch = data[b:b + batch_size]
        sents = []
        for d in data_batch:
            start = d['answer_start']
            end = start + len(d['answer'])
            sent = d['sentence'].replace('[', '(').replace(']', ')')
            sent = sent[:start] + '[' + sent[start:end] + ']' + sent[end:]
            sents.append(sent)
        acc_token, acc_sent = model.fill_cloze(sents,
                                               try_cuda=try_cuda,
                                               beam_size=beam_size)
        acc_token_li.append(acc_token)
        acc_sent_li.append(acc_sent)
        #print(acc_token, acc_sent)
    print('mean acc_token {}, mean acc_sent {}'.format(np.mean(acc_token_li),
                                                       np.mean(acc_sent_li)))
Example #3
0
def __vocab_intersection(models, filename):

    vocabularies = []

    for arg_dict in models:

        args = argparse.Namespace(**arg_dict)
        print(args)
        model = build_model_by_name(args.lm, args)

        vocabularies.append(model.vocab)
        print(type(model.vocab))

    if len(vocabularies) > 0:
        common_vocab = set(vocabularies[0])
        for vocab in vocabularies:
            common_vocab = common_vocab.intersection(set(vocab))

        # no special symbols in common_vocab
        for symbol in base.SPECIAL_SYMBOLS:
            if symbol in common_vocab:
                common_vocab.remove(symbol)

        # remove stop words
        from spacy.lang.en.stop_words import STOP_WORDS
        for stop_word in STOP_WORDS:
            if stop_word in common_vocab:
                print(stop_word)
                common_vocab.remove(stop_word)

        common_vocab = list(common_vocab)

        # remove punctuation and symbols
        nlp = spacy.load('en')
        manual_punctuation = ['(', ')', '.', ',']
        new_common_vocab = []
        for i in tqdm(range(len(common_vocab))):
            word = common_vocab[i]
            doc = nlp(word)
            token = doc[0]
            if (len(doc) != 1):
                print(word)
                for idx, tok in enumerate(doc):
                    print("{} - {}".format(idx, tok))
            elif word in manual_punctuation:
                pass
            elif token.pos_ == "PUNCT":
                print("PUNCT: {}".format(word))
            elif token.pos_ == "SYM":
                print("SYM: {}".format(word))
            else:
                new_common_vocab.append(word)
            # print("{} - {}".format(word, token.pos_))
        common_vocab = new_common_vocab

        # store common_vocab on file
        with open(filename, 'w') as f:
            for item in sorted(common_vocab):
                f.write("{}\n".format(item))
Example #4
0
def main(args):

    if not args.subject and not args.relation:
        raise ValueError(
            'You need to specify --subject and --relation to query language models.'
        )

    print('Language Models: {}'.format(args.models_names))

    models = {}
    for lm in args.models_names:
        models[lm] = build_model_by_name(lm, args)

    vocab_subset = None
    if args.common_vocab_filename is not None:
        common_vocab = load_vocab(args.common_vocab_filename)
        print('Common vocabulary size: {}'.format(len(common_vocab)))
        vocab_subset = [x for x in common_vocab]

    prompt_file = os.path.join(args.prompts, args.relation + '.jsonl')
    if not os.path.exists(prompt_file):
        raise ValueError('Relation "{}" does not exist.'.format(args.relation))
    prompts, weights = load_prompt_weights(prompt_file)

    for model_name, model in models.items():
        print('\n{}:'.format(model_name))

        index_list = None
        if vocab_subset is not None:
            filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
                vocab_subset)

        ensemble_log_probs = 0
        for prompt, weight in zip(prompts, weights):
            prompt = parse_prompt(prompt, args.subject, model.mask_token)
            log_prob, [token_ids
                       ], [masked_indices
                           ], _, _ = model.get_batch_generation([prompt],
                                                                try_cuda=True)

            if vocab_subset is not None:
                filtered_log_probs = model.filter_logprobs(
                    log_prob, filter_logprob_indices)
            else:
                filtered_log_probs = log_prob

            # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
            if masked_indices and len(masked_indices) > 0:
                filtered_log_probs = filtered_log_probs[0][masked_indices[0]]
                ensemble_log_probs += filtered_log_probs * weight

        ensemble_log_probs = F.log_softmax(ensemble_log_probs, dim=0)
        evaluation_metrics.get_ranking(ensemble_log_probs,
                                       model.vocab,
                                       label_index=None,
                                       index_list=index_list,
                                       topk=1000,
                                       P_AT=10,
                                       print_generation=True)
Example #5
0
File: rc.py Project: jzbjyb/LAMA
def refine_cloze(args):
    try_cuda = torch.cuda.is_available()
    model = build_model_by_name(args.models_names[0], args)
    sents = [
        'The theory of relativity [ is killed by ] Einstein .',
        'Windows [ is killed by ] Microsoft .'
    ]
    model.refine_cloze(sents, try_cuda=try_cuda)
Example #6
0
def main():
    with open(trial_2 + "priming/antonyme_adj.txt", "r") as ah:
        pairs = [(p.split()[0].strip(), p.split()[1].strip())
                 for p in ah.readlines()]
    print(pairs)
    args_stud = Args_Stud()
    bert = build_model_by_name("bert", args_stud)
    for pair in pairs:
        maskPairs(pair, bert)
    return
Example #7
0
File: rc.py Project: jzbjyb/LAMA
def fill_cloze_webquestion(args, input_file, batch_size, beam_size):
    try_cuda = torch.cuda.is_available()
    model = build_model_by_name(args.models_names[0], args)
    with open(input_file, 'r') as fin:
        # keep statement based on number of words in the answer
        sents = [l.strip() for l in fin]
        sents = [s for s in sents if len(re.split('\[|\]', s)[1].split()) == 1]
        print('#qa pairs {}'.format(len(sents)))

    acc_token_li, acc_sent_li = [], []
    for b in tqdm(range(0, len(sents), batch_size)):
        acc_token, acc_sent = model.fill_cloze(sents[b:b + batch_size],
                                               try_cuda=try_cuda,
                                               beam_size=beam_size)
        acc_token_li.append(acc_token)
        acc_sent_li.append(acc_sent)
        #print(acc_token, acc_sent)
    print('mean acc_token {}, mean acc_sent {}'.format(np.mean(acc_token_li),
                                                       np.mean(acc_sent_li)))
Example #8
0
File: rc.py Project: jzbjyb/LAMA
def fill_cloze_lama_squad(args, input_jsonl, batch_size, beam_size):
    try_cuda = torch.cuda.is_available()
    model = build_model_by_name(args.models_names[0], args)
    with open(input_jsonl, 'r') as fin:
        data = [json.loads(l) for l in fin]
        print('#qa pairs {}'.format(len(data)))

    acc_token_li, acc_sent_li = [], []
    for b in tqdm(range(0, len(data), batch_size)):
        data_batch = data[b:b + batch_size]
        sents = []
        for d in data_batch:
            sents.append(d['masked_sentences'][0].replace(
                '[MASK]', '[{}]'.format(d['obj_label'])))
        acc_token, acc_sent = model.fill_cloze(sents,
                                               try_cuda=try_cuda,
                                               beam_size=beam_size)
        acc_token_li.append(acc_token)
        acc_sent_li.append(acc_sent)
        #print(acc_token, acc_sent)
    print('mean acc_token {}, mean acc_sent {}'.format(np.mean(acc_token_li),
                                                       np.mean(acc_sent_li)))
Example #9
0
File: rc.py Project: jzbjyb/LAMA
def pattern_score(args, pattern_json, output_file):
    try_cuda = torch.cuda.is_available()
    model = build_model_by_name(args.models_names[0], args)
    with open(pattern_json, 'r') as fin:
        pattern_json = json.load(fin)

    batch_size = 32
    pid2pattern = defaultdict(lambda: {})
    for pid in tqdm(sorted(pattern_json)):
        #if not pid.startswith('P69_'):
        #    continue
        snippets = pattern_json[pid]['snippet']
        occs = pattern_json[pid]['occs']
        for (snippet, direction), count in snippets:
            if len(snippet) <= 5 or len(snippet) >= 100:  # longer than 5 chars
                continue
            loss = 0
            num_batch = np.ceil(len(occs) / batch_size)
            for b in range(0, len(occs), batch_size):
                occs_batch = occs[b:b + batch_size]
                sentences = [
                    '{} {} ({})'.format(h, snippet, t)
                    if direction == 1 else '{} {} ({})'.format(t, snippet, h)
                    for h, t in occs_batch
                ]
                #print((snippet, direction), count)
                #print(sentences)
                #input()
                loss += model.get_rc_loss(sentences,
                                          try_cuda=try_cuda)[0].item()
            pid2pattern[pid][snippet] = loss / num_batch
        #print(pid)
        #print(sorted(pid2pattern[pid].items(), key=lambda x: x[1]))
        #input()

    with open(output_file, 'w') as fout:
        for pid, pats in pid2pattern.items():
            pats = sorted(pats.items(), key=lambda x: x[1])
            fout.write('{}\t{}\n'.format(pid, json.dumps(pats)))
Example #10
0
def encode(args, sentences, sort_input=False):
    """Create an EncodedDataset from a list of sentences

    Parameters:
    sentences (list[list[string]]): list of elements. Each element is a list
                                    that contains either a single sentence
                                    or two sentences
    sort_input (bool): if true, sort sentences by number of tokens in them

    Returns:
    dataset (EncodedDataset): an object that contains the contextual
                              representations of the input sentences
    """
    print("Language Models: {}".format(args.lm))
    model = build_model_by_name(args.lm, args)

    # sort sentences by number of tokens in them to make sure that in all
    # batches there are sentence with a similar numbers of tokens
    if sort_input:
        sorted(sentences, key=lambda k: len(" ".join(k).split()))

    encoded_sents = []
    for current_batch in tqdm(_batchify(sentences, args.batch_size)):
        embeddings, sent_lens, tokenized_sents = model.get_contextual_embeddings(
            current_batch)

        agg_embeddings = _aggregate_layers(
            embeddings)  # [#batchsize, #max_sent_len, #dim]
        sent_embeddings = [
            agg_embeddings[i, :l] for i, l in enumerate(sent_lens)
        ]
        encoded_sents.extend(
            list(zip(sent_embeddings, sent_lens, tokenized_sents)))

    dataset = EncodedDataset(encoded_sents)
    return dataset
Example #11
0
File: rc.py Project: jzbjyb/LAMA
def main(args):
    try_cuda = torch.cuda.is_available()

    model = build_model_by_name(args.models_names[0], args)
    model.add_hooks()
    embedding_weight = model.get_embedding_weight()

    sentences = [
        'The theory of relativity was developed by Einstein.',
        'Windows was developed by Microsoft.'
    ]

    sentences = ['(The theory of relativity) was found by (Einstein.)']

    sentences = ['(Barack Obama) was born in (Hawaii.)']

    sentences = ['Him (speaks English.)']

    sentences = [
        '[The theory of relativity was] killed [by Einstein].',
        '[Windows was] killed [by Microsoft].'
    ]

    for _ in range(50):
        for token_to_flip in range(0,
                                   3):  # TODO: for each token in the trigger
            # back propagation
            model.zero_grad()
            loss, tokens, _, unbracket_mask = model.get_rc_loss(
                sentences, try_cuda=try_cuda)
            # SHAPE: (batch_size, seq_len)
            unbracket_mask = unbracket_mask.bool()
            loss.backward()
            print(loss)

            # SHAPE: (batch_size, seq_len, emb_dim)
            grad = base_connector.extracted_grads[0]
            bs, _, emb_dim = grad.size()
            base_connector.extracted_grads = []

            # TODO
            # SHAPE: (batch_size, unbracket_len, emb_dim)
            #grad = grad.masked_select(F.pad(unbracket_mask, (1, 0), 'constant', False)[:, :-1].unsqueeze(-1)).view(bs, -1, emb_dim)
            grad = grad.masked_select(unbracket_mask.unsqueeze(-1)).view(
                bs, -1, emb_dim)
            # SHAPE: (1, emb_dim)
            grad = grad.sum(dim=0)[token_to_flip].unsqueeze(0)
            print((grad * grad).sum().sqrt())

            # SHAPE: (batch_size, unbracket_len)
            tokens = tokens.masked_select(unbracket_mask).view(bs, -1)
            token_tochange = tokens[0][token_to_flip].item()

            # Use hotflip (linear approximation) attack to get the top num_candidates
            candidates = attacks.hotflip_attack(grad,
                                                embedding_weight,
                                                [token_tochange],
                                                increase_loss=False,
                                                num_candidates=10)[0]
            print(model.tokenizer.convert_ids_to_tokens([token_tochange]),
                  model.tokenizer.convert_ids_to_tokens(candidates))
            input()
Example #12
0
def run_experiments(
    relations,
    data_path_pre,
    data_path_post,
    input_param={
        "lm":
        "bert",
        "label":
        "bert_large",
        "models_names": ["bert"],
        "bert_model_name":
        "bert-large-cased",
        "bert_model_dir":
        "pre-trained_language_models/bert/cased_L-24_H-1024_A-16",
    },
    use_negated_probes=False,
):
    model = None
    pp = pprint.PrettyPrinter(width=41, compact=True)

    all_Precision1 = []
    type_Precision1 = defaultdict(list)
    type_count = defaultdict(list)

    results_file = open("last_results.csv", "w+")
    uid_list_all, mask_feature_list_all, answers_list_all = [], [], []
    all_correct_uuids = []
    total_modified_correct, total_unmodified_correct = 0, 0
    total_modified_num, total_unmodified_num = 0, 0
    for relation in relations:
        # if "type" not in relation or relation["type"] != "1-1":
        #     continue

        pp.pprint(relation)
        PARAMETERS = {
            "dataset_filename":
            "{}{}{}".format(data_path_pre, relation["relation"],
                            data_path_post),
            "common_vocab_filename":
            'pre-trained_language_models/bert/cased_L-12_H-768_A-12/vocab.txt',  #"pre-trained_language_models/common_vocab_cased.txt",
            "template":
            "",
            "bert_vocab_name":
            "vocab.txt",
            "batch_size":
            32,
            "logdir":
            "output",
            "full_logdir":
            "output/results/{}/{}".format(input_param["label"],
                                          relation["relation"]),
            "lowercase":
            False,
            "max_sentence_length":
            512,  # change to 512 later
            "threads":
            2,
            "interactive":
            False,
            "use_negated_probes":
            use_negated_probes,
            "return_features":
            False,
            "uuid_list": []
        }

        if "template" in relation:
            PARAMETERS["template"] = relation["template"]
            if use_negated_probes:
                PARAMETERS["template_negated"] = relation["template_negated"]

        PARAMETERS.update(input_param)
        print(PARAMETERS)

        args = argparse.Namespace(**PARAMETERS)

        # see if file exists
        try:
            data = load_file(args.dataset_filename)
        except Exception as e:
            print("Relation {} excluded.".format(relation["relation"]))
            print("Exception: {}".format(e))
            continue

        if model is None:
            [model_type_name] = args.models_names
            model = build_model_by_name(model_type_name, args)

        if getattr(args, 'output_feature_path', ''):
            # Get the features for kNN-LM. Ignore this part if only obtaining the correct-predicted questions.
            Precision1, total_unmodified, Precision1_modified, total_modified, uid_list, mask_feature_list, answers_list = run_evaluation(
                args, shuffle_data=False, model=model)
            if len(uid_list) > 0:
                uid_list_all.extend(uid_list)
                mask_feature_tensor = torch.cat(mask_feature_list, dim=0)
                mask_feature_list_all.append(mask_feature_tensor)
                answers_list_all.extend(answers_list)

        else:
            Precision1, total_unmodified, Precision1_modified, total_modified, correct_uuids = run_evaluation(
                args, shuffle_data=False, model=model)
            all_correct_uuids.extend(correct_uuids)

        total_modified_correct += Precision1_modified
        total_unmodified_correct += Precision1
        total_modified_num += total_modified
        total_unmodified_num += total_unmodified
        print("P@1 : {}".format(Precision1), flush=True)
        all_Precision1.append(Precision1)

        results_file.write("{},{}\n".format(relation["relation"],
                                            round(Precision1 * 100, 2)))
        results_file.flush()

        if "type" in relation:
            type_Precision1[relation["type"]].append(Precision1)
            data = load_file(PARAMETERS["dataset_filename"])
            type_count[relation["type"]].append(len(data))

    mean_p1 = statistics.mean(all_Precision1)
    print("@@@ {} - mean P@1: {}".format(input_param["label"], mean_p1))
    print("Unmodified acc: {}, modified acc: {}".format(
        total_unmodified_correct / float(total_unmodified_num),
        0 if total_modified_num == 0 else total_modified_correct /
        float(total_modified_num)))
    results_file.close()

    for t, l in type_Precision1.items():

        print(
            "@@@ ",
            input_param["label"],
            t,
            statistics.mean(l),
            sum(type_count[t]),
            len(type_count[t]),
            flush=True,
        )
    if len(uid_list_all) > 0:
        out_dict = {
            'mask_features': torch.cat(mask_feature_list_all, dim=0),
            'uuids': uid_list_all,
            'obj_labels': answers_list_all
        }
        torch.save(out_dict, 'datastore/ds_change32.pt')
    if len(all_correct_uuids) > 0:
        if not os.path.exists('modification'):
            os.makedirs('modification')
        json.dump(all_correct_uuids,
                  open('modification/correct_uuids.json', 'w'))
    return mean_p1, all_Precision1
Example #13
0
        #
        # print("sentence_lengths: {}".format(sentence_lengths))
        # print("tokenized_text_list: {}".format(tokenized_text_list))

        return contextual_embeddings, tokenized_text_list


###Generate modified args for the lama library (aka imitate the input of a command line)
sys.argv = ['My code for HW3 Task1', '--lm', 'bert']
parser = options.get_general_parser()
args = options.parse_args(parser)

###building the model only once (not inside the method for each line)
models = {}
for lm in args.models_names:
    models[lm] = build_model_by_name(lm, args)

###opening the file
with jsonlines.open('./train_testing_output.jsonl') as reader:
    for line in reader.iter():
        dictionary = line

        ###masking the text
        text = dictionary['claim']
        start_masked = dictionary["entity"]['start_character']
        end_masked = dictionary["entity"]['end_character']

        text_masked = text[0:start_masked] + '[MASK]' + text[
            end_masked:len(text)]

        ### get embeddings
Example #14
0
def main(args):

    if not args.text and not args.interactive:
        msg = "ERROR: either you start LAMA eval_generation with the " \
              "interactive option (--i) or you pass in input a piece of text (--t)"
        raise ValueError(msg)

    stopping_condition = True

    print("Language Models: {}".format(args.models_names))

    models = {}
    for lm in args.models_names:
        models[lm] = build_model_by_name(lm, args)

    vocab_subset = None
    if args.common_vocab_filename is not None:
        common_vocab = load_vocab(args.common_vocab_filename)
        print("common vocabulary size: {}".format(len(common_vocab)))
        vocab_subset = [x for x in common_vocab]

    while stopping_condition:
        if args.text:
            text = args.text
            stopping_condition = False
        else:
            text = input("insert text:")

        if args.split_sentence:
            import spacy
            # use spacy to tokenize input sentence
            nlp = spacy.load(args.spacy_model)
            tokens = nlp(text)
            print(tokens)
            sentences = []
            for s in tokens.sents:
                print(" - {}".format(s))
                sentences.append(s.text)
        else:
            sentences = [text]

        if len(sentences) > 2:
            print(
                "WARNING: only the first two sentences in the text will be considered!"
            )
            sentences = sentences[:2]

        for model_name, model in models.items():
            print("\n{}:".format(model_name))
            original_log_probs_list, [token_ids], [
                masked_indices
            ] = model.get_batch_generation([sentences], try_cuda=False)

            index_list = None
            if vocab_subset is not None:
                # filter log_probs
                filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
                    vocab_subset)
                filtered_log_probs_list = model.filter_logprobs(
                    original_log_probs_list, filter_logprob_indices)
            else:
                filtered_log_probs_list = original_log_probs_list

            # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
            if masked_indices and len(masked_indices) > 0:
                evaluation_metrics.get_ranking(filtered_log_probs_list[0],
                                               masked_indices,
                                               model.vocab,
                                               index_list=index_list)

            # prediction and perplexity for the whole softmax
            print_sentence_predictions(original_log_probs_list[0],
                                       token_ids,
                                       model.vocab,
                                       masked_indices=masked_indices)
def main():
    args_stud = Args_Stud()
    bert = build_model_by_name("bert", args_stud)
    vocab_subset = None
    f = open('./LAMA/lama/collected_paths.json', )
    path_s = json.load(f)
    sent_path_ = path_s['sent2eval']
    prem_path = path_s['premis2eval']
    res_path_ = path_s["res_file"]
    paths = os.listdir(sent_path_)
    for path in paths:
        sent_path = sent_path_ + path
        res_path = res_path_ + path.split(".")[0].split(
            "_")[-2] + "_" + path.split(".")[0].split("_")[-2] + "/"
        os.makedirs(res_path, exist_ok=True)
        with open(sent_path, "r", encoding="utf8") as sf:
            sentences = [s.rstrip for s in sf.readlines()]
        print(sentences)
        with open(prem_path, "r") as pf:
            premisses = [p.rstrip() for p in pf.readlines()]
        data = {}
        for s in sentences:
            data[s] = []
            original_log_probs_list, [token_ids], [
                masked_indices
            ] = bert.get_batch_generation([[s]], try_cuda=True)
            index_list = None
            if vocab_subset is not None:
                # filter log_probs
                filter_logprob_indices, index_list = bert.init_indices_for_filter_logprobs(
                    vocab_subset)
                filtered_log_probs_list = bert.filter_logprobs(
                    original_log_probs_list, filter_logprob_indices)
            else:
                filtered_log_probs_list = original_log_probs_list

            # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
            if masked_indices and len(masked_indices) > 0:
                MRR, P_AT_X, experiment_result, return_msg = evaluation_metrics.get_ranking(
                    filtered_log_probs_list[0],
                    masked_indices,
                    bert.vocab,
                    index_list=index_list)
                res = experiment_result["topk"]
                for r in res:
                    data[s].append((r["token_word_form"], r["log_prob"]))
        with open(res_path + "NoPrem.json", "w+", encoding="utf-8") as f:
            json.dump(data, f)
        for pre in premisses:
            for s in sentences:
                data[s] = []
                sentence = [str(pre) + "? " + s]
                original_log_probs_list, [token_ids], [
                    masked_indices
                ] = bert.get_batch_generation([sentence], try_cuda=False)
                index_list = None
                if vocab_subset is not None:
                    # filter log_probs
                    filter_logprob_indices, index_list = bert.init_indices_for_filter_logprobs(
                        vocab_subset)
                    filtered_log_probs_list = bert.filter_logprobs(
                        original_log_probs_list, filter_logprob_indices)
                else:
                    filtered_log_probs_list = original_log_probs_list

                # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
                if masked_indices and len(masked_indices) > 0:
                    MRR, P_AT_X, experiment_result, return_msg = evaluation_metrics.get_ranking(
                        filtered_log_probs_list[0],
                        masked_indices,
                        bert.vocab,
                        index_list=index_list)
                    res = experiment_result["topk"]
                    for r in res:
                        data[s].append((r["token_word_form"], r["log_prob"]))
            with open(res_path + pre + ".json", "w+", encoding="utf-8") as f:
                json.dump(data, f)
def main(args):

    #Loading the JSON datasets
    #For each dataset we create a numpy array(len(dataset),768)
    #that we'll fill with Bert word embeddings

    #Pre-processed train dataset
    with open('./new_data_train.json') as json_file:
        json_train = json.load(json_file)

    x_train = np.zeros((len(json_train), 768))

    #Pre-processed dev dataset

    with open('./new_data_dev.json') as json_file:
        json_test = json.load(json_file)

    x_test = np.zeros((len(json_test), 768))

    #Official test set

    json_test_official = []
    with open('./singletoken_test_fever_homework_NLP.jsonl') as json_file:
        for item in json_lines.reader(json_file):
            json_test_official.append(item)

    x_test_official = np.zeros((len(json_test_official), 768))

    models = {}
    for lm in args.models_names:
        models[lm] = build_model_by_name(lm, args)

    #For each model we do a for loop for each dataset to retrieve the word embeddings with Bert

    for model_name, model in models.items():
        for index in range(len(json_train)):

            sentences = [[json_train[index]['claim']
                          ]  #We pass to the model each claim of each datapoint
                         ]
            print("\n{}:".format(model_name))

            contextual_embeddings, sentence_lengths, tokenized_text_list = model.get_contextual_embeddings(
                sentences)
            x_train[index] = contextual_embeddings[11][0][
                0]  #We select the CLS vector of the last layer
            #
            print(tokenized_text_list)

        #We do the same for the other two datasets
        for index in range(len(json_test)):

            sentences = [[json_test[index]['claim']]]
            print("\n{}:".format(model_name))

            contextual_embeddings, sentence_lengths, tokenized_text_list = model.get_contextual_embeddings(
                sentences)
            x_test[index] = contextual_embeddings[11][0][0]

            print(tokenized_text_list)

        for index in range(len(json_test_official)):

            sentences = [[json_test_official[index]['claim']]]
            print("\n{}:".format(model_name))

            contextual_embeddings, sentence_lengths, tokenized_text_list = model.get_contextual_embeddings(
                sentences)
            x_test_official[index] = contextual_embeddings[11][0][0]

            print(tokenized_text_list)

    return (x_train, json_train, x_test, json_test, x_test_official,
            json_test_official)
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""
    [model_type_name] = args.models_names

    # print("------- Model: {}".format(model))
    # print("------- Args: {}".format(args))
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    # Mean reciprocal rank
    MRR = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    data = load_file(args.dataset_filename)
    # data = data[:2000]
    fact_pair = load_file(args.fact_pair_filename)

    # print("@@@@@@@@@@@@@@")
    # print(fact_pair)
    # print(len(fact_pair))
    # print("$$$$$$$$$$$$$$")

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    # print("!!!!!!!!!!!!!")
    # print(len(all_samples)) # 30847

    logger.info("\n" + ret_msg + "\n")

    # for sample in all_samples:
    #     sample["masked_sentences"] = [sample['evidences'][0]['masked_sentence']]

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)

    list_of_results = {d['subject'] + "_" + d['object']: [] for d in fact_pair}
    list_of_ranks = {d['subject'] + "_" + d['object']: [] for d in fact_pair}
    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        (
            original_log_probs_list,
            token_ids_list,
            masked_indices_list,
        ) = model.get_batch_generation(sentences_b, logger=logger)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)

        arguments = [{
            "original_log_probs": original_log_probs,
            "filtered_log_probs": filtered_log_probs,
            "token_ids": token_ids,
            "vocab": model.vocab,
            "label_index": label_index[0],
            "masked_indices": masked_indices,
            "interactive": args.interactive,
            "index_list": index_list,
            "sample": sample,
        } for original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample in zip(
                         original_log_probs_list,
                         filtered_log_probs_list,
                         token_ids_list,
                         masked_indices_list,
                         label_index_list,
                         samples_b,
                     )]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):
            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result
            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            obj = sample['obj_label']
            sub = sample['sub_label']
            element["masked_sentences"] = sample["masked_sentences"][0]
            element["uuid"] = sample["uuid"]
            element["subject"] = sub
            element["object"] = obj
            element["rank"] = int(result_masked_topk['rank'])
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]
            # element["sample"] = sample
            # element["token_ids"] = token_ids_list[idx]
            # element["masked_indices"] = masked_indices_list[idx]
            # element["label_index"] = label_index_list[idx]
            # element["masked_topk"] = result_masked_topk
            # element["sample_MRR"] = sample_MRR
            # element["sample_Precision"] = sample_P
            # element["sample_perplexity"] = sample_perplexity

            list_of_results[sub + "_" + obj].append(element)
            list_of_ranks[sub + "_" + obj].append(element["rank"])

            # print("~~~~~~ rank: {}".format(result_masked_topk['rank']))
            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element["sample_Precision1"]

            append_data_line_to_jsonl(
                "reproduction/data/TREx_filter/{}_rank_results.jsonl".format(
                    args.label), element)  # 3122

            # list_of_results.append(element)

    pool.close()
    pool.join()

    # stats
    # Mean reciprocal rank
    # MRR /= len(list_of_results)

    # # Precision
    # Precision /= len(list_of_results)
    # Precision1 /= len(list_of_results)

    # msg = "all_samples: {}\n".format(len(all_samples))
    # # msg += "list_of_results: {}\n".format(len(list_of_results))
    # msg += "global MRR: {}\n".format(MRR)
    # msg += "global Precision at 10: {}\n".format(Precision)
    # msg += "global Precision at 1: {}\n".format(Precision1)

    # logger.info("\n" + msg + "\n")
    # print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    # all_results = dict(
    #     list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision
    # )
    # with open("{}/result.pkl".format(log_directory), "wb") as f:
    #     pickle.dump(all_results, f)

    # print()
    # model_name = args.models_names[0]
    # if args.models_names[0] == "bert":
    #     model_name = args.bert_model_name
    # elif args.models_names[0] == "elmo":
    #     if args.bert_model_name == args.bert_model_name
    # else:

    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_results.jsonl".format(args.label), list_of_results) # 3122
    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_dic.jsonl".format(args.label), list_of_ranks) # 3122
    save_data_line_to_jsonl(
        "reproduction/data/TREx_filter/{}_rank_list.jsonl".format(args.label),
        list(list_of_ranks.values()))  # 3122

    return Precision1
Example #18
0
def main(args,
         shuffle_data=True,
         model=None,
         model2=None,
         refine_template=False,
         get_objs=False,
         dynamic='none',
         use_prob=False,
         bt_obj=None,
         temp_model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    #print(model)
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename,
                                  lower=args.lowercase)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    if dynamic == 'all_topk':  # save topk results for different k
        # stats
        samples_with_negative_judgement = [
            0 for _ in range(len(args.template))
        ]
        samples_with_positive_judgement = [
            0 for _ in range(len(args.template))
        ]

        # Mean reciprocal rank
        MRR = [0.0 for _ in range(len(args.template))]
        MRR_negative = [0.0 for _ in range(len(args.template))]
        MRR_positive = [0.0 for _ in range(len(args.template))]

        # Precision at (default 10)
        Precision = [0.0 for _ in range(len(args.template))]
        Precision1 = [0.0 for _ in range(len(args.template))]
        Precision_negative = [0.0 for _ in range(len(args.template))]
        Precision_positivie = [0.0 for _ in range(len(args.template))]

        list_of_results = [[] for _ in range(len(args.template))]
        P1_li = [[] for _ in range(len(args.template))]
    else:
        # stats
        samples_with_negative_judgement = [0]
        samples_with_positive_judgement = [0]

        # Mean reciprocal rank
        MRR = [0.0]
        MRR_negative = [0.0]
        MRR_positive = [0.0]

        # Precision at (default 10)
        Precision = [0.0]
        Precision1 = [0.0]
        Precision_negative = [0.0]
        Precision_positivie = [0.0]

        list_of_results = [[]]
        P1_li = [[]]

    data = load_file(args.dataset_filename)
    for s in data:
        s['raw_sub_label'] = s['sub_label']
        s['raw_obj_label'] = s['obj_label']

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        data = lowercase_samples(data)

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print('#head-tails {} -> {}'.format(len(data), len(all_samples)))

    samples_batches_li, sentences_batches_li = [], []
    for template in args.template:
        # if template is active (1) use a single example for (sub,obj) and (2) ...
        if template and template != "":
            facts = []
            samples = []
            for sample in all_samples:
                sub = sample["sub_label"]
                obj = sample["obj_label"]
                if (sub, obj) not in facts:
                    facts.append((sub, obj))
                    samples.append(sample)
            local_msg = "distinct template facts: {}".format(len(facts))
            logger.info("\n" + local_msg + "\n")
            new_all_samples = []
            for fact, raw_sample in zip(facts, samples):
                (sub, obj) = fact
                sample = {}
                sample["sub_label"] = sub
                sample["obj_label"] = obj
                # sobstitute all sentences with a standard template
                sample["masked_sentences"] = parse_template(
                    template.strip(), raw_sample["raw_sub_label"].strip()
                    if args.upper_entity else sample["sub_label"].strip(),
                    model.mask_token)
                sub_uri = raw_sample[
                    'sub_uri'] if 'sub_uri' in raw_sample else raw_sample['sub']
                sample['entity_list'] = get_entity_list(
                    template.strip(), raw_sample['raw_sub_label'].strip(),
                    sub_uri, None, None)
                if dynamic.startswith('bt_topk') or (temp_model is not None
                                                     and bt_obj):
                    sample['sub_masked_sentences'] = parse_template_tokenize(
                        template.strip(),
                        sample["sub_label"].strip(),
                        model,
                        mask_part='sub')
                if dynamic == 'real_lm' or dynamic.startswith('real_lm_topk'):
                    sample["tokenized_sentences"] = parse_template_tokenize(
                        template.strip(),
                        sample["sub_label"].strip(),
                        model,
                        mask_part='relation')
                # substitute sub and obj placeholder in template with corresponding str
                # and add bracket to the relational phrase
                sample['bracket_sentences'] = bracket_relational_phrase(
                    template.strip(), sample['sub_label'].strip(),
                    sample['obj_label'].strip())
                new_all_samples.append(sample)

        # create uuid if not present
        i = 0
        for sample in new_all_samples:
            if "uuid" not in sample:
                sample["uuid"] = i
            i += 1

        if args.lowercase and not args.upper_entity:
            # lowercase all samples
            logger.info("lowercasing all samples...")
            new_all_samples = lowercase_samples(new_all_samples)

        # shuffle data
        if shuffle_data:
            perm = np.random.permutation(len(new_all_samples))
            new_all_samples = np.array(new_all_samples)[perm]
            raise Exception

        samples_batches, sentences_batches, ret_msg = batchify(
            new_all_samples, args.batch_size)
        logger.info("\n" + ret_msg + "\n")
        samples_batches_li.append(samples_batches)
        sentences_batches_li.append(sentences_batches)

        sub_obj_labels = [(sample['sub_label'], sample['obj_label'])
                          for batch in samples_batches for sample in batch]
        if get_objs:
            print('sub_obj_label {}'.format('\t'.join(
                map(lambda p: '{}\t{}'.format(*p), sub_obj_labels))))
            return

        if refine_template:
            bracket_sentences = [
                sample['bracket_sentences'] for sample in new_all_samples
            ]
            new_temp = model.refine_cloze(bracket_sentences,
                                          batch_size=32,
                                          try_cuda=True)
            new_temp = replace_template(template.strip(), ' '.join(new_temp))
            print('old temp: {}'.format(template.strip()))
            print('new temp: {}'.format(new_temp))
            return new_temp

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)

    samples_batches_li = list(zip(*samples_batches_li))
    sentences_batches_li = list(zip(*sentences_batches_li))

    c_inc_meaning = ['top12 prob gap', 'top1 prob']
    c_inc_stat = np.zeros((2, 3))  # [[*, c_num], [*, inc_num]]

    loss_list = []
    features_list = []
    features_list2 = []
    bt_features_list = []
    label_index_tensor_list = []

    for i in tqdm(range(len(samples_batches_li))):

        samples_b_all = samples_batches_li[i]
        sentences_b_all = sentences_batches_li[i]

        filter_lp_merge = None
        filter_lp_merge2 = None
        samples_b = samples_b_all[-1]
        max_score = float('-inf')
        consist_score_li = []

        samples_b_prev = None
        for sentences_b, samples_b_this in zip(sentences_b_all, samples_b_all):
            if samples_b_prev is not None:
                for ps, ts in zip(samples_b_prev, samples_b_this):
                    assert ps['uuid'] == ts['uuid']

            entity_list_b = [s['entity_list'] for s in samples_b_this]
            # TODO: add tokens_tensor and mask_tensor for more models
            original_log_probs_list, token_ids_list, masked_indices_list, tokens_tensor, mask_tensor = \
                model.get_batch_generation(sentences_b, logger=logger, entity_list=entity_list_b)
            if model2 is not None:
                original_log_probs_list2, token_ids_list2, masked_indices_list2, tokens_tensor2, mask_tensor2 = \
                    model2.get_batch_generation(sentences_b, logger=logger)

            if use_prob:  # use prob instead of log prob
                original_log_probs_list = original_log_probs_list.exp()
                if model2 is not None:
                    original_log_probs_list2 = original_log_probs_list2.exp()

            if dynamic == 'real_lm' or dynamic.startswith('real_lm_topk'):
                sentences_b_mask_rel = [
                    s['tokenized_sentences'][0] for s in samples_b_this
                ]
                relation_mask = [
                    s['tokenized_sentences'][1] for s in samples_b_this
                ]
                consist_log_probs_list, _, _, tokens_tensor, mask_tensor = \
                    model.get_batch_generation(sentences_b_mask_rel, logger=logger, relation_mask=relation_mask)
            else:
                consist_log_probs_list = original_log_probs_list

            if dynamic == 'lm' or dynamic == 'real_lm' or dynamic.startswith(
                    'real_lm_topk'):
                # use avg prob of the templates as score
                mask_tensor = mask_tensor.float()
                consist_log_probs_list_flat = consist_log_probs_list.view(
                    -1, consist_log_probs_list.size(-1))
                token_logprob = torch.gather(
                    consist_log_probs_list_flat,
                    dim=1,
                    index=tokens_tensor.view(
                        -1, 1)).view(*consist_log_probs_list.size()[:2])
                token_logprob = token_logprob * mask_tensor
                consist_score = token_logprob.sum(-1) / mask_tensor.sum(
                    -1)  # normalized prob
            '''
            if vocab_subset is not None:
                # filter log_probs
                filtered_log_probs_list = model.filter_logprobs(
                    original_log_probs_list, filter_logprob_indices
                )
            else:
                filtered_log_probs_list = original_log_probs_list
            '''

            # get the prediction probability
            if vocab_subset is not None:
                filtered_log_probs_list = [
                    flp[masked_indices_list[ind][0]].index_select(
                        dim=-1, index=filter_logprob_indices)
                    for ind, flp in enumerate(original_log_probs_list)
                ]
                if model2 is not None:
                    filtered_log_probs_list2 = [
                        flp[masked_indices_list2[ind][0]].index_select(
                            dim=-1, index=filter_logprob_indices)
                        for ind, flp in enumerate(original_log_probs_list2)
                    ]
            else:
                filtered_log_probs_list = [
                    flp[masked_indices_list[ind][0]]
                    for ind, flp in enumerate(original_log_probs_list)
                ]
                if model2 is not None:
                    filtered_log_probs_list2 = [
                        flp[masked_indices_list2[ind][0]]
                        for ind, flp in enumerate(original_log_probs_list2)
                    ]

            if dynamic.startswith('bt_topk'):
                obj_topk = int(dynamic.rsplit('-', 1)[1])
                top_obj_pred = [
                    flp.topk(k=obj_topk) for flp in filtered_log_probs_list
                ]
                top_obj_logprob, top_obj_pred = zip(*top_obj_pred)

            if dynamic.startswith('obj_lm_topk'):
                # use highest obj prob as consistency score
                consist_score = torch.tensor(
                    [torch.max(flp).item() for flp in filtered_log_probs_list])
            elif dynamic.startswith('obj_lmgap_topk'):
                # the gap between the highest prediction log p1 - log p2
                get_gap = lambda top2: (top2[0] - top2[1]).item()
                consist_score = torch.tensor([
                    get_gap(torch.topk(flp, k=2)[0])
                    for flp in filtered_log_probs_list
                ])
            elif dynamic.startswith('bt_topk'):
                # use the obj_topk highest obj to "back translate" sub
                consist_score_obj_topk = []
                used_vocab = vocab_subset if vocab_subset is not None else model.vocab
                for obj_i in range(obj_topk):
                    sentences_b_mask_sub = [[
                        replace_list(s['sub_masked_sentences'][0][0],
                                     model.mask_token,
                                     used_vocab[obj_pred[obj_i].item()])
                    ] for s, obj_pred in zip(samples_b_this, top_obj_pred)]
                    sub_mask = [
                        s['sub_masked_sentences'][1] for s in samples_b_this
                    ]
                    # TODO: only masked lm can do this
                    consist_log_probs_list, _, _, tokens_tensor, mask_tensor = \
                        model.get_batch_generation(sentences_b_mask_sub, logger=logger, relation_mask=sub_mask)
                    # use avg prob of the sub as score
                    mask_tensor = mask_tensor.float()
                    consist_log_probs_list_flat = consist_log_probs_list.view(
                        -1, consist_log_probs_list.size(-1))
                    token_logprob = torch.gather(
                        consist_log_probs_list_flat,
                        dim=1,
                        index=tokens_tensor.view(
                            -1, 1)).view(*consist_log_probs_list.size()[:2])
                    token_logprob = token_logprob * mask_tensor
                    consist_score = token_logprob.sum(-1) / mask_tensor.sum(
                        -1)  # normalized prob
                    consist_score_obj_topk.append(consist_score)

                # SHAPE: (batch_size, obj_topk)
                consist_score_obj_topk = torch.stack(
                    consist_score_obj_topk).permute(1, 0)
                consist_score_weight = torch.stack(top_obj_logprob).exp()
                # SHAPE: (batch_size)
                consist_score = (consist_score_obj_topk *
                                 consist_score_weight).sum(-1) / (
                                     consist_score_weight.sum(-1) + 1e-10)

            # add to overall probability
            if filter_lp_merge is None:
                filter_lp_merge = filtered_log_probs_list
                if model2 is not None:
                    filter_lp_merge2 = filtered_log_probs_list2
                if dynamic == 'lm' or dynamic == 'real_lm':
                    max_score = consist_score
                elif dynamic.startswith('real_lm_topk') or \
                        dynamic.startswith('obj_lm_topk') or \
                        dynamic.startswith('obj_lmgap_topk') or \
                        dynamic.startswith('bt_topk'):
                    consist_score_li.append(consist_score)
            else:
                if dynamic == 'none' and temp_model is None:
                    filter_lp_merge = [
                        a + b for a, b in zip(filter_lp_merge,
                                              filtered_log_probs_list)
                    ]
                elif dynamic == 'all_topk':
                    filter_lp_merge.extend(filtered_log_probs_list)
                elif dynamic == 'lm' or dynamic == 'real_lm':
                    filter_lp_merge = \
                        [a if c >= d else b for a, b, c, d in
                         zip(filter_lp_merge, filtered_log_probs_list, max_score, consist_score)]
                    max_score = torch.max(max_score, consist_score)
                elif dynamic.startswith('real_lm_topk') or \
                        dynamic.startswith('obj_lm_topk') or \
                        dynamic.startswith('obj_lmgap_topk') or \
                        dynamic.startswith('bt_topk'):
                    filter_lp_merge.extend(filtered_log_probs_list)
                    consist_score_li.append(consist_score)
                elif temp_model is not None:
                    filter_lp_merge.extend(filtered_log_probs_list)
                    if model2 is not None:
                        filter_lp_merge2.extend(filtered_log_probs_list2)

            samples_b_prev = samples_b_this

        label_index_list = []
        obj_word_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)
            obj_word_list.append(sample['obj_label'])

        if dynamic == 'all_topk' or \
                dynamic.startswith('real_lm_topk') or \
                dynamic.startswith('obj_lm_topk') or \
                dynamic.startswith('obj_lmgap_topk') or \
                dynamic.startswith('bt_topk') or \
                temp_model is not None:  # analyze prob
            # SHAPE: (batch_size, num_temp, filter_vocab_size)
            filter_lp_merge = torch.stack(filter_lp_merge, 0).view(
                len(sentences_b_all),
                len(filter_lp_merge) // len(sentences_b_all),
                -1).permute(1, 0, 2)
            if model2 is not None:
                filter_lp_merge2 = torch.stack(filter_lp_merge2, 0).view(
                    len(sentences_b_all),
                    len(filter_lp_merge2) // len(sentences_b_all),
                    -1).permute(1, 0, 2)
            # SHAPE: (batch_size)
            label_index_tensor = torch.tensor(
                [index_list.index(li[0]) for li in label_index_list])
            c_inc = np.array(
                metrics.analyze_prob(filter_lp_merge,
                                     label_index_tensor,
                                     output=False,
                                     method='sample'))
            c_inc_stat += c_inc
        elif dynamic == 'none':
            # SHAPE: (batch_size, 1, filter_vocab_size)
            filter_lp_merge = torch.stack(filter_lp_merge, 0).unsqueeze(1)

        # SHAPE: (batch_size, num_temp, filter_vocab_size)
        filter_lp_unmerge = filter_lp_merge

        if temp_model is not None:  # optimize template weights
            temp_model_, optimizer = temp_model
            if optimizer is None:  # predict
                filter_lp_merge = temp_model_(args.relation,
                                              filter_lp_merge.detach(),
                                              target=None)
            elif optimizer == 'precompute':  # pre-compute and save featuers
                lp = filter_lp_merge
                # SHAPE: (batch_size * num_temp)
                features = torch.gather(lp.contiguous().view(-1, lp.size(-1)),
                                        dim=1,
                                        index=label_index_tensor.repeat(
                                            lp.size(1)).view(-1, 1))
                features = features.view(-1, lp.size(1))
                features_list.append(features)
                if not bt_obj:
                    continue
            elif optimizer is not None:  # train on the fly
                features_list.append(
                    filter_lp_merge
                )  # collect features that will later be used in optimization
                if model2 is not None:
                    features_list2.append(
                        filter_lp_merge2
                    )  # collect features that will later be used in optimization
                label_index_tensor_list.append(
                    label_index_tensor)  # collect labels
                if not bt_obj:
                    continue
                else:
                    #filter_lp_merge = temp_model_(args.relation, filter_lp_merge.detach(), target=None)
                    filter_lp_merge = filter_lp_merge.mean(
                        1)  # use average prob to beam search

        if dynamic.startswith('real_lm_topk') or \
                dynamic.startswith('obj_lm_topk') or \
                dynamic.startswith('obj_lmgap_topk') or \
                dynamic.startswith('bt_topk'):  # dynamic ensemble
            real_lm_topk = min(
                int(dynamic[dynamic.find('topk') + 4:].split('-')[0]),
                len(consist_score_li))
            # SHAPE: (batch_size, num_temp)
            consist_score_li = torch.stack(consist_score_li, -1)
            # SHAPE: (batch_size, topk)
            consist_score, consist_ind = consist_score_li.topk(real_lm_topk,
                                                               dim=-1)
            # SHAPE: (batch_size, 1)
            consist_score = consist_score.min(-1, keepdim=True)[0]
            # SHAPE: (batch_size, num_temp, 1)
            consist_mask = (consist_score_li >=
                            consist_score).float().unsqueeze(-1)
            # avg over top k
            filter_lp_merge = filter_lp_merge * consist_mask
            filter_lp_merge = filter_lp_merge.sum(1) / consist_mask.sum(1)

        if bt_obj:  # choose top bt_obj objects and bach-translate subject
            # get the top bt_obj objects with highest probability
            used_vocab = vocab_subset if vocab_subset is not None else model.vocab
            temp_model_, optimizer = temp_model
            if optimizer is None:  # use beam search
                # SHAPE: (batch_size, bt_obj)
                objs_score, objs_ind = filter_lp_merge.topk(bt_obj, dim=-1)
                objs_ind = torch.sort(objs_ind,
                                      dim=-1)[0]  # the index must be ascending
            elif optimizer == 'precompute':  # use ground truth
                objs_ind = label_index_tensor.view(-1, 1)
                bt_obj = 1
            elif optimizer is not None:  # get both ground truth and beam search
                # SHAPE: (batch_size, bt_obj)
                objs_score, objs_ind = filter_lp_merge.topk(bt_obj, dim=-1)
                objs_ind = torch.cat(
                    [objs_ind, label_index_tensor.view(-1, 1)], -1)
                objs_ind = torch.sort(objs_ind,
                                      dim=-1)[0]  # the index must be ascending
                bt_obj += 1

            # bach translation
            sub_lp_list = []
            for sentences_b, samples_b_this in zip(
                    sentences_b_all, samples_b_all):  # iter over templates
                for obj_i in range(bt_obj):  # iter over objs
                    sentences_b_mask_sub = []
                    for s, obj_pred, obj_word in zip(samples_b_this, objs_ind,
                                                     obj_word_list):
                        replace_tok = used_vocab[obj_pred[obj_i].item()]
                        if optimizer == 'precompute':
                            assert replace_tok.strip() == obj_word.strip()
                        sentences_b_mask_sub.append([
                            replace_list(s['sub_masked_sentences'][0][0],
                                         model.mask_token, replace_tok)
                        ])
                    sub_mask = [
                        s['sub_masked_sentences'][1] for s in samples_b_this
                    ]
                    # TODO: only masked lm can do this
                    lp, _, _, tokens_tensor, mask_tensor = \
                        model.get_batch_generation(sentences_b_mask_sub, logger=logger, relation_mask=sub_mask)
                    # use avg prob of the sub as score
                    mask_tensor = mask_tensor.float()
                    lp_flat = lp.view(-1, lp.size(-1))
                    sub_lp = torch.gather(lp_flat,
                                          dim=1,
                                          index=tokens_tensor.view(
                                              -1, 1)).view(*lp.size()[:2])
                    sub_lp = sub_lp * mask_tensor
                    sub_lp_avg = sub_lp.sum(-1) / mask_tensor.sum(
                        -1)  # normalized prob
                    sub_lp_list.append(sub_lp_avg)

            # SHAPE: (batch_size, num_temp, top_obj_num)
            num_temp = len(sentences_b_all)
            sub_lp_list = torch.cat(sub_lp_list, 0).view(num_temp, bt_obj,
                                                         -1).permute(2, 0, 1)

            if optimizer == 'precompute':
                bt_features_list.append(sub_lp_list.squeeze(-1))
                continue
            elif optimizer is not None:
                sub_lp_list_expand = torch.zeros_like(filter_lp_unmerge)
                # SHAPE: (batch_size, num_temp, vocab_size)
                sub_lp_list_expand.scatter_(
                    -1,
                    objs_ind.unsqueeze(1).repeat(1, num_temp, 1), sub_lp_list)
                bt_features_list.append(sub_lp_list_expand)
                bt_obj -= 1
                continue

            # select obj prob
            expand_mask = torch.zeros_like(filter_lp_unmerge)
            expand_mask.scatter_(-1,
                                 objs_ind.unsqueeze(1).repeat(1, num_temp, 1),
                                 1)
            # SHAPE: (batch_size, num_temp, top_obj_num)
            obj_lp_list = torch.masked_select(filter_lp_unmerge,
                                              expand_mask.eq(1)).view(
                                                  -1, num_temp, bt_obj)

            # run temp model
            # SHAPE: (batch_size, vocab_size)
            filter_lp_merge_expand = torch.zeros_like(filter_lp_merge)
            # SHAPE: (batch_size, top_obj_num)
            filter_lp_merge = temp_model_(args.relation,
                                          torch.cat([obj_lp_list, sub_lp_list],
                                                    1),
                                          target=None)

            # expand results to vocab_size
            filter_lp_merge_expand.scatter_(-1, objs_ind, filter_lp_merge)
            filter_lp_merge = filter_lp_merge_expand + expand_mask[:, 0, :].log(
            )  # mask out other objs

        if len(filter_lp_merge.size()) == 2:
            filter_lp_merge = filter_lp_merge.unsqueeze(1)

        for temp_id in range(filter_lp_merge.size(1)):

            arguments = [{
                "original_log_probs": original_log_probs,
                "filtered_log_probs": filtered_log_probs,
                "token_ids": token_ids,
                "vocab": model.vocab,
                "label_index": label_index[0],
                "masked_indices": masked_indices,
                "interactive": args.interactive,
                "index_list": index_list,
                "sample": sample,
            } for original_log_probs, filtered_log_probs, token_ids,
                         masked_indices, label_index, sample in zip(
                             original_log_probs_list,
                             filter_lp_merge[:, :temp_id + 1].sum(1),
                             token_ids_list,
                             masked_indices_list,
                             label_index_list,
                             samples_b,
                         )]

            # single thread for debug
            # for isx,a in enumerate(arguments):
            #     print(samples_b[isx])
            #     run_thread(a)

            # multithread
            res = pool.map(run_thread, arguments)

            for idx, result in enumerate(res):

                result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

                logger.info("\n" + msg + "\n")

                sample = samples_b[idx]

                element = {}
                element["sample"] = sample
                element["uuid"] = sample["uuid"]
                element["token_ids"] = token_ids_list[idx]
                element["masked_indices"] = masked_indices_list[idx]
                element["label_index"] = label_index_list[idx]
                element["masked_topk"] = result_masked_topk
                element["sample_MRR"] = sample_MRR
                element["sample_Precision"] = sample_P
                element["sample_perplexity"] = sample_perplexity
                element["sample_Precision1"] = result_masked_topk["P_AT_1"]

                # print()
                # print("idx: {}".format(idx))
                # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
                # for yi in range(10):
                #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
                # print("masked_indices_list: {}".format(masked_indices_list[idx]))
                # print("sample_MRR: {}".format(sample_MRR))
                # print("sample_P: {}".format(sample_P))
                # print("sample: {}".format(sample))
                # print()

                MRR[temp_id] += sample_MRR
                Precision[temp_id] += sample_P
                Precision1[temp_id] += element["sample_Precision1"]
                P1_li[temp_id].append(element["sample_Precision1"])
                '''
                if element["sample_Precision1"] == 1:
                    print(element["sample"])
                    input(1)
                else:
                    print(element["sample"])
                    input(0)
                '''

                # the judgment of the annotators recording whether they are
                # evidence in the sentence that indicates a relation between two entities.
                num_yes = 0
                num_no = 0

                if "judgments" in sample:
                    # only for Google-RE
                    for x in sample["judgments"]:
                        if x["judgment"] == "yes":
                            num_yes += 1
                        else:
                            num_no += 1
                    if num_no >= num_yes:
                        samples_with_negative_judgement[temp_id] += 1
                        element["judgement"] = "negative"
                        MRR_negative[temp_id] += sample_MRR
                        Precision_negative[temp_id] += sample_P
                    else:
                        samples_with_positive_judgement[temp_id] += 1
                        element["judgement"] = "positive"
                        MRR_positive[temp_id] += sample_MRR
                        Precision_positivie[temp_id] += sample_P

                list_of_results[temp_id].append(element)

    if temp_model is not None:
        if temp_model[1] == 'precompute':
            features = torch.cat(features_list, 0)
            if bt_obj:
                bt_features = torch.cat(bt_features_list, 0)
                features = torch.cat([features, bt_features], 1)
            return features
        if temp_model[1] is not None:
            # optimize the model on the fly
            temp_model_, (optimizer, temperature) = temp_model
            temp_model_.cuda()
            # SHAPE: (batch_size, num_temp, vocab_size)
            features = torch.cat(features_list, 0)
            if model2 is not None:
                features2 = torch.cat(features_list2, 0)
            if bt_obj:
                bt_features = torch.cat(bt_features_list, 0)
                features = torch.cat([features, bt_features], 1)
            # compute weight
            # SHAPE: (batch_size,)
            label_index_tensor = torch.cat(label_index_tensor_list, 0)
            label_count = torch.bincount(label_index_tensor)
            label_count = torch.index_select(label_count, 0,
                                             label_index_tensor)
            sample_weight = F.softmax(
                temperature * torch.log(1.0 / label_count.float()),
                0) * label_index_tensor.size(0)
            min_loss = 1e10
            es = 0
            batch_size = 128
            for e in range(500):
                # loss = temp_model_(args.relation, features.cuda(), target=label_index_tensor.cuda(), use_softmax=True)
                loss_li = []
                for b in range(0, features.size(0), batch_size):
                    features_b = features[b:b + batch_size].cuda()
                    label_index_tensor_b = label_index_tensor[b:b +
                                                              batch_size].cuda(
                                                              )
                    sample_weight_b = sample_weight[b:b + batch_size].cuda()
                    loss = temp_model_(args.relation,
                                       features_b,
                                       target=label_index_tensor_b,
                                       sample_weight=sample_weight_b,
                                       use_softmax=True)
                    if model2 is not None:
                        features2_b = features2[b:b + batch_size].cuda()
                        loss2 = temp_model_(args.relation,
                                            features2_b,
                                            target=label_index_tensor_b,
                                            sample_weight=sample_weight_b,
                                            use_softmax=True)
                        loss = loss + loss2
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    loss_li.append(loss.cpu().item())
                dev_loss = np.mean(loss_li)
                if dev_loss - min_loss < -1e-3:
                    min_loss = dev_loss
                    es = 0
                else:
                    es += 1
                    if es >= 30:
                        print('early stop')
                        break
            temp_model_.cpu()
            return min_loss

    pool.close()
    pool.join()

    for temp_id in range(len(P1_li)):
        # stats
        # Mean reciprocal rank
        MRR[temp_id] /= len(list_of_results[temp_id])

        # Precision
        Precision[temp_id] /= len(list_of_results[temp_id])
        Precision1[temp_id] /= len(list_of_results[temp_id])

        msg = "all_samples: {}\n".format(len(all_samples))
        msg += "list_of_results: {}\n".format(len(list_of_results[temp_id]))
        msg += "global MRR: {}\n".format(MRR[temp_id])
        msg += "global Precision at 10: {}\n".format(Precision[temp_id])
        msg += "global Precision at 1: {}\n".format(Precision1[temp_id])

        if samples_with_negative_judgement[
                temp_id] > 0 and samples_with_positive_judgement[temp_id] > 0:
            # Google-RE specific
            MRR_negative[temp_id] /= samples_with_negative_judgement[temp_id]
            MRR_positive[temp_id] /= samples_with_positive_judgement[temp_id]
            Precision_negative[temp_id] /= samples_with_negative_judgement[
                temp_id]
            Precision_positivie[temp_id] /= samples_with_positive_judgement[
                temp_id]
            msg += "samples_with_negative_judgement: {}\n".format(
                samples_with_negative_judgement[temp_id])
            msg += "samples_with_positive_judgement: {}\n".format(
                samples_with_positive_judgement[temp_id])
            msg += "MRR_negative: {}\n".format(MRR_negative[temp_id])
            msg += "MRR_positive: {}\n".format(MRR_positive[temp_id])
            msg += "Precision_negative: {}\n".format(
                Precision_negative[temp_id])
            msg += "Precision_positivie: {}\n".format(
                Precision_positivie[temp_id])

        logger.info("\n" + msg + "\n")
        print("\n" + msg + "\n")

        # dump pickle with the result of the experiment
        all_results = dict(list_of_results=list_of_results[temp_id],
                           global_MRR=MRR,
                           global_P_at_10=Precision)
        with open("{}/result.pkl".format(log_directory), "wb") as f:
            pickle.dump(all_results, f)

        print('P1all {}'.format('\t'.join(map(str, P1_li[temp_id]))))

    print('meaning: {}'.format(c_inc_meaning))
    print('correct-incorrect {}'.format('\t'.join(
        map(str,
            (c_inc_stat[:, :-1] / (c_inc_stat[:, -1:] + 1e-5)).reshape(-1)))))

    return Precision1[-1]
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            "Please specify a single language model (e.g., --lm \"bert\").")

    msg = ""

    [model_type_name] = args.models_names

    print(model)
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == 'fairseq':
        model_name = 'fairseq_{}'.format(args.fairseq_model_name)
    elif model_type_name == 'bert':
        model_name = 'BERT_{}'.format(args.bert_model_name)
    elif model_type_name == 'elmo':
        model_name = 'ELMo_{}'.format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), 'w') as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.
    MRR_negative = 0.
    MRR_positive = 0.

    # Precision at (default 10)
    Precision = 0.
    Precision1 = 0.
    Precision_negative = 0.
    Precision_positivie = 0.

    data = load_file(args.dataset_filename)

    print(len(data))

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        all_samples = lowercase_samples(data)
    else:
        # keep samples as they are
        all_samples = data

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != '':
        facts = []
        for sample in all_samples:
            sub = sample['sub_label']
            obj = sample['obj_label']
            if (sub, obj) not in facts:
                facts.append((sub, obj))
        local_msg = "distinct template facts: {}".format(len(facts))
        logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj) = fact
            sample = {}
            sample['sub_label'] = sub
            sample['obj_label'] = obj
            # sobstitute all sentences with a standard template
            sample['masked_sentences'] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK)
            all_samples.append(sample)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if 'uuid' not in sample:
            sample['uuid'] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []

    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        original_log_probs_list, token_ids_list, masked_indices_list = model.get_batch_generation(
            sentences_b, logger=logger)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample['obj_label'])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample['obj_label']))
            elif (model.vocab[obj_label_id[0]] != sample['obj_label']):
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample['obj_label']))
            elif vocab_subset is not None and sample[
                    'obj_label'] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample['obj_label']))

            label_index_list.append(obj_label_id)

        arguments = [{
            'original_log_probs': original_log_probs,
            'filtered_log_probs': filtered_log_probs,
            'token_ids': token_ids,
            'vocab': model.vocab,
            'label_index': label_index[0],
            'masked_indices': masked_indices,
            'interactive': args.interactive,
            'index_list': index_list,
            'sample': sample
        } for original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample in zip(
                         original_log_probs_list, filtered_log_probs_list,
                         token_ids_list, masked_indices_list, label_index_list,
                         samples_b)]

        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):

            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element['sample'] = sample
            element['uuid'] = sample['uuid']
            element['token_ids'] = token_ids_list[idx]
            element['masked_indices'] = masked_indices_list[idx]
            element['label_index'] = label_index_list[idx]
            element['masked_topk'] = result_masked_topk
            element['sample_MRR'] = sample_MRR
            element['sample_Precision'] = sample_P
            element['sample_perplexity'] = sample_perplexity
            element['sample_Precision1'] = result_masked_topk["P_AT_1"]

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element['sample_Precision1']

            # the judgment of the annotators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if 'judgments' in sample:
                # only for Google-RE
                for x in sample['judgments']:
                    if (x['judgment'] == "yes"):
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element['judgement'] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element['judgement'] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P

            list_of_results.append(element)

    pool.close()
    pool.join()

    # stats
    # Mean reciprocal rank
    MRR /= len(list_of_results)

    # Precision
    Precision /= len(list_of_results)
    Precision1 /= len(list_of_results)

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement)
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement)
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(
        list_of_results=list_of_results,
        global_MRR=MRR,
        global_P_at_10=Precision,
    )
    with open("{}/result.pkl".format(log_directory), 'wb') as f:
        pickle.dump(all_results, f)

    return Precision1
Example #20
0
def run_experiments(
    relations,
    data_path_pre,
    data_path_post,
    input_param={
        "lm":
        "bert",
        "label":
        "bert_large",
        "models_names": ["bert"],
        "bert_model_name":
        "bert-large-cased",
        "bert_model_dir":
        "pre-trained_language_models/bert/cased_L-24_H-1024_A-16",
    },
    use_negated_probes=False,
):
    model = None
    pp = pprint.PrettyPrinter(width=41, compact=True)

    all_Precision1 = []
    all_Precision10 = []
    type_Precision1 = defaultdict(list)
    type_count = defaultdict(list)

    # Append to results_file
    results_file = open("last_results.csv", "a", encoding='utf-8')
    results_file.write('\n')

    for relation in relations:
        pp.pprint(relation)
        PARAMETERS = {
            "dataset_filename":
            "{}{}{}".format(data_path_pre, relation["relation"],
                            data_path_post),
            "common_vocab_filename":
            "pre-trained_language_models/common_vocab_cased.txt",
            "template":
            "",
            "bert_vocab_name":
            "vocab.txt",
            "batch_size":
            32,
            "logdir":
            "output",
            "full_logdir":
            "output/results/{}/{}".format(input_param["label"],
                                          relation["relation"]),
            "lowercase":
            False,
            "max_sentence_length":
            100,
            "threads":
            -1,
            "interactive":
            False,
            "use_negated_probes":
            use_negated_probes,
        }

        if "template" in relation:
            PARAMETERS["template"] = relation["template"]
            if use_negated_probes:
                PARAMETERS["template_negated"] = relation["template_negated"]

        PARAMETERS.update(input_param)
        print(PARAMETERS)

        args = argparse.Namespace(**PARAMETERS)
        relation_name = relation["relation"]
        if relation_name == "test":
            relation_name = data_path_pre.replace("/", "") + "_test"

        # see if file exists
        try:
            data = load_file(args.dataset_filename)
        except Exception as e:
            print("Relation {} excluded.".format(relation_name))
            print("Exception: {}".format(e))
            continue

        if model is None:
            [model_type_name] = args.models_names
            model = build_model_by_name(model_type_name, args)

        Precision1, Precision10 = run_evaluation(args,
                                                 shuffle_data=False,
                                                 model=model)
        print("P@1 : {}".format(Precision1), flush=True)
        all_Precision1.append(Precision1)
        all_Precision10.append(Precision10)

        results_file.write("[{}] {}: {}, P10 = {}, P1 = {}\n".format(
            datetime.now(), input_param["label"], relation_name,
            round(Precision10 * 100, 2), round(Precision1 * 100, 2)))
        results_file.flush()

        if "type" in relation:
            type_Precision1[relation["type"]].append(Precision1)
            data = load_file(PARAMETERS["dataset_filename"])
            type_count[relation["type"]].append(len(data))

    mean_p1 = statistics.mean(all_Precision1)
    mean_p10 = statistics.mean(all_Precision10)
    summaryP1 = "@@@ {} - mean P@10 = {}, mean P@1 = {}".format(
        input_param["label"], round(mean_p10 * 100, 2),
        round(mean_p1 * 100, 2))
    print(summaryP1)
    results_file.write(f'{summaryP1}\n')
    results_file.flush()

    for t, l in type_Precision1.items():
        prec1item = f'@@@ Label={input_param["label"]}, type={t}, samples={sum(type_count[t])}, relations={len(type_count[t])}, mean prec1={round(statistics.mean(l) * 100, 2)}\n'
        print(prec1item, flush=True)
        results_file.write(prec1item)
        results_file.flush()

    results_file.close()
    return mean_p1, all_Precision1
def run_experiments(
    data_path_pre,
    data_path_post,
    input_param={
        "lm":
        "bert",
        "label":
        "bert_large",
        "models_names": ["bert"],
        "bert_model_name":
        "bert-large-cased",
        "bert_model_dir":
        "pre-trained_language_models/bert/cased_L-24_H-1024_A-16",
    },
    use_negated_probes=False,
):
    model = None
    pp = pprint.PrettyPrinter(width=41, compact=True)

    all_Precision1 = []
    type_Precision1 = defaultdict(list)
    type_count = defaultdict(list)

    for i in range(1):
        PARAMETERS = {
            "dataset_filename":
            "reproduction/data/TREx_filter/different_queries.jsonl",
            "fact_pair_filename":
            "reproduction/data/TREx_filter/different_queries_facts.jsonl",
            "common_vocab_filename":
            "pre-trained_language_models/common_vocab_cased.txt",
            "template":
            "",
            "bert_vocab_name":
            "vocab.txt",
            "batch_size":
            10,
            "logdir":
            "output",
            "full_logdir":
            "output/results/{}/{}".format(input_param["label"],
                                          "different_queries"),
            "lowercase":
            False,
            "max_sentence_length":
            100,
            "threads":
            -1,
            "interactive":
            False,
            "use_negated_probes":
            use_negated_probes,
        }

        PARAMETERS.update(input_param)
        args = argparse.Namespace(**PARAMETERS)

        if model is None:
            [model_type_name] = args.models_names
            model = build_model_by_name(model_type_name, args)

        Precision1 = run_evaluation(args, shuffle_data=False, model=model)
        print("P@1 : {}".format(Precision1), flush=True)
        all_Precision1.append(Precision1)

    mean_p1 = statistics.mean(all_Precision1)
    print("@@@ {} - mean P@1: {}".format(input_param["label"], mean_p1))

    for t, l in type_Precision1.items():

        print(
            "@@@ ",
            input_param["label"],
            t,
            statistics.mean(l),
            sum(type_count[t]),
            len(type_count[t]),
            flush=True,
        )

    return mean_p1, all_Precision1
Example #22
0
def run_experiments(
    relations,
    data_path_pre,
    data_path_post,
    refine_template,
    get_objs,
    batch_size,
    dynamic=None,
    use_prob=False,
    bt_obj=None,
    temp_model=None,
    save=None,
    load=None,
    feature_dir=None,
    enforce_prob=True,
    num_feat=1,
    temperature=0.0,
    use_model2=False,
    lowercase=False,
    upper_entity=False,
    input_param={
        "lm":
        "bert",
        "label":
        "bert_large",
        "models_names": ["bert"],
        "bert_model_name":
        "bert-large-cased",
        "bert_model_dir":
        "pre-trained_language_models/bert/cased_L-24_H-1024_A-16",
    },
):
    model, model2 = None, None
    pp = pprint.PrettyPrinter(width=41, compact=True)

    all_Precision1 = []
    type_Precision1 = defaultdict(list)
    type_count = defaultdict(list)
    print('use lowercase: {}, use upper entity: {}'.format(
        lowercase, upper_entity))

    results_file = open("last_results.csv", "w+")

    if refine_template:
        refine_temp_fout = open(refine_template, 'w')
        new_relations = []
        templates_set = set()

    rel2numtemp = {}
    for relation in relations:  # collect templates
        if 'template' in relation:
            if type(relation['template']) is not list:
                relation['template'] = [relation['template']]
        rel2numtemp[relation['relation']] = len(relation['template'])

    if temp_model is not None:
        if temp_model.startswith('mixture'):
            method = temp_model.split('_')[1]
            if method == 'optimize':  # (extract feature) + optimize
                temp_model = TempModel(rel2numtemp,
                                       enforce_prob=enforce_prob,
                                       num_feat=num_feat)
                temp_model.train()
                optimizer = optim.Adam(temp_model.parameters(), lr=1e-1)
                temp_model = (temp_model, (optimizer, temperature))
            elif method == 'precompute':  # extract feature
                temp_model = (None, 'precompute')
            elif method == 'predict':  # predict
                temp_model = TempModel(
                    rel2numtemp, enforce_prob=enforce_prob,
                    num_feat=num_feat)  # TODO: number of feature
                if load is not None:
                    temp_model.load_state_dict(torch.load(load))
                temp_model.eval()
                temp_model = (temp_model, None)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

    for relation in relations:
        pp.pprint(relation)
        PARAMETERS = {
            "relation":
            relation["relation"],
            "dataset_filename":
            "{}/{}{}".format(data_path_pre, relation["relation"],
                             data_path_post),
            "common_vocab_filename":
            "pre-trained_language_models/common_vocab_cased.txt",
            "template":
            "",
            "bert_vocab_name":
            "vocab.txt",
            "batch_size":
            batch_size,
            "logdir":
            "output",
            "full_logdir":
            "output/results/{}/{}".format(input_param["label"],
                                          relation["relation"]),
            "lowercase":
            lowercase,
            "upper_entity":
            upper_entity,
            "max_sentence_length":
            100,
            "threads":
            -1,
            "interactive":
            False,
        }
        dev_param = deepcopy(PARAMETERS)
        dev_param['dataset_filename'] = '{}/{}{}'.format(
            data_path_pre + '_dev', relation['relation'], data_path_post)
        bert_large_param = deepcopy(PARAMETERS)

        if 'template' in relation:
            PARAMETERS['template'] = relation['template']
            dev_param['template'] = relation['template']
            bert_large_param['template'] = relation['template']

        PARAMETERS.update(input_param)
        dev_param.update(input_param)
        bert_large_param.update(
            LM_BERT_LARGE
        )  # this is used to optimize the weights for bert-base and bert-large at the same time
        print(PARAMETERS)

        args = argparse.Namespace(**PARAMETERS)
        dev_args = argparse.Namespace(**dev_param)
        bert_large_args = argparse.Namespace(**bert_large_param)

        # see if file exists
        try:
            data = load_file(args.dataset_filename)
        except Exception as e:
            print("Relation {} excluded.".format(relation["relation"]))
            print("Exception: {}".format(e))
            continue

        if model is None:
            [model_type_name] = args.models_names
            model = build_model_by_name(model_type_name, args)
            if use_model2:
                model2 = build_model_by_name(bert_large_args.models_names[0],
                                             bert_large_args)

        if temp_model is not None:
            if temp_model[1] == 'precompute':
                features = run_evaluation(
                    args,
                    shuffle_data=False,
                    model=model,
                    refine_template=bool(refine_template),
                    get_objs=get_objs,
                    dynamic=dynamic,
                    use_prob=use_prob,
                    bt_obj=bt_obj,
                    temp_model=temp_model)
                print('save features for {}'.format(relation['relation']))
                torch.save(features,
                           os.path.join(save, relation['relation'] + '.pt'))
                continue
            elif temp_model[1] is not None:  # train temp model
                if feature_dir is None:
                    loss = run_evaluation(
                        args,
                        shuffle_data=False,
                        model=model,
                        model2=model2,
                        refine_template=bool(refine_template),
                        get_objs=get_objs,
                        dynamic=dynamic,
                        use_prob=use_prob,
                        bt_obj=bt_obj,
                        temp_model=temp_model)
                else:
                    temp_model_, (optimizer, temperature) = temp_model
                    temp_model_.cuda()
                    min_loss = 1e10
                    es = 0
                    for e in range(500):
                        # SHAPE: (num_sample, num_temp)
                        feature = torch.load(
                            os.path.join(feature_dir,
                                         args.relation + '.pt')).cuda()
                        #dev_feature = torch.load(os.path.join(feature_dir + '_dev', args.relation + '.pt')).cuda()
                        #feature = torch.cat([feature, dev_feature], 0)
                        #weight = feature.mean(0)
                        #temp_model[0].set_weight(args.relation, weight)
                        optimizer.zero_grad()
                        loss = temp_model_(args.relation, feature)
                        if os.path.exists(feature_dir +
                                          '__dev'):  # TODO: debug
                            dev_feature = torch.load(
                                os.path.join(feature_dir + '_dev',
                                             args.relation + '.pt')).cuda()
                            dev_loss = temp_model_(args.relation, dev_feature)
                        else:
                            dev_loss = loss
                        loss.backward()
                        optimizer.step()
                        if dev_loss - min_loss < -1e-3:
                            min_loss = dev_loss
                            es = 0
                        else:
                            es += 1
                            if es >= 10:
                                print('early stop')
                                break
                continue

        Precision1 = run_evaluation(args,
                                    shuffle_data=False,
                                    model=model,
                                    refine_template=bool(refine_template),
                                    get_objs=get_objs,
                                    dynamic=dynamic,
                                    use_prob=use_prob,
                                    bt_obj=bt_obj,
                                    temp_model=temp_model)

        if get_objs:
            return

        if refine_template and Precision1 is not None:
            if Precision1 in templates_set:
                continue
            templates_set.add(Precision1)
            new_relation = deepcopy(relation)
            new_relation['old_template'] = new_relation['template']
            new_relation['template'] = Precision1
            new_relations.append(new_relation)
            refine_temp_fout.write(json.dumps(new_relation) + '\n')
            refine_temp_fout.flush()
            continue

        print("P@1 : {}".format(Precision1), flush=True)
        all_Precision1.append(Precision1)

        results_file.write("{},{}\n".format(relation["relation"],
                                            round(Precision1 * 100, 2)))
        results_file.flush()

        if "type" in relation:
            type_Precision1[relation["type"]].append(Precision1)
            data = load_file(PARAMETERS["dataset_filename"])
            type_count[relation["type"]].append(len(data))

    if refine_template:
        refine_temp_fout.close()
        return

    if temp_model is not None:
        if save is not None and temp_model[0] is not None:
            torch.save(temp_model[0].state_dict(), save)
        return

    mean_p1 = statistics.mean(all_Precision1)
    print("@@@ {} - mean P@1: {}".format(input_param["label"], mean_p1))
    results_file.close()

    for t, l in type_Precision1.items():

        print(
            "@@@ ",
            input_param["label"],
            t,
            statistics.mean(l),
            sum(type_count[t]),
            len(type_count[t]),
            flush=True,
        )

    return mean_p1, all_Precision1
Example #23
0
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""
    [model_type_name] = args.models_names

    # print("------- Model: {}".format(model))
    # print("------- Args: {}".format(args))
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    # logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    # with open("{}/args.json".format(log_directory), "w") as outfile:
    #     json.dump(vars(args), outfile)

    # Mean reciprocal rank
    MRR = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    data = load_file(args.dataset_filename)

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    # logger.info("\n" + ret_msg + "\n")

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        facts = []
        for sample in all_samples:
            sub = sample["sub_label"]
            obj = sample["obj_label"]
            if (sub, obj) not in facts:
                facts.append((sub, obj))
        local_msg = "distinct template facts: {}".format(len(facts))
        # logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj) = fact
            sample = {}
            sample["sub_label"] = sub
            sample["obj_label"] = obj
            # sobstitute all sentences with a standard template
            sample["masked_sentences"] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK)
            if args.use_negated_probes:
                # substitute all negated sentences with a standard template
                sample["negated"] = parse_template(
                    args.template_negated.strip(),
                    sample["sub_label"].strip(),
                    base.MASK,
                )
            all_samples.append(sample)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    # logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)

    # list_of_results = []
    # list_of_ranks = []
    item_count = 0
    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        (
            original_log_probs_list,
            token_ids_list,
            masked_indices_list,
        ) = model.get_batch_generation(sentences_b, logger=logger)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)

        arguments = [{
            "original_log_probs": original_log_probs,
            "filtered_log_probs": filtered_log_probs,
            "token_ids": token_ids,
            "vocab": model.vocab,
            "label_index": label_index[0],
            "masked_indices": masked_indices,
            "interactive": args.interactive,
            "index_list": index_list,
            "sample": sample,
        } for original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample in zip(
                         original_log_probs_list,
                         filtered_log_probs_list,
                         token_ids_list,
                         masked_indices_list,
                         label_index_list,
                         samples_b,
                     )]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):
            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

            # print("~~~~~~~~~~~~~~~~~~")
            # print(result_masked_topk)

            # logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            obj = sample['obj_label']
            sub = sample['sub_label']
            element["masked_sentences"] = sample["masked_sentences"][0]
            # element["uuid"] = sample["uuid"]
            element["subject"] = sub
            element["object"] = obj
            element["rank"] = int(result_masked_topk['rank'])
            # element["sample_Precision1"] = result_masked_topk["P_AT_1"]

            # element["sample"] = sample
            # element["token_ids"] = token_ids_list[idx]
            # element["masked_indices"] = masked_indices_list[idx]
            # element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk['topk'][:20]
            # element["sample_MRR"] = sample_MRR
            # element["sample_Precision"] = sample_P
            # element["sample_perplexity"] = sample_perplexity

            # list_of_results[sub + "_" + obj].append(element)
            # list_of_ranks[sub + "_" + obj].append(element["rank"])

            # print("~~~~~~ rank: {}".format(result_masked_topk['rank']))
            MRR += sample_MRR
            Precision += sample_P
            Precision1 += result_masked_topk["P_AT_1"]

            item_count += 1
            append_data_line_to_jsonl(
                "reproduction/data/P_AT_K/{}_rank_results.jsonl".format(
                    args.label), element)
            append_data_line_to_jsonl(
                "reproduction/data/P_AT_K/{}_rank_list.jsonl".format(
                    args.label), element["rank"])

    pool.close()
    pool.join()
    Precision1 /= item_count

    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_results.jsonl".format(args.label), list_of_results) # 3122
    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_dic.jsonl".format(args.label), list_of_ranks) # 3122
    # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_list.jsonl".format(args.label), list(list_of_ranks.values())) # 3122

    return Precision1
Example #24
0
def run_experiments(
    relations,
    data_path_pre,
    data_path_post,
    input_param={
        "lm":
        "bert",
        "label":
        "bert_large",
        "models_names": ["bert"],
        "bert_model_name":
        "bert-large-cased",
        "bert_model_dir":
        "pre-trained_language_models/bert/cased_L-24_H-1024_A-16",
    },
):
    model = None
    pp = pprint.PrettyPrinter(width=41, compact=True)

    all_Precision1 = []
    type_Precision1 = defaultdict(list)
    type_count = defaultdict(list)

    results_file = open("last_results.csv", "w+")

    for relation in relations:
        pp.pprint(relation)
        PARAMETERS = {
            "dataset_filename":
            "{}{}{}".format(data_path_pre, relation["relation"],
                            data_path_post),
            "common_vocab_filename":
            "pre-trained_language_models/common_vocab_cased.txt",
            "template":
            "",
            "bert_vocab_name":
            "vocab.txt",
            "batch_size":
            32,
            "logdir":
            "output",
            "full_logdir":
            "output/results/{}/{}".format(input_param["label"],
                                          relation["relation"]),
            "lowercase":
            False,
            "max_sentence_length":
            100,
            "threads":
            -1,
            "interactive":
            False,
            "use_negated_probes":
            False,
        }

        if "template" in relation:
            PARAMETERS["template"] = relation["template"]

        PARAMETERS.update(input_param)
        print(PARAMETERS)

        args = argparse.Namespace(**PARAMETERS)

        # see if file exists
        try:
            data = load_file(args.dataset_filename)
        except Exception as e:
            print("Relation {} excluded.".format(relation["relation"]))
            print("Exception: {}".format(e))
            continue

        if model is None:
            [model_type_name] = args.models_names
            model = build_model_by_name(model_type_name, args)

        Precision1 = run_evaluation(args, shuffle_data=False, model=model)
        print("P@1 : {}".format(Precision1), flush=True)
        all_Precision1.append(Precision1)

        results_file.write("{},{}\n".format(relation["relation"],
                                            round(Precision1 * 100, 2)))
        results_file.flush()

        if "type" in relation:
            type_Precision1[relation["type"]].append(Precision1)
            data = load_file(PARAMETERS["dataset_filename"])
            type_count[relation["type"]].append(len(data))

    mean_p1 = statistics.mean(all_Precision1)
    print("@@@ {} - mean P@1: {}".format(input_param["label"], mean_p1))
    results_file.close()

    for t, l in type_Precision1.items():

        print(
            "@@@ ",
            input_param["label"],
            t,
            statistics.mean(l),
            sum(type_count[t]),
            len(type_count[t]),
            flush=True,
        )

    return mean_p1, all_Precision1
Example #25
0
def main(args):
    verbose = False
    moral_template = "{} is a [MASK] behavior ."

    data_user_study = pickle.load(
        open('../MoRT/data/parsed_yes_no_BERTBias.p', 'rb'))
    actions = list(data_user_study.keys())
    actions = ['help people', 'help white people', 'help black people']
    print("Language Models: {}".format(args.models_names))

    models = {}
    for lm in args.models_names:
        models[lm] = build_model_by_name(lm, args)

    vocab_subset = None
    if args.common_vocab_filename is not None:
        common_vocab = load_vocab(args.common_vocab_filename)
        print("common vocabulary size: {}".format(len(common_vocab)))
        vocab_subset = [x for x in common_vocab]

    data_user_study_BERTKnowledge = dict()
    for action in actions:
        action_ = action.split(" ")
        action_[0] = verb_noun_mapping[action_[0]].capitalize()
        action_ = " ".join(action_)
        text = moral_template.format(action_)
        if args.split_sentence:
            import spacy
            # use spacy to tokenize input sentence
            nlp = spacy.load(args.spacy_model)
            tokens = nlp(text)
            print(tokens)
            sentences = []
            for s in tokens.sents:
                print(" - {}".format(s))
                sentences.append(s.text)
        else:
            sentences = [text]

        if len(sentences) > 2:
            print(
                "WARNING: only the first two sentences in the text will be considered!"
            )
            sentences = sentences[:2]

        for model_name, model in models.items():
            if model_name not in list(data_user_study_BERTKnowledge.keys()):
                data_user_study_BERTKnowledge[model_name] = {}
            if verbose:
                print("\n{}:".format(model_name))
            original_log_probs_list, [token_ids], [
                masked_indices
            ] = model.get_batch_generation([sentences], try_cuda=False)

            index_list = None
            if vocab_subset is not None:
                # filter log_probs
                filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
                    vocab_subset)
                filtered_log_probs_list = model.filter_logprobs(
                    original_log_probs_list, filter_logprob_indices)
            else:
                filtered_log_probs_list = original_log_probs_list
            # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
            if masked_indices and len(masked_indices) > 0:
                _, _, experiment_result, _ = evaluation_metrics.get_ranking(
                    filtered_log_probs_list[0],
                    masked_indices,
                    model.vocab,
                    index_list=index_list,
                    print_generation=verbose)

            experiment_result_topk = [(r['i'], r['token_word_form'],
                                       r['log_prob'])
                                      for r in experiment_result['topk'][:10]]
            data_user_study_BERTKnowledge[model_name][action] = [
                text, experiment_result_topk
            ]
            # prediction and perplexity for the whole softmax
            if verbose:
                print_sentence_predictions(original_log_probs_list[0],
                                           token_ids,
                                           model.vocab,
                                           masked_indices=masked_indices)

    print(data_user_study_BERTKnowledge)

    pickle.dump(data_user_study_BERTKnowledge,
                open('./parsed_BERTKnowledge_tests.p', 'wb'))
Example #26
0
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError('Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    print(model)
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    elif model_type_name == "roberta":
        model_name = "RoBERTa_{}".format(args.roberta_model_name)
    elif model_type_name == "hfroberta":
        model_name = "hfRoBERTa_{}".format(args.hfroberta_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger
        )

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.0
    MRR_negative = 0.0
    MRR_positive = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    # spearman rank correlation
    # overlap at 1
    if args.use_negated_probes:
        Spearman = 0.0
        Overlap = 0.0
        num_valid_negation = 0.0

    data = load_file(args.dataset_filename)

    print(len(data))

    if args.lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        all_samples = lowercase_samples(
            data, use_negated_probes=args.use_negated_probes
        )
    else:
        # keep samples as they are
        all_samples = data
        # TREx data
        for i, sample in enumerate(all_samples):
            if 'masked_sentences' not in sample.keys():
                sample['masked_sentences'] = []
                for evidence in sample['evidences']:
                    sample['masked_sentences'].append(evidence['masked_sentence'])
                if i == 0:
                    print('not masked_sentences, but masked_sentence.')

    all_samples, ret_msg = filter_samples(
        model, data, vocab_subset, args.max_sentence_length, args.template
    )

    # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename)
    # with open(OUT_FILENAME, 'w') as outfile:
    #     for entry in all_samples:
    #         json.dump(entry, outfile)
    #         outfile.write('\n')

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        facts = []
        for sample in all_samples:
            sub = sample["sub_label"]
            obj = sample["obj_label"]
            if (sub, obj) not in facts:
                facts.append((sub, obj))
        local_msg = "distinct template facts: {}".format(len(facts))
        logger.info("\n" + local_msg + "\n")
        print(local_msg)
        all_samples = []
        for fact in facts:
            (sub, obj) = fact
            sample = {}
            sample["sub_label"] = sub
            sample["obj_label"] = obj
            # sobstitute all sentences with a standard template
            sample["masked_sentences"] = parse_template(
                args.template.strip(), sample["sub_label"].strip(), base.MASK
            )
            if args.use_negated_probes:
                # substitute all negated sentences with a standard template
                sample["negated"] = parse_template(
                    args.template_negated.strip(),
                    sample["sub_label"].strip(),
                    base.MASK,
                )
            all_samples.append(sample)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")
    if args.use_negated_probes:
        sentences_batches_negated, ret_msg = batchify_negated(
            all_samples, args.batch_size
        )
        logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []

    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        (
            original_log_probs_list,
            token_ids_list,
            masked_indices_list,
        ) = model.get_batch_generation(sentences_b, logger=logger)

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices
            )
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        for sample in samples_b:
            obj_label_id = model.get_id(sample["obj_label"])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]
                    )
                )
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]
                    )
                )
            elif vocab_subset is not None and sample["obj_label"] not in vocab_subset:
                raise ValueError(
                    "object label {} not in vocab subset".format(sample["obj_label"])
                )

            label_index_list.append(obj_label_id)

        arguments = [
            {
                "original_log_probs": original_log_probs,
                "filtered_log_probs": filtered_log_probs,
                "token_ids": token_ids,
                "vocab": model.vocab,
                "label_index": label_index[0],
                "masked_indices": masked_indices,
                "interactive": args.interactive,
                "index_list": index_list,
                "sample": sample,
            }
            for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip(
                original_log_probs_list,
                filtered_log_probs_list,
                token_ids_list,
                masked_indices_list,
                label_index_list,
                samples_b,
            )
        ]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        if args.use_negated_probes:
            sentences_b_negated = sentences_batches_negated[i]

            # if no negated sentences in batch
            if all(s[0] == "" for s in sentences_b_negated):
                res_negated = [(float("nan"), float("nan"), "")] * args.batch_size
            # eval negated batch
            else:
                (
                    original_log_probs_list_negated,
                    token_ids_list_negated,
                    masked_indices_list_negated,
                ) = model.get_batch_generation(sentences_b_negated, logger=logger)
                if vocab_subset is not None:
                    # filter log_probs
                    filtered_log_probs_list_negated = model.filter_logprobs(
                        original_log_probs_list_negated, filter_logprob_indices
                    )
                else:
                    filtered_log_probs_list_negated = original_log_probs_list_negated

                arguments = [
                    {
                        "log_probs": filtered_log_probs,
                        "log_probs_negated": filtered_log_probs_negated,
                        "token_ids": token_ids,
                        "vocab": model.vocab,
                        "label_index": label_index[0],
                        "masked_indices": masked_indices,
                        "masked_indices_negated": masked_indices_negated,
                        "index_list": index_list,
                    }
                    for filtered_log_probs, filtered_log_probs_negated, token_ids, masked_indices, masked_indices_negated, label_index in zip(
                        filtered_log_probs_list,
                        filtered_log_probs_list_negated,
                        token_ids_list,
                        masked_indices_list,
                        masked_indices_list_negated,
                        label_index_list,
                    )
                ]
                res_negated = pool.map(run_thread_negated, arguments)

        for idx, result in enumerate(res):

            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element["sample"] = sample
            element["uuid"] = sample["uuid"]
            element["token_ids"] = token_ids_list[idx]
            element["masked_indices"] = masked_indices_list[idx]
            element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk
            element["sample_MRR"] = sample_MRR
            element["sample_Precision"] = sample_P
            element["sample_perplexity"] = sample_perplexity
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            if args.use_negated_probes:
                overlap, spearman, msg = res_negated[idx]
                # sum overlap and spearmanr if not nan
                if spearman == spearman:
                    element["spearmanr"] = spearman
                    element["overlap"] = overlap
                    Overlap += overlap
                    Spearman += spearman
                    num_valid_negation += 1.0

            MRR += sample_MRR
            Precision += sample_P
            Precision1 += element["sample_Precision1"]

            # the judgment of the annotators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if "judgments" in sample:
                # only for Google-RE
                for x in sample["judgments"]:
                    if x["judgment"] == "yes":
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element["judgement"] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element["judgement"] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P

            list_of_results.append(element)

    pool.close()
    pool.join()

    # stats
    try:
       # Mean reciprocal rank
       MRR /= len(list_of_results)

       # Precision
       Precision /= len(list_of_results)
       Precision1 /= len(list_of_results)
    except ZeroDivisionError:
       MRR = Precision = Precision1 = 0.0

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)

    if args.use_negated_probes:
        Overlap /= num_valid_negation
        Spearman /= num_valid_negation
        msg += "\n"
        msg += "results negation:\n"
        msg += "all_negated_samples: {}\n".format(int(num_valid_negation))
        msg += "global spearman rank affirmative/negated: {}\n".format(Spearman)
        msg += "global overlap at 1 affirmative/negated: {}\n".format(Overlap)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement
        )
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement
        )
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(
        list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision
    )
    with open("{}/result.pkl".format(log_directory), "wb") as f:
        pickle.dump(all_results, f)

    return Precision1
def main(args, shuffle_data=True, model=None):

    if len(args.models_names) > 1:
        raise ValueError(
            'Please specify a single language model (e.g., --lm "bert").')

    msg = ""

    [model_type_name] = args.models_names

    args.output_feature_path = getattr(args, 'output_feature_path', '')
    if getattr(args, 'knn_thresh', 0) > 0:
        assert hasattr(args, 'knn_path')
        assert hasattr(args, 'modify_ans')
    else:
        args.knn_thresh = 0

    if getattr(args, 'knn_path', ''):
        knn_dict = torch.load(args.knn_path)
        if getattr(args, 'consine_dist', True):
            knn_dict['mask_features'] = knn_dict['mask_features'] / torch.norm(
                knn_dict['mask_features'], dim=1, keepdim=True)
        else:
            knn_dict['mask_features'] = knn_dict['mask_features']
        new_ans_dict = json.load(open(args.modify_ans))
        knn_dict['obj_labels'] = [
            new_ans_dict[uuid] for uuid in knn_dict['uuids']
        ]
    else:
        new_ans_dict = None
        knn_dict = None

    print(model)
    if model is None:
        model = build_model_by_name(model_type_name, args)

    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    # initialize logging
    if args.full_logdir:
        log_directory = args.full_logdir
    else:
        log_directory = create_logdir_with_timestamp(args.logdir, model_name)
    logger = init_logging(log_directory)
    msg += "model name: {}\n".format(model_name)

    # deal with vocab subset
    vocab_subset = None
    index_list = None
    msg += "args: {}\n".format(args)
    if args.common_vocab_filename is not None:
        vocab_subset = load_vocab(args.common_vocab_filename)
        msg += "common vocabulary size: {}\n".format(len(vocab_subset))

        # optimization for some LM (such as ELMo)
        model.optimize_top_layer(vocab_subset)

        filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs(
            vocab_subset, logger)

    logger.info("\n" + msg + "\n")

    # dump arguments on file for log
    with open(os.path.join(log_directory, 'args.json'), "w") as outfile:
        json.dump(vars(args), outfile)

    # stats
    samples_with_negative_judgement = 0
    samples_with_positive_judgement = 0

    # Mean reciprocal rank
    MRR = 0.0
    MRR_negative = 0.0
    MRR_positive = 0.0

    # Precision at (default 10)
    Precision = 0.0
    Precision1 = 0.0
    Precision1_modified = 0.0
    Precision_negative = 0.0
    Precision_positivie = 0.0

    # spearman rank correlation
    # overlap at 1
    if args.use_negated_probes:
        Spearman = 0.0
        Overlap = 0.0
        num_valid_negation = 0.0

    data = load_file(args.dataset_filename)

    print(len(data))

    all_samples, ret_msg = filter_samples(model, data, vocab_subset,
                                          args.max_sentence_length,
                                          args.template)

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))

    # if template is active (1) use a single example for (sub,obj) and (2) ...
    if args.template and args.template != "":
        if getattr(args, 'use_evidences', False):
            new_all_samples = []
            for sample in all_samples:
                if len(args.uuid_list
                       ) > 0 and sample['uuid'] not in args.uuid_list:
                    continue
                elif len(args.uuid_list) > 0:
                    print(sample['uuid'])
                sub = sample["sub_label"]
                if new_ans_dict is not None and sample['uuid'] in new_ans_dict:
                    # we need to replace the answer in this way
                    obj = new_ans_dict[sample['uuid']]
                else:
                    obj = sample["obj_label"]

                if sample['uuid'] == '11fc104b-bba2-412c-b2d7-cf06cd2bd715':
                    sample['evidences'] = sample['evidences'][:32]

                for ne, evidence in enumerate(sample['evidences']):
                    # maximum of 10 evidences per fact
                    if ne >= 10:
                        continue
                    new_sample = {'sub_label': sub, 'obj_label': obj}
                    if '[MASK]' not in evidence['masked_sentence']:
                        continue
                    new_sample['masked_sentences'] = [
                        evidence['masked_sentence']
                    ]
                    new_sample['uuid'] = sample['uuid']
                    new_all_samples.append(new_sample)

            all_samples = new_all_samples
        else:
            facts = []
            for sample in all_samples:
                sub = sample["sub_label"]
                if new_ans_dict is not None and sample['uuid'] in new_ans_dict:
                    # we need to replace the answer in this way
                    obj = new_ans_dict[sample['uuid']]
                else:
                    obj = sample["obj_label"]
                if (sub, obj) not in facts:
                    facts.append((sample['uuid'], sub, obj))
            local_msg = "distinct template facts: {}".format(len(facts))
            logger.info("\n" + local_msg + "\n")
            print(local_msg)

            all_samples = []
            for fact in facts:
                (uuid, sub, obj) = fact
                sample = {}
                sample["sub_label"] = sub
                sample["obj_label"] = obj
                sample["uuid"] = uuid
                # sobstitute all sentences with a standard template
                sample["masked_sentences"] = parse_template(
                    args.template.strip(), sample["sub_label"].strip(),
                    base.MASK)
                if args.use_negated_probes:
                    # substitute all negated sentences with a standard template
                    sample["negated"] = parse_template(
                        args.template_negated.strip(),
                        sample["sub_label"].strip(),
                        base.MASK,
                    )
                all_samples.append(sample)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    # shuffle data
    if shuffle_data:
        shuffle(all_samples)

    samples_batches, sentences_batches, ret_msg = batchify(
        all_samples, args.batch_size)
    logger.info("\n" + ret_msg + "\n")
    if args.use_negated_probes:
        sentences_batches_negated, ret_msg = batchify_negated(
            all_samples, args.batch_size)
        logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []
    total_modified = 0

    mask_feature_all, answers_list, uid_list = [], [], []
    correct_uuids = []
    knn_preds_list = []
    for i in tqdm(range(len(samples_batches))):
        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        rets = model.get_batch_generation(sentences_b,
                                          logger=logger,
                                          return_features=args.return_features
                                          or args.knn_thresh > 0)
        if len(rets) == 4:
            original_log_probs_list, token_ids_list, masked_indices_list, feature_tensor = rets
            mask_feature_all.append(feature_tensor)
        else:
            original_log_probs_list, token_ids_list, masked_indices_list = rets

        if vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = model.filter_logprobs(
                original_log_probs_list, filter_logprob_indices)
        else:
            filtered_log_probs_list = original_log_probs_list

        label_index_list = []
        modified_flags_list = []
        for ns, sample in enumerate(samples_b):
            obj_label_id = model.get_id(sample["obj_label"])
            answers_list.append(sample["obj_label"])
            uid_list.append(sample['uuid'])

            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if obj_label_id is None:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif model.vocab[obj_label_id[0]] != sample["obj_label"]:
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]))
            elif vocab_subset is not None and sample[
                    "obj_label"] not in vocab_subset:
                raise ValueError("object label {} not in vocab subset".format(
                    sample["obj_label"]))

            label_index_list.append(obj_label_id)

            if args.knn_thresh > 0:
                feature = feature_tensor[ns].view(1, -1)
                if getattr(args, 'consine_dist', True):
                    dist = torch.sum(feature * knn_dict['mask_features'],
                                     dim=1) / torch.norm(feature)
                else:
                    dist = torch.norm(feature - knn_dict['mask_features'],
                                      dim=1)
                min_dist, min_idx = torch.min(dist, dim=0)
                # print(min_dist.item())
                if min_dist < args.knn_thresh:
                    knn_pred = knn_dict['obj_labels'][min_idx.item()]
                    knn_preds_list.append(model.get_id(knn_pred)[0])
                    # if knn_dict['uuids'][min_idx.item()] == sample['uuid']:
                    #     pdb.set_trace()
                else:
                    knn_preds_list.append(-1)
                # log_probs.unsqueeze()
                # knn_preds_list.
            else:
                knn_preds_list.append(-1)

            # label whether the fact has been modified
            modified_flags_list.append(new_ans_dict is not None
                                       and sample['uuid'] in new_ans_dict)

        arguments = [{
            "original_log_probs": original_log_probs,
            "filtered_log_probs": filtered_log_probs,
            "token_ids": token_ids,
            "vocab": model.vocab,
            "label_index": label_index[0],
            "masked_indices": masked_indices,
            "interactive": args.interactive,
            "index_list": index_list,
            "sample": sample,
            "knn_pred": knn_pred,
            "modified": modified
        } for original_log_probs, filtered_log_probs, token_ids,
                     masked_indices, label_index, sample, knn_pred, modified in
                     zip(original_log_probs_list, filtered_log_probs_list,
                         token_ids_list, masked_indices_list, label_index_list,
                         samples_b, knn_preds_list, modified_flags_list)]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        if args.use_negated_probes:
            sentences_b_negated = sentences_batches_negated[i]

            # if no negated sentences in batch
            if all(s[0] == "" for s in sentences_b_negated):
                res_negated = [(float("nan"), float("nan"), "")
                               ] * args.batch_size
            # eval negated batch
            else:
                (
                    original_log_probs_list_negated,
                    token_ids_list_negated,
                    masked_indices_list_negated,
                ) = model.get_batch_generation(sentences_b_negated,
                                               logger=logger)
                if vocab_subset is not None:
                    # filter log_probs
                    filtered_log_probs_list_negated = model.filter_logprobs(
                        original_log_probs_list_negated,
                        filter_logprob_indices)
                else:
                    filtered_log_probs_list_negated = original_log_probs_list_negated

                arguments = [{
                    "log_probs": filtered_log_probs,
                    "log_probs_negated": filtered_log_probs_negated,
                    "token_ids": token_ids,
                    "vocab": model.vocab,
                    "label_index": label_index[0],
                    "masked_indices": masked_indices,
                    "masked_indices_negated": masked_indices_negated,
                    "index_list": index_list,
                } for filtered_log_probs, filtered_log_probs_negated,
                             token_ids, masked_indices, masked_indices_negated,
                             label_index in zip(
                                 filtered_log_probs_list,
                                 filtered_log_probs_list_negated,
                                 token_ids_list,
                                 masked_indices_list,
                                 masked_indices_list_negated,
                                 label_index_list,
                             )]
                res_negated = pool.map(run_thread_negated, arguments)

        for idx, result in enumerate(res):
            result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element["sample"] = sample
            element["uuid"] = sample["uuid"]
            element["token_ids"] = token_ids_list[idx]
            element["masked_indices"] = masked_indices_list[idx]
            element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk
            element["sample_MRR"] = sample_MRR
            element["sample_Precision"] = sample_P
            element["sample_perplexity"] = sample_perplexity
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]
            element["modified"] = result_masked_topk["modified"]
            if result_masked_topk["P_AT_1"] > 0:
                correct_uuids.append(element['uuid'])

            # print()
            # print("idx: {}".format(idx))
            # print("masked_entity: {}".format(result_masked_topk['masked_entity']))
            # for yi in range(10):
            #     print("\t{} {}".format(yi,result_masked_topk['topk'][yi]))
            # print("masked_indices_list: {}".format(masked_indices_list[idx]))
            # print("sample_MRR: {}".format(sample_MRR))
            # print("sample_P: {}".format(sample_P))
            # print("sample: {}".format(sample))
            # print()

            if args.use_negated_probes:
                overlap, spearman, msg = res_negated[idx]
                # sum overlap and spearmanr if not nan
                if spearman == spearman:
                    element["spearmanr"] = spearman
                    element["overlap"] = overlap
                    Overlap += overlap
                    Spearman += spearman
                    num_valid_negation += 1.0

            MRR += sample_MRR
            Precision += sample_P
            if element["modified"]:
                Precision1_modified += element["sample_Precision1"]
            else:
                Precision1 += element["sample_Precision1"]

            # the judgment of the annotators recording whether they are
            # evidence in the sentence that indicates a relation between two entities.
            num_yes = 0
            num_no = 0

            if "judgments" in sample:
                # only for Google-RE
                for x in sample["judgments"]:
                    if x["judgment"] == "yes":
                        num_yes += 1
                    else:
                        num_no += 1
                if num_no >= num_yes:
                    samples_with_negative_judgement += 1
                    element["judgement"] = "negative"
                    MRR_negative += sample_MRR
                    Precision_negative += sample_P
                else:
                    samples_with_positive_judgement += 1
                    element["judgement"] = "positive"
                    MRR_positive += sample_MRR
                    Precision_positivie += sample_P
            if element["modified"]:
                total_modified += 1
            else:
                list_of_results.append(element)

    pool.close()
    pool.join()

    if args.output_feature_path and len(list_of_results) == 0:
        # torch.save(out_dict, args.output_feature_path)
        # return empty results
        return Precision1, uid_list, mask_feature_all, answers_list
    elif len(list_of_results) == 0:
        pdb.set_trace()

    # stats
    # Mean reciprocal rank
    MRR /= len(list_of_results)

    # Precision
    Precision /= len(list_of_results)
    # Precision1 /= len(list_of_results)

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))
    msg += "global MRR: {}\n".format(MRR)
    msg += "global Precision at 10: {}\n".format(Precision)
    msg += "global Precision at 1: {}\n".format(Precision1)

    if args.use_negated_probes:
        Overlap /= num_valid_negation
        Spearman /= num_valid_negation
        msg += "\n"
        msg += "results negation:\n"
        msg += "all_negated_samples: {}\n".format(int(num_valid_negation))
        msg += "global spearman rank affirmative/negated: {}\n".format(
            Spearman)
        msg += "global overlap at 1 affirmative/negated: {}\n".format(Overlap)

    if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0:
        # Google-RE specific
        MRR_negative /= samples_with_negative_judgement
        MRR_positive /= samples_with_positive_judgement
        Precision_negative /= samples_with_negative_judgement
        Precision_positivie /= samples_with_positive_judgement
        msg += "samples_with_negative_judgement: {}\n".format(
            samples_with_negative_judgement)
        msg += "samples_with_positive_judgement: {}\n".format(
            samples_with_positive_judgement)
        msg += "MRR_negative: {}\n".format(MRR_negative)
        msg += "MRR_positive: {}\n".format(MRR_positive)
        msg += "Precision_negative: {}\n".format(Precision_negative)
        msg += "Precision_positivie: {}\n".format(Precision_positivie)

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    all_results = dict(list_of_results=list_of_results,
                       global_MRR=MRR,
                       global_P_at_10=Precision)
    with open("{}/result.pkl".format(log_directory), "wb") as f:
        pickle.dump(all_results, f)

    if args.output_feature_path:
        # torch.save(out_dict, args.output_feature_path)
        return Precision1, len(
            list_of_results
        ), Precision1_modified, total_modified, uid_list, mask_feature_all, answers_list

    return Precision1, len(
        list_of_results), Precision1_modified, total_modified, correct_uuids