def main(args): sentences = [ ["the cat is on the table ."], # single-sentence instance ["the dog is sleeping on the sofa .", "he makes happy noises ."], # two-sentence ] print("Language Models: {}".format(args.models_names)) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) for model_name, model in models.items(): print("\n{}:".format(model_name)) if args.cuda: model.try_cuda() contextual_embeddings, sentence_lengths, tokenized_text_list = model.get_contextual_embeddings( sentences) # contextual_embeddings is a list of tensors, one tensor for each layer. # Each element contains one layer of the representations with shape # (x, y, z). # x - the batch size # y - the sequence length of the batch # z - the length of each layer vector print(f'Number of layers: {len(contextual_embeddings)}') for layer_id, layer in enumerate(contextual_embeddings): print(f'Layer {layer_id} has shape: {layer.shape}') print("sentence_lengths: {}".format(sentence_lengths)) print("tokenized_text_list: {}".format(tokenized_text_list))
def fill_cloze(args, input_jsonl, batch_size, beam_size): try_cuda = torch.cuda.is_available() model = build_model_by_name(args.models_names[0], args) with open(input_jsonl, 'r') as fin: data = [json.loads(l) for l in fin] # only keep qa pairs (1) with uppercase initials (2) <= 200 chars (3) not contain number data = [ d for d in data if d['answer'][0].isupper() and len(d['sentence']) <= 200 and not bool(re.search(r'\d', d['sentence'])) ] print('#qa pairs {}'.format(len(data))) acc_token_li, acc_sent_li = [], [] for b in tqdm(range(0, len(data), batch_size)): data_batch = data[b:b + batch_size] sents = [] for d in data_batch: start = d['answer_start'] end = start + len(d['answer']) sent = d['sentence'].replace('[', '(').replace(']', ')') sent = sent[:start] + '[' + sent[start:end] + ']' + sent[end:] sents.append(sent) acc_token, acc_sent = model.fill_cloze(sents, try_cuda=try_cuda, beam_size=beam_size) acc_token_li.append(acc_token) acc_sent_li.append(acc_sent) #print(acc_token, acc_sent) print('mean acc_token {}, mean acc_sent {}'.format(np.mean(acc_token_li), np.mean(acc_sent_li)))
def __vocab_intersection(models, filename): vocabularies = [] for arg_dict in models: args = argparse.Namespace(**arg_dict) print(args) model = build_model_by_name(args.lm, args) vocabularies.append(model.vocab) print(type(model.vocab)) if len(vocabularies) > 0: common_vocab = set(vocabularies[0]) for vocab in vocabularies: common_vocab = common_vocab.intersection(set(vocab)) # no special symbols in common_vocab for symbol in base.SPECIAL_SYMBOLS: if symbol in common_vocab: common_vocab.remove(symbol) # remove stop words from spacy.lang.en.stop_words import STOP_WORDS for stop_word in STOP_WORDS: if stop_word in common_vocab: print(stop_word) common_vocab.remove(stop_word) common_vocab = list(common_vocab) # remove punctuation and symbols nlp = spacy.load('en') manual_punctuation = ['(', ')', '.', ','] new_common_vocab = [] for i in tqdm(range(len(common_vocab))): word = common_vocab[i] doc = nlp(word) token = doc[0] if (len(doc) != 1): print(word) for idx, tok in enumerate(doc): print("{} - {}".format(idx, tok)) elif word in manual_punctuation: pass elif token.pos_ == "PUNCT": print("PUNCT: {}".format(word)) elif token.pos_ == "SYM": print("SYM: {}".format(word)) else: new_common_vocab.append(word) # print("{} - {}".format(word, token.pos_)) common_vocab = new_common_vocab # store common_vocab on file with open(filename, 'w') as f: for item in sorted(common_vocab): f.write("{}\n".format(item))
def main(args): if not args.subject and not args.relation: raise ValueError( 'You need to specify --subject and --relation to query language models.' ) print('Language Models: {}'.format(args.models_names)) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) vocab_subset = None if args.common_vocab_filename is not None: common_vocab = load_vocab(args.common_vocab_filename) print('Common vocabulary size: {}'.format(len(common_vocab))) vocab_subset = [x for x in common_vocab] prompt_file = os.path.join(args.prompts, args.relation + '.jsonl') if not os.path.exists(prompt_file): raise ValueError('Relation "{}" does not exist.'.format(args.relation)) prompts, weights = load_prompt_weights(prompt_file) for model_name, model in models.items(): print('\n{}:'.format(model_name)) index_list = None if vocab_subset is not None: filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset) ensemble_log_probs = 0 for prompt, weight in zip(prompts, weights): prompt = parse_prompt(prompt, args.subject, model.mask_token) log_prob, [token_ids ], [masked_indices ], _, _ = model.get_batch_generation([prompt], try_cuda=True) if vocab_subset is not None: filtered_log_probs = model.filter_logprobs( log_prob, filter_logprob_indices) else: filtered_log_probs = log_prob # rank over the subset of the vocab (if defined) for the SINGLE masked tokens if masked_indices and len(masked_indices) > 0: filtered_log_probs = filtered_log_probs[0][masked_indices[0]] ensemble_log_probs += filtered_log_probs * weight ensemble_log_probs = F.log_softmax(ensemble_log_probs, dim=0) evaluation_metrics.get_ranking(ensemble_log_probs, model.vocab, label_index=None, index_list=index_list, topk=1000, P_AT=10, print_generation=True)
def refine_cloze(args): try_cuda = torch.cuda.is_available() model = build_model_by_name(args.models_names[0], args) sents = [ 'The theory of relativity [ is killed by ] Einstein .', 'Windows [ is killed by ] Microsoft .' ] model.refine_cloze(sents, try_cuda=try_cuda)
def main(): with open(trial_2 + "priming/antonyme_adj.txt", "r") as ah: pairs = [(p.split()[0].strip(), p.split()[1].strip()) for p in ah.readlines()] print(pairs) args_stud = Args_Stud() bert = build_model_by_name("bert", args_stud) for pair in pairs: maskPairs(pair, bert) return
def fill_cloze_webquestion(args, input_file, batch_size, beam_size): try_cuda = torch.cuda.is_available() model = build_model_by_name(args.models_names[0], args) with open(input_file, 'r') as fin: # keep statement based on number of words in the answer sents = [l.strip() for l in fin] sents = [s for s in sents if len(re.split('\[|\]', s)[1].split()) == 1] print('#qa pairs {}'.format(len(sents))) acc_token_li, acc_sent_li = [], [] for b in tqdm(range(0, len(sents), batch_size)): acc_token, acc_sent = model.fill_cloze(sents[b:b + batch_size], try_cuda=try_cuda, beam_size=beam_size) acc_token_li.append(acc_token) acc_sent_li.append(acc_sent) #print(acc_token, acc_sent) print('mean acc_token {}, mean acc_sent {}'.format(np.mean(acc_token_li), np.mean(acc_sent_li)))
def fill_cloze_lama_squad(args, input_jsonl, batch_size, beam_size): try_cuda = torch.cuda.is_available() model = build_model_by_name(args.models_names[0], args) with open(input_jsonl, 'r') as fin: data = [json.loads(l) for l in fin] print('#qa pairs {}'.format(len(data))) acc_token_li, acc_sent_li = [], [] for b in tqdm(range(0, len(data), batch_size)): data_batch = data[b:b + batch_size] sents = [] for d in data_batch: sents.append(d['masked_sentences'][0].replace( '[MASK]', '[{}]'.format(d['obj_label']))) acc_token, acc_sent = model.fill_cloze(sents, try_cuda=try_cuda, beam_size=beam_size) acc_token_li.append(acc_token) acc_sent_li.append(acc_sent) #print(acc_token, acc_sent) print('mean acc_token {}, mean acc_sent {}'.format(np.mean(acc_token_li), np.mean(acc_sent_li)))
def pattern_score(args, pattern_json, output_file): try_cuda = torch.cuda.is_available() model = build_model_by_name(args.models_names[0], args) with open(pattern_json, 'r') as fin: pattern_json = json.load(fin) batch_size = 32 pid2pattern = defaultdict(lambda: {}) for pid in tqdm(sorted(pattern_json)): #if not pid.startswith('P69_'): # continue snippets = pattern_json[pid]['snippet'] occs = pattern_json[pid]['occs'] for (snippet, direction), count in snippets: if len(snippet) <= 5 or len(snippet) >= 100: # longer than 5 chars continue loss = 0 num_batch = np.ceil(len(occs) / batch_size) for b in range(0, len(occs), batch_size): occs_batch = occs[b:b + batch_size] sentences = [ '{} {} ({})'.format(h, snippet, t) if direction == 1 else '{} {} ({})'.format(t, snippet, h) for h, t in occs_batch ] #print((snippet, direction), count) #print(sentences) #input() loss += model.get_rc_loss(sentences, try_cuda=try_cuda)[0].item() pid2pattern[pid][snippet] = loss / num_batch #print(pid) #print(sorted(pid2pattern[pid].items(), key=lambda x: x[1])) #input() with open(output_file, 'w') as fout: for pid, pats in pid2pattern.items(): pats = sorted(pats.items(), key=lambda x: x[1]) fout.write('{}\t{}\n'.format(pid, json.dumps(pats)))
def encode(args, sentences, sort_input=False): """Create an EncodedDataset from a list of sentences Parameters: sentences (list[list[string]]): list of elements. Each element is a list that contains either a single sentence or two sentences sort_input (bool): if true, sort sentences by number of tokens in them Returns: dataset (EncodedDataset): an object that contains the contextual representations of the input sentences """ print("Language Models: {}".format(args.lm)) model = build_model_by_name(args.lm, args) # sort sentences by number of tokens in them to make sure that in all # batches there are sentence with a similar numbers of tokens if sort_input: sorted(sentences, key=lambda k: len(" ".join(k).split())) encoded_sents = [] for current_batch in tqdm(_batchify(sentences, args.batch_size)): embeddings, sent_lens, tokenized_sents = model.get_contextual_embeddings( current_batch) agg_embeddings = _aggregate_layers( embeddings) # [#batchsize, #max_sent_len, #dim] sent_embeddings = [ agg_embeddings[i, :l] for i, l in enumerate(sent_lens) ] encoded_sents.extend( list(zip(sent_embeddings, sent_lens, tokenized_sents))) dataset = EncodedDataset(encoded_sents) return dataset
def main(args): try_cuda = torch.cuda.is_available() model = build_model_by_name(args.models_names[0], args) model.add_hooks() embedding_weight = model.get_embedding_weight() sentences = [ 'The theory of relativity was developed by Einstein.', 'Windows was developed by Microsoft.' ] sentences = ['(The theory of relativity) was found by (Einstein.)'] sentences = ['(Barack Obama) was born in (Hawaii.)'] sentences = ['Him (speaks English.)'] sentences = [ '[The theory of relativity was] killed [by Einstein].', '[Windows was] killed [by Microsoft].' ] for _ in range(50): for token_to_flip in range(0, 3): # TODO: for each token in the trigger # back propagation model.zero_grad() loss, tokens, _, unbracket_mask = model.get_rc_loss( sentences, try_cuda=try_cuda) # SHAPE: (batch_size, seq_len) unbracket_mask = unbracket_mask.bool() loss.backward() print(loss) # SHAPE: (batch_size, seq_len, emb_dim) grad = base_connector.extracted_grads[0] bs, _, emb_dim = grad.size() base_connector.extracted_grads = [] # TODO # SHAPE: (batch_size, unbracket_len, emb_dim) #grad = grad.masked_select(F.pad(unbracket_mask, (1, 0), 'constant', False)[:, :-1].unsqueeze(-1)).view(bs, -1, emb_dim) grad = grad.masked_select(unbracket_mask.unsqueeze(-1)).view( bs, -1, emb_dim) # SHAPE: (1, emb_dim) grad = grad.sum(dim=0)[token_to_flip].unsqueeze(0) print((grad * grad).sum().sqrt()) # SHAPE: (batch_size, unbracket_len) tokens = tokens.masked_select(unbracket_mask).view(bs, -1) token_tochange = tokens[0][token_to_flip].item() # Use hotflip (linear approximation) attack to get the top num_candidates candidates = attacks.hotflip_attack(grad, embedding_weight, [token_tochange], increase_loss=False, num_candidates=10)[0] print(model.tokenizer.convert_ids_to_tokens([token_tochange]), model.tokenizer.convert_ids_to_tokens(candidates)) input()
def run_experiments( relations, data_path_pre, data_path_post, input_param={ "lm": "bert", "label": "bert_large", "models_names": ["bert"], "bert_model_name": "bert-large-cased", "bert_model_dir": "pre-trained_language_models/bert/cased_L-24_H-1024_A-16", }, use_negated_probes=False, ): model = None pp = pprint.PrettyPrinter(width=41, compact=True) all_Precision1 = [] type_Precision1 = defaultdict(list) type_count = defaultdict(list) results_file = open("last_results.csv", "w+") uid_list_all, mask_feature_list_all, answers_list_all = [], [], [] all_correct_uuids = [] total_modified_correct, total_unmodified_correct = 0, 0 total_modified_num, total_unmodified_num = 0, 0 for relation in relations: # if "type" not in relation or relation["type"] != "1-1": # continue pp.pprint(relation) PARAMETERS = { "dataset_filename": "{}{}{}".format(data_path_pre, relation["relation"], data_path_post), "common_vocab_filename": 'pre-trained_language_models/bert/cased_L-12_H-768_A-12/vocab.txt', #"pre-trained_language_models/common_vocab_cased.txt", "template": "", "bert_vocab_name": "vocab.txt", "batch_size": 32, "logdir": "output", "full_logdir": "output/results/{}/{}".format(input_param["label"], relation["relation"]), "lowercase": False, "max_sentence_length": 512, # change to 512 later "threads": 2, "interactive": False, "use_negated_probes": use_negated_probes, "return_features": False, "uuid_list": [] } if "template" in relation: PARAMETERS["template"] = relation["template"] if use_negated_probes: PARAMETERS["template_negated"] = relation["template_negated"] PARAMETERS.update(input_param) print(PARAMETERS) args = argparse.Namespace(**PARAMETERS) # see if file exists try: data = load_file(args.dataset_filename) except Exception as e: print("Relation {} excluded.".format(relation["relation"])) print("Exception: {}".format(e)) continue if model is None: [model_type_name] = args.models_names model = build_model_by_name(model_type_name, args) if getattr(args, 'output_feature_path', ''): # Get the features for kNN-LM. Ignore this part if only obtaining the correct-predicted questions. Precision1, total_unmodified, Precision1_modified, total_modified, uid_list, mask_feature_list, answers_list = run_evaluation( args, shuffle_data=False, model=model) if len(uid_list) > 0: uid_list_all.extend(uid_list) mask_feature_tensor = torch.cat(mask_feature_list, dim=0) mask_feature_list_all.append(mask_feature_tensor) answers_list_all.extend(answers_list) else: Precision1, total_unmodified, Precision1_modified, total_modified, correct_uuids = run_evaluation( args, shuffle_data=False, model=model) all_correct_uuids.extend(correct_uuids) total_modified_correct += Precision1_modified total_unmodified_correct += Precision1 total_modified_num += total_modified total_unmodified_num += total_unmodified print("P@1 : {}".format(Precision1), flush=True) all_Precision1.append(Precision1) results_file.write("{},{}\n".format(relation["relation"], round(Precision1 * 100, 2))) results_file.flush() if "type" in relation: type_Precision1[relation["type"]].append(Precision1) data = load_file(PARAMETERS["dataset_filename"]) type_count[relation["type"]].append(len(data)) mean_p1 = statistics.mean(all_Precision1) print("@@@ {} - mean P@1: {}".format(input_param["label"], mean_p1)) print("Unmodified acc: {}, modified acc: {}".format( total_unmodified_correct / float(total_unmodified_num), 0 if total_modified_num == 0 else total_modified_correct / float(total_modified_num))) results_file.close() for t, l in type_Precision1.items(): print( "@@@ ", input_param["label"], t, statistics.mean(l), sum(type_count[t]), len(type_count[t]), flush=True, ) if len(uid_list_all) > 0: out_dict = { 'mask_features': torch.cat(mask_feature_list_all, dim=0), 'uuids': uid_list_all, 'obj_labels': answers_list_all } torch.save(out_dict, 'datastore/ds_change32.pt') if len(all_correct_uuids) > 0: if not os.path.exists('modification'): os.makedirs('modification') json.dump(all_correct_uuids, open('modification/correct_uuids.json', 'w')) return mean_p1, all_Precision1
# # print("sentence_lengths: {}".format(sentence_lengths)) # print("tokenized_text_list: {}".format(tokenized_text_list)) return contextual_embeddings, tokenized_text_list ###Generate modified args for the lama library (aka imitate the input of a command line) sys.argv = ['My code for HW3 Task1', '--lm', 'bert'] parser = options.get_general_parser() args = options.parse_args(parser) ###building the model only once (not inside the method for each line) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) ###opening the file with jsonlines.open('./train_testing_output.jsonl') as reader: for line in reader.iter(): dictionary = line ###masking the text text = dictionary['claim'] start_masked = dictionary["entity"]['start_character'] end_masked = dictionary["entity"]['end_character'] text_masked = text[0:start_masked] + '[MASK]' + text[ end_masked:len(text)] ### get embeddings
def main(args): if not args.text and not args.interactive: msg = "ERROR: either you start LAMA eval_generation with the " \ "interactive option (--i) or you pass in input a piece of text (--t)" raise ValueError(msg) stopping_condition = True print("Language Models: {}".format(args.models_names)) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) vocab_subset = None if args.common_vocab_filename is not None: common_vocab = load_vocab(args.common_vocab_filename) print("common vocabulary size: {}".format(len(common_vocab))) vocab_subset = [x for x in common_vocab] while stopping_condition: if args.text: text = args.text stopping_condition = False else: text = input("insert text:") if args.split_sentence: import spacy # use spacy to tokenize input sentence nlp = spacy.load(args.spacy_model) tokens = nlp(text) print(tokens) sentences = [] for s in tokens.sents: print(" - {}".format(s)) sentences.append(s.text) else: sentences = [text] if len(sentences) > 2: print( "WARNING: only the first two sentences in the text will be considered!" ) sentences = sentences[:2] for model_name, model in models.items(): print("\n{}:".format(model_name)) original_log_probs_list, [token_ids], [ masked_indices ] = model.get_batch_generation([sentences], try_cuda=False) index_list = None if vocab_subset is not None: # filter log_probs filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset) filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list # rank over the subset of the vocab (if defined) for the SINGLE masked tokens if masked_indices and len(masked_indices) > 0: evaluation_metrics.get_ranking(filtered_log_probs_list[0], masked_indices, model.vocab, index_list=index_list) # prediction and perplexity for the whole softmax print_sentence_predictions(original_log_probs_list[0], token_ids, model.vocab, masked_indices=masked_indices)
def main(): args_stud = Args_Stud() bert = build_model_by_name("bert", args_stud) vocab_subset = None f = open('./LAMA/lama/collected_paths.json', ) path_s = json.load(f) sent_path_ = path_s['sent2eval'] prem_path = path_s['premis2eval'] res_path_ = path_s["res_file"] paths = os.listdir(sent_path_) for path in paths: sent_path = sent_path_ + path res_path = res_path_ + path.split(".")[0].split( "_")[-2] + "_" + path.split(".")[0].split("_")[-2] + "/" os.makedirs(res_path, exist_ok=True) with open(sent_path, "r", encoding="utf8") as sf: sentences = [s.rstrip for s in sf.readlines()] print(sentences) with open(prem_path, "r") as pf: premisses = [p.rstrip() for p in pf.readlines()] data = {} for s in sentences: data[s] = [] original_log_probs_list, [token_ids], [ masked_indices ] = bert.get_batch_generation([[s]], try_cuda=True) index_list = None if vocab_subset is not None: # filter log_probs filter_logprob_indices, index_list = bert.init_indices_for_filter_logprobs( vocab_subset) filtered_log_probs_list = bert.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list # rank over the subset of the vocab (if defined) for the SINGLE masked tokens if masked_indices and len(masked_indices) > 0: MRR, P_AT_X, experiment_result, return_msg = evaluation_metrics.get_ranking( filtered_log_probs_list[0], masked_indices, bert.vocab, index_list=index_list) res = experiment_result["topk"] for r in res: data[s].append((r["token_word_form"], r["log_prob"])) with open(res_path + "NoPrem.json", "w+", encoding="utf-8") as f: json.dump(data, f) for pre in premisses: for s in sentences: data[s] = [] sentence = [str(pre) + "? " + s] original_log_probs_list, [token_ids], [ masked_indices ] = bert.get_batch_generation([sentence], try_cuda=False) index_list = None if vocab_subset is not None: # filter log_probs filter_logprob_indices, index_list = bert.init_indices_for_filter_logprobs( vocab_subset) filtered_log_probs_list = bert.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list # rank over the subset of the vocab (if defined) for the SINGLE masked tokens if masked_indices and len(masked_indices) > 0: MRR, P_AT_X, experiment_result, return_msg = evaluation_metrics.get_ranking( filtered_log_probs_list[0], masked_indices, bert.vocab, index_list=index_list) res = experiment_result["topk"] for r in res: data[s].append((r["token_word_form"], r["log_prob"])) with open(res_path + pre + ".json", "w+", encoding="utf-8") as f: json.dump(data, f)
def main(args): #Loading the JSON datasets #For each dataset we create a numpy array(len(dataset),768) #that we'll fill with Bert word embeddings #Pre-processed train dataset with open('./new_data_train.json') as json_file: json_train = json.load(json_file) x_train = np.zeros((len(json_train), 768)) #Pre-processed dev dataset with open('./new_data_dev.json') as json_file: json_test = json.load(json_file) x_test = np.zeros((len(json_test), 768)) #Official test set json_test_official = [] with open('./singletoken_test_fever_homework_NLP.jsonl') as json_file: for item in json_lines.reader(json_file): json_test_official.append(item) x_test_official = np.zeros((len(json_test_official), 768)) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) #For each model we do a for loop for each dataset to retrieve the word embeddings with Bert for model_name, model in models.items(): for index in range(len(json_train)): sentences = [[json_train[index]['claim'] ] #We pass to the model each claim of each datapoint ] print("\n{}:".format(model_name)) contextual_embeddings, sentence_lengths, tokenized_text_list = model.get_contextual_embeddings( sentences) x_train[index] = contextual_embeddings[11][0][ 0] #We select the CLS vector of the last layer # print(tokenized_text_list) #We do the same for the other two datasets for index in range(len(json_test)): sentences = [[json_test[index]['claim']]] print("\n{}:".format(model_name)) contextual_embeddings, sentence_lengths, tokenized_text_list = model.get_contextual_embeddings( sentences) x_test[index] = contextual_embeddings[11][0][0] print(tokenized_text_list) for index in range(len(json_test_official)): sentences = [[json_test_official[index]['claim']]] print("\n{}:".format(model_name)) contextual_embeddings, sentence_lengths, tokenized_text_list = model.get_contextual_embeddings( sentences) x_test_official[index] = contextual_embeddings[11][0][0] print(tokenized_text_list) return (x_train, json_train, x_test, json_test, x_test_official, json_test_official)
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError( 'Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names # print("------- Model: {}".format(model)) # print("------- Args: {}".format(args)) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), "w") as outfile: json.dump(vars(args), outfile) # Mean reciprocal rank MRR = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 data = load_file(args.dataset_filename) # data = data[:2000] fact_pair = load_file(args.fact_pair_filename) # print("@@@@@@@@@@@@@@") # print(fact_pair) # print(len(fact_pair)) # print("$$$$$$$$$$$$$$") all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) # print("!!!!!!!!!!!!!") # print(len(all_samples)) # 30847 logger.info("\n" + ret_msg + "\n") # for sample in all_samples: # sample["masked_sentences"] = [sample['evidences'][0]['masked_sentence']] # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify( all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = {d['subject'] + "_" + d['object']: [] for d in fact_pair} list_of_ranks = {d['subject'] + "_" + d['object']: [] for d in fact_pair} for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] ( original_log_probs_list, token_ids_list, masked_indices_list, ) = model.get_batch_generation(sentences_b, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif vocab_subset is not None and sample[ "obj_label"] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample["obj_label"])) label_index_list.append(obj_label_id) arguments = [{ "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, )] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} obj = sample['obj_label'] sub = sample['sub_label'] element["masked_sentences"] = sample["masked_sentences"][0] element["uuid"] = sample["uuid"] element["subject"] = sub element["object"] = obj element["rank"] = int(result_masked_topk['rank']) element["sample_Precision1"] = result_masked_topk["P_AT_1"] # element["sample"] = sample # element["token_ids"] = token_ids_list[idx] # element["masked_indices"] = masked_indices_list[idx] # element["label_index"] = label_index_list[idx] # element["masked_topk"] = result_masked_topk # element["sample_MRR"] = sample_MRR # element["sample_Precision"] = sample_P # element["sample_perplexity"] = sample_perplexity list_of_results[sub + "_" + obj].append(element) list_of_ranks[sub + "_" + obj].append(element["rank"]) # print("~~~~~~ rank: {}".format(result_masked_topk['rank'])) MRR += sample_MRR Precision += sample_P Precision1 += element["sample_Precision1"] append_data_line_to_jsonl( "reproduction/data/TREx_filter/{}_rank_results.jsonl".format( args.label), element) # 3122 # list_of_results.append(element) pool.close() pool.join() # stats # Mean reciprocal rank # MRR /= len(list_of_results) # # Precision # Precision /= len(list_of_results) # Precision1 /= len(list_of_results) # msg = "all_samples: {}\n".format(len(all_samples)) # # msg += "list_of_results: {}\n".format(len(list_of_results)) # msg += "global MRR: {}\n".format(MRR) # msg += "global Precision at 10: {}\n".format(Precision) # msg += "global Precision at 1: {}\n".format(Precision1) # logger.info("\n" + msg + "\n") # print("\n" + msg + "\n") # dump pickle with the result of the experiment # all_results = dict( # list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision # ) # with open("{}/result.pkl".format(log_directory), "wb") as f: # pickle.dump(all_results, f) # print() # model_name = args.models_names[0] # if args.models_names[0] == "bert": # model_name = args.bert_model_name # elif args.models_names[0] == "elmo": # if args.bert_model_name == args.bert_model_name # else: # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_results.jsonl".format(args.label), list_of_results) # 3122 # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_dic.jsonl".format(args.label), list_of_ranks) # 3122 save_data_line_to_jsonl( "reproduction/data/TREx_filter/{}_rank_list.jsonl".format(args.label), list(list_of_ranks.values())) # 3122 return Precision1
def main(args, shuffle_data=True, model=None, model2=None, refine_template=False, get_objs=False, dynamic='none', use_prob=False, bt_obj=None, temp_model=None): if len(args.models_names) > 1: raise ValueError( 'Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names #print(model) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename, lower=args.lowercase) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), "w") as outfile: json.dump(vars(args), outfile) if dynamic == 'all_topk': # save topk results for different k # stats samples_with_negative_judgement = [ 0 for _ in range(len(args.template)) ] samples_with_positive_judgement = [ 0 for _ in range(len(args.template)) ] # Mean reciprocal rank MRR = [0.0 for _ in range(len(args.template))] MRR_negative = [0.0 for _ in range(len(args.template))] MRR_positive = [0.0 for _ in range(len(args.template))] # Precision at (default 10) Precision = [0.0 for _ in range(len(args.template))] Precision1 = [0.0 for _ in range(len(args.template))] Precision_negative = [0.0 for _ in range(len(args.template))] Precision_positivie = [0.0 for _ in range(len(args.template))] list_of_results = [[] for _ in range(len(args.template))] P1_li = [[] for _ in range(len(args.template))] else: # stats samples_with_negative_judgement = [0] samples_with_positive_judgement = [0] # Mean reciprocal rank MRR = [0.0] MRR_negative = [0.0] MRR_positive = [0.0] # Precision at (default 10) Precision = [0.0] Precision1 = [0.0] Precision_negative = [0.0] Precision_positivie = [0.0] list_of_results = [[]] P1_li = [[]] data = load_file(args.dataset_filename) for s in data: s['raw_sub_label'] = s['sub_label'] s['raw_obj_label'] = s['obj_label'] if args.lowercase: # lowercase all samples logger.info("lowercasing all samples...") data = lowercase_samples(data) all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename) # with open(OUT_FILENAME, 'w') as outfile: # for entry in all_samples: # json.dump(entry, outfile) # outfile.write('\n') logger.info("\n" + ret_msg + "\n") print('#head-tails {} -> {}'.format(len(data), len(all_samples))) samples_batches_li, sentences_batches_li = [], [] for template in args.template: # if template is active (1) use a single example for (sub,obj) and (2) ... if template and template != "": facts = [] samples = [] for sample in all_samples: sub = sample["sub_label"] obj = sample["obj_label"] if (sub, obj) not in facts: facts.append((sub, obj)) samples.append(sample) local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") new_all_samples = [] for fact, raw_sample in zip(facts, samples): (sub, obj) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj # sobstitute all sentences with a standard template sample["masked_sentences"] = parse_template( template.strip(), raw_sample["raw_sub_label"].strip() if args.upper_entity else sample["sub_label"].strip(), model.mask_token) sub_uri = raw_sample[ 'sub_uri'] if 'sub_uri' in raw_sample else raw_sample['sub'] sample['entity_list'] = get_entity_list( template.strip(), raw_sample['raw_sub_label'].strip(), sub_uri, None, None) if dynamic.startswith('bt_topk') or (temp_model is not None and bt_obj): sample['sub_masked_sentences'] = parse_template_tokenize( template.strip(), sample["sub_label"].strip(), model, mask_part='sub') if dynamic == 'real_lm' or dynamic.startswith('real_lm_topk'): sample["tokenized_sentences"] = parse_template_tokenize( template.strip(), sample["sub_label"].strip(), model, mask_part='relation') # substitute sub and obj placeholder in template with corresponding str # and add bracket to the relational phrase sample['bracket_sentences'] = bracket_relational_phrase( template.strip(), sample['sub_label'].strip(), sample['obj_label'].strip()) new_all_samples.append(sample) # create uuid if not present i = 0 for sample in new_all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 if args.lowercase and not args.upper_entity: # lowercase all samples logger.info("lowercasing all samples...") new_all_samples = lowercase_samples(new_all_samples) # shuffle data if shuffle_data: perm = np.random.permutation(len(new_all_samples)) new_all_samples = np.array(new_all_samples)[perm] raise Exception samples_batches, sentences_batches, ret_msg = batchify( new_all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") samples_batches_li.append(samples_batches) sentences_batches_li.append(sentences_batches) sub_obj_labels = [(sample['sub_label'], sample['obj_label']) for batch in samples_batches for sample in batch] if get_objs: print('sub_obj_label {}'.format('\t'.join( map(lambda p: '{}\t{}'.format(*p), sub_obj_labels)))) return if refine_template: bracket_sentences = [ sample['bracket_sentences'] for sample in new_all_samples ] new_temp = model.refine_cloze(bracket_sentences, batch_size=32, try_cuda=True) new_temp = replace_template(template.strip(), ' '.join(new_temp)) print('old temp: {}'.format(template.strip())) print('new temp: {}'.format(new_temp)) return new_temp # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) samples_batches_li = list(zip(*samples_batches_li)) sentences_batches_li = list(zip(*sentences_batches_li)) c_inc_meaning = ['top12 prob gap', 'top1 prob'] c_inc_stat = np.zeros((2, 3)) # [[*, c_num], [*, inc_num]] loss_list = [] features_list = [] features_list2 = [] bt_features_list = [] label_index_tensor_list = [] for i in tqdm(range(len(samples_batches_li))): samples_b_all = samples_batches_li[i] sentences_b_all = sentences_batches_li[i] filter_lp_merge = None filter_lp_merge2 = None samples_b = samples_b_all[-1] max_score = float('-inf') consist_score_li = [] samples_b_prev = None for sentences_b, samples_b_this in zip(sentences_b_all, samples_b_all): if samples_b_prev is not None: for ps, ts in zip(samples_b_prev, samples_b_this): assert ps['uuid'] == ts['uuid'] entity_list_b = [s['entity_list'] for s in samples_b_this] # TODO: add tokens_tensor and mask_tensor for more models original_log_probs_list, token_ids_list, masked_indices_list, tokens_tensor, mask_tensor = \ model.get_batch_generation(sentences_b, logger=logger, entity_list=entity_list_b) if model2 is not None: original_log_probs_list2, token_ids_list2, masked_indices_list2, tokens_tensor2, mask_tensor2 = \ model2.get_batch_generation(sentences_b, logger=logger) if use_prob: # use prob instead of log prob original_log_probs_list = original_log_probs_list.exp() if model2 is not None: original_log_probs_list2 = original_log_probs_list2.exp() if dynamic == 'real_lm' or dynamic.startswith('real_lm_topk'): sentences_b_mask_rel = [ s['tokenized_sentences'][0] for s in samples_b_this ] relation_mask = [ s['tokenized_sentences'][1] for s in samples_b_this ] consist_log_probs_list, _, _, tokens_tensor, mask_tensor = \ model.get_batch_generation(sentences_b_mask_rel, logger=logger, relation_mask=relation_mask) else: consist_log_probs_list = original_log_probs_list if dynamic == 'lm' or dynamic == 'real_lm' or dynamic.startswith( 'real_lm_topk'): # use avg prob of the templates as score mask_tensor = mask_tensor.float() consist_log_probs_list_flat = consist_log_probs_list.view( -1, consist_log_probs_list.size(-1)) token_logprob = torch.gather( consist_log_probs_list_flat, dim=1, index=tokens_tensor.view( -1, 1)).view(*consist_log_probs_list.size()[:2]) token_logprob = token_logprob * mask_tensor consist_score = token_logprob.sum(-1) / mask_tensor.sum( -1) # normalized prob ''' if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices ) else: filtered_log_probs_list = original_log_probs_list ''' # get the prediction probability if vocab_subset is not None: filtered_log_probs_list = [ flp[masked_indices_list[ind][0]].index_select( dim=-1, index=filter_logprob_indices) for ind, flp in enumerate(original_log_probs_list) ] if model2 is not None: filtered_log_probs_list2 = [ flp[masked_indices_list2[ind][0]].index_select( dim=-1, index=filter_logprob_indices) for ind, flp in enumerate(original_log_probs_list2) ] else: filtered_log_probs_list = [ flp[masked_indices_list[ind][0]] for ind, flp in enumerate(original_log_probs_list) ] if model2 is not None: filtered_log_probs_list2 = [ flp[masked_indices_list2[ind][0]] for ind, flp in enumerate(original_log_probs_list2) ] if dynamic.startswith('bt_topk'): obj_topk = int(dynamic.rsplit('-', 1)[1]) top_obj_pred = [ flp.topk(k=obj_topk) for flp in filtered_log_probs_list ] top_obj_logprob, top_obj_pred = zip(*top_obj_pred) if dynamic.startswith('obj_lm_topk'): # use highest obj prob as consistency score consist_score = torch.tensor( [torch.max(flp).item() for flp in filtered_log_probs_list]) elif dynamic.startswith('obj_lmgap_topk'): # the gap between the highest prediction log p1 - log p2 get_gap = lambda top2: (top2[0] - top2[1]).item() consist_score = torch.tensor([ get_gap(torch.topk(flp, k=2)[0]) for flp in filtered_log_probs_list ]) elif dynamic.startswith('bt_topk'): # use the obj_topk highest obj to "back translate" sub consist_score_obj_topk = [] used_vocab = vocab_subset if vocab_subset is not None else model.vocab for obj_i in range(obj_topk): sentences_b_mask_sub = [[ replace_list(s['sub_masked_sentences'][0][0], model.mask_token, used_vocab[obj_pred[obj_i].item()]) ] for s, obj_pred in zip(samples_b_this, top_obj_pred)] sub_mask = [ s['sub_masked_sentences'][1] for s in samples_b_this ] # TODO: only masked lm can do this consist_log_probs_list, _, _, tokens_tensor, mask_tensor = \ model.get_batch_generation(sentences_b_mask_sub, logger=logger, relation_mask=sub_mask) # use avg prob of the sub as score mask_tensor = mask_tensor.float() consist_log_probs_list_flat = consist_log_probs_list.view( -1, consist_log_probs_list.size(-1)) token_logprob = torch.gather( consist_log_probs_list_flat, dim=1, index=tokens_tensor.view( -1, 1)).view(*consist_log_probs_list.size()[:2]) token_logprob = token_logprob * mask_tensor consist_score = token_logprob.sum(-1) / mask_tensor.sum( -1) # normalized prob consist_score_obj_topk.append(consist_score) # SHAPE: (batch_size, obj_topk) consist_score_obj_topk = torch.stack( consist_score_obj_topk).permute(1, 0) consist_score_weight = torch.stack(top_obj_logprob).exp() # SHAPE: (batch_size) consist_score = (consist_score_obj_topk * consist_score_weight).sum(-1) / ( consist_score_weight.sum(-1) + 1e-10) # add to overall probability if filter_lp_merge is None: filter_lp_merge = filtered_log_probs_list if model2 is not None: filter_lp_merge2 = filtered_log_probs_list2 if dynamic == 'lm' or dynamic == 'real_lm': max_score = consist_score elif dynamic.startswith('real_lm_topk') or \ dynamic.startswith('obj_lm_topk') or \ dynamic.startswith('obj_lmgap_topk') or \ dynamic.startswith('bt_topk'): consist_score_li.append(consist_score) else: if dynamic == 'none' and temp_model is None: filter_lp_merge = [ a + b for a, b in zip(filter_lp_merge, filtered_log_probs_list) ] elif dynamic == 'all_topk': filter_lp_merge.extend(filtered_log_probs_list) elif dynamic == 'lm' or dynamic == 'real_lm': filter_lp_merge = \ [a if c >= d else b for a, b, c, d in zip(filter_lp_merge, filtered_log_probs_list, max_score, consist_score)] max_score = torch.max(max_score, consist_score) elif dynamic.startswith('real_lm_topk') or \ dynamic.startswith('obj_lm_topk') or \ dynamic.startswith('obj_lmgap_topk') or \ dynamic.startswith('bt_topk'): filter_lp_merge.extend(filtered_log_probs_list) consist_score_li.append(consist_score) elif temp_model is not None: filter_lp_merge.extend(filtered_log_probs_list) if model2 is not None: filter_lp_merge2.extend(filtered_log_probs_list2) samples_b_prev = samples_b_this label_index_list = [] obj_word_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif vocab_subset is not None and sample[ "obj_label"] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample["obj_label"])) label_index_list.append(obj_label_id) obj_word_list.append(sample['obj_label']) if dynamic == 'all_topk' or \ dynamic.startswith('real_lm_topk') or \ dynamic.startswith('obj_lm_topk') or \ dynamic.startswith('obj_lmgap_topk') or \ dynamic.startswith('bt_topk') or \ temp_model is not None: # analyze prob # SHAPE: (batch_size, num_temp, filter_vocab_size) filter_lp_merge = torch.stack(filter_lp_merge, 0).view( len(sentences_b_all), len(filter_lp_merge) // len(sentences_b_all), -1).permute(1, 0, 2) if model2 is not None: filter_lp_merge2 = torch.stack(filter_lp_merge2, 0).view( len(sentences_b_all), len(filter_lp_merge2) // len(sentences_b_all), -1).permute(1, 0, 2) # SHAPE: (batch_size) label_index_tensor = torch.tensor( [index_list.index(li[0]) for li in label_index_list]) c_inc = np.array( metrics.analyze_prob(filter_lp_merge, label_index_tensor, output=False, method='sample')) c_inc_stat += c_inc elif dynamic == 'none': # SHAPE: (batch_size, 1, filter_vocab_size) filter_lp_merge = torch.stack(filter_lp_merge, 0).unsqueeze(1) # SHAPE: (batch_size, num_temp, filter_vocab_size) filter_lp_unmerge = filter_lp_merge if temp_model is not None: # optimize template weights temp_model_, optimizer = temp_model if optimizer is None: # predict filter_lp_merge = temp_model_(args.relation, filter_lp_merge.detach(), target=None) elif optimizer == 'precompute': # pre-compute and save featuers lp = filter_lp_merge # SHAPE: (batch_size * num_temp) features = torch.gather(lp.contiguous().view(-1, lp.size(-1)), dim=1, index=label_index_tensor.repeat( lp.size(1)).view(-1, 1)) features = features.view(-1, lp.size(1)) features_list.append(features) if not bt_obj: continue elif optimizer is not None: # train on the fly features_list.append( filter_lp_merge ) # collect features that will later be used in optimization if model2 is not None: features_list2.append( filter_lp_merge2 ) # collect features that will later be used in optimization label_index_tensor_list.append( label_index_tensor) # collect labels if not bt_obj: continue else: #filter_lp_merge = temp_model_(args.relation, filter_lp_merge.detach(), target=None) filter_lp_merge = filter_lp_merge.mean( 1) # use average prob to beam search if dynamic.startswith('real_lm_topk') or \ dynamic.startswith('obj_lm_topk') or \ dynamic.startswith('obj_lmgap_topk') or \ dynamic.startswith('bt_topk'): # dynamic ensemble real_lm_topk = min( int(dynamic[dynamic.find('topk') + 4:].split('-')[0]), len(consist_score_li)) # SHAPE: (batch_size, num_temp) consist_score_li = torch.stack(consist_score_li, -1) # SHAPE: (batch_size, topk) consist_score, consist_ind = consist_score_li.topk(real_lm_topk, dim=-1) # SHAPE: (batch_size, 1) consist_score = consist_score.min(-1, keepdim=True)[0] # SHAPE: (batch_size, num_temp, 1) consist_mask = (consist_score_li >= consist_score).float().unsqueeze(-1) # avg over top k filter_lp_merge = filter_lp_merge * consist_mask filter_lp_merge = filter_lp_merge.sum(1) / consist_mask.sum(1) if bt_obj: # choose top bt_obj objects and bach-translate subject # get the top bt_obj objects with highest probability used_vocab = vocab_subset if vocab_subset is not None else model.vocab temp_model_, optimizer = temp_model if optimizer is None: # use beam search # SHAPE: (batch_size, bt_obj) objs_score, objs_ind = filter_lp_merge.topk(bt_obj, dim=-1) objs_ind = torch.sort(objs_ind, dim=-1)[0] # the index must be ascending elif optimizer == 'precompute': # use ground truth objs_ind = label_index_tensor.view(-1, 1) bt_obj = 1 elif optimizer is not None: # get both ground truth and beam search # SHAPE: (batch_size, bt_obj) objs_score, objs_ind = filter_lp_merge.topk(bt_obj, dim=-1) objs_ind = torch.cat( [objs_ind, label_index_tensor.view(-1, 1)], -1) objs_ind = torch.sort(objs_ind, dim=-1)[0] # the index must be ascending bt_obj += 1 # bach translation sub_lp_list = [] for sentences_b, samples_b_this in zip( sentences_b_all, samples_b_all): # iter over templates for obj_i in range(bt_obj): # iter over objs sentences_b_mask_sub = [] for s, obj_pred, obj_word in zip(samples_b_this, objs_ind, obj_word_list): replace_tok = used_vocab[obj_pred[obj_i].item()] if optimizer == 'precompute': assert replace_tok.strip() == obj_word.strip() sentences_b_mask_sub.append([ replace_list(s['sub_masked_sentences'][0][0], model.mask_token, replace_tok) ]) sub_mask = [ s['sub_masked_sentences'][1] for s in samples_b_this ] # TODO: only masked lm can do this lp, _, _, tokens_tensor, mask_tensor = \ model.get_batch_generation(sentences_b_mask_sub, logger=logger, relation_mask=sub_mask) # use avg prob of the sub as score mask_tensor = mask_tensor.float() lp_flat = lp.view(-1, lp.size(-1)) sub_lp = torch.gather(lp_flat, dim=1, index=tokens_tensor.view( -1, 1)).view(*lp.size()[:2]) sub_lp = sub_lp * mask_tensor sub_lp_avg = sub_lp.sum(-1) / mask_tensor.sum( -1) # normalized prob sub_lp_list.append(sub_lp_avg) # SHAPE: (batch_size, num_temp, top_obj_num) num_temp = len(sentences_b_all) sub_lp_list = torch.cat(sub_lp_list, 0).view(num_temp, bt_obj, -1).permute(2, 0, 1) if optimizer == 'precompute': bt_features_list.append(sub_lp_list.squeeze(-1)) continue elif optimizer is not None: sub_lp_list_expand = torch.zeros_like(filter_lp_unmerge) # SHAPE: (batch_size, num_temp, vocab_size) sub_lp_list_expand.scatter_( -1, objs_ind.unsqueeze(1).repeat(1, num_temp, 1), sub_lp_list) bt_features_list.append(sub_lp_list_expand) bt_obj -= 1 continue # select obj prob expand_mask = torch.zeros_like(filter_lp_unmerge) expand_mask.scatter_(-1, objs_ind.unsqueeze(1).repeat(1, num_temp, 1), 1) # SHAPE: (batch_size, num_temp, top_obj_num) obj_lp_list = torch.masked_select(filter_lp_unmerge, expand_mask.eq(1)).view( -1, num_temp, bt_obj) # run temp model # SHAPE: (batch_size, vocab_size) filter_lp_merge_expand = torch.zeros_like(filter_lp_merge) # SHAPE: (batch_size, top_obj_num) filter_lp_merge = temp_model_(args.relation, torch.cat([obj_lp_list, sub_lp_list], 1), target=None) # expand results to vocab_size filter_lp_merge_expand.scatter_(-1, objs_ind, filter_lp_merge) filter_lp_merge = filter_lp_merge_expand + expand_mask[:, 0, :].log( ) # mask out other objs if len(filter_lp_merge.size()) == 2: filter_lp_merge = filter_lp_merge.unsqueeze(1) for temp_id in range(filter_lp_merge.size(1)): arguments = [{ "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filter_lp_merge[:, :temp_id + 1].sum(1), token_ids_list, masked_indices_list, label_index_list, samples_b, )] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element["sample"] = sample element["uuid"] = sample["uuid"] element["token_ids"] = token_ids_list[idx] element["masked_indices"] = masked_indices_list[idx] element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk element["sample_MRR"] = sample_MRR element["sample_Precision"] = sample_P element["sample_perplexity"] = sample_perplexity element["sample_Precision1"] = result_masked_topk["P_AT_1"] # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() MRR[temp_id] += sample_MRR Precision[temp_id] += sample_P Precision1[temp_id] += element["sample_Precision1"] P1_li[temp_id].append(element["sample_Precision1"]) ''' if element["sample_Precision1"] == 1: print(element["sample"]) input(1) else: print(element["sample"]) input(0) ''' # the judgment of the annotators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if "judgments" in sample: # only for Google-RE for x in sample["judgments"]: if x["judgment"] == "yes": num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement[temp_id] += 1 element["judgement"] = "negative" MRR_negative[temp_id] += sample_MRR Precision_negative[temp_id] += sample_P else: samples_with_positive_judgement[temp_id] += 1 element["judgement"] = "positive" MRR_positive[temp_id] += sample_MRR Precision_positivie[temp_id] += sample_P list_of_results[temp_id].append(element) if temp_model is not None: if temp_model[1] == 'precompute': features = torch.cat(features_list, 0) if bt_obj: bt_features = torch.cat(bt_features_list, 0) features = torch.cat([features, bt_features], 1) return features if temp_model[1] is not None: # optimize the model on the fly temp_model_, (optimizer, temperature) = temp_model temp_model_.cuda() # SHAPE: (batch_size, num_temp, vocab_size) features = torch.cat(features_list, 0) if model2 is not None: features2 = torch.cat(features_list2, 0) if bt_obj: bt_features = torch.cat(bt_features_list, 0) features = torch.cat([features, bt_features], 1) # compute weight # SHAPE: (batch_size,) label_index_tensor = torch.cat(label_index_tensor_list, 0) label_count = torch.bincount(label_index_tensor) label_count = torch.index_select(label_count, 0, label_index_tensor) sample_weight = F.softmax( temperature * torch.log(1.0 / label_count.float()), 0) * label_index_tensor.size(0) min_loss = 1e10 es = 0 batch_size = 128 for e in range(500): # loss = temp_model_(args.relation, features.cuda(), target=label_index_tensor.cuda(), use_softmax=True) loss_li = [] for b in range(0, features.size(0), batch_size): features_b = features[b:b + batch_size].cuda() label_index_tensor_b = label_index_tensor[b:b + batch_size].cuda( ) sample_weight_b = sample_weight[b:b + batch_size].cuda() loss = temp_model_(args.relation, features_b, target=label_index_tensor_b, sample_weight=sample_weight_b, use_softmax=True) if model2 is not None: features2_b = features2[b:b + batch_size].cuda() loss2 = temp_model_(args.relation, features2_b, target=label_index_tensor_b, sample_weight=sample_weight_b, use_softmax=True) loss = loss + loss2 optimizer.zero_grad() loss.backward() optimizer.step() loss_li.append(loss.cpu().item()) dev_loss = np.mean(loss_li) if dev_loss - min_loss < -1e-3: min_loss = dev_loss es = 0 else: es += 1 if es >= 30: print('early stop') break temp_model_.cpu() return min_loss pool.close() pool.join() for temp_id in range(len(P1_li)): # stats # Mean reciprocal rank MRR[temp_id] /= len(list_of_results[temp_id]) # Precision Precision[temp_id] /= len(list_of_results[temp_id]) Precision1[temp_id] /= len(list_of_results[temp_id]) msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results[temp_id])) msg += "global MRR: {}\n".format(MRR[temp_id]) msg += "global Precision at 10: {}\n".format(Precision[temp_id]) msg += "global Precision at 1: {}\n".format(Precision1[temp_id]) if samples_with_negative_judgement[ temp_id] > 0 and samples_with_positive_judgement[temp_id] > 0: # Google-RE specific MRR_negative[temp_id] /= samples_with_negative_judgement[temp_id] MRR_positive[temp_id] /= samples_with_positive_judgement[temp_id] Precision_negative[temp_id] /= samples_with_negative_judgement[ temp_id] Precision_positivie[temp_id] /= samples_with_positive_judgement[ temp_id] msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement[temp_id]) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement[temp_id]) msg += "MRR_negative: {}\n".format(MRR_negative[temp_id]) msg += "MRR_positive: {}\n".format(MRR_positive[temp_id]) msg += "Precision_negative: {}\n".format( Precision_negative[temp_id]) msg += "Precision_positivie: {}\n".format( Precision_positivie[temp_id]) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict(list_of_results=list_of_results[temp_id], global_MRR=MRR, global_P_at_10=Precision) with open("{}/result.pkl".format(log_directory), "wb") as f: pickle.dump(all_results, f) print('P1all {}'.format('\t'.join(map(str, P1_li[temp_id])))) print('meaning: {}'.format(c_inc_meaning)) print('correct-incorrect {}'.format('\t'.join( map(str, (c_inc_stat[:, :-1] / (c_inc_stat[:, -1:] + 1e-5)).reshape(-1))))) return Precision1[-1]
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError( "Please specify a single language model (e.g., --lm \"bert\").") msg = "" [model_type_name] = args.models_names print(model) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == 'fairseq': model_name = 'fairseq_{}'.format(args.fairseq_model_name) elif model_type_name == 'bert': model_name = 'BERT_{}'.format(args.bert_model_name) elif model_type_name == 'elmo': model_name = 'ELMo_{}'.format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), 'w') as outfile: json.dump(vars(args), outfile) # stats samples_with_negative_judgement = 0 samples_with_positive_judgement = 0 # Mean reciprocal rank MRR = 0. MRR_negative = 0. MRR_positive = 0. # Precision at (default 10) Precision = 0. Precision1 = 0. Precision_negative = 0. Precision_positivie = 0. data = load_file(args.dataset_filename) print(len(data)) if args.lowercase: # lowercase all samples logger.info("lowercasing all samples...") all_samples = lowercase_samples(data) else: # keep samples as they are all_samples = data all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename) # with open(OUT_FILENAME, 'w') as outfile: # for entry in all_samples: # json.dump(entry, outfile) # outfile.write('\n') logger.info("\n" + ret_msg + "\n") print(len(all_samples)) # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != '': facts = [] for sample in all_samples: sub = sample['sub_label'] obj = sample['obj_label'] if (sub, obj) not in facts: facts.append((sub, obj)) local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (sub, obj) = fact sample = {} sample['sub_label'] = sub sample['obj_label'] = obj # sobstitute all sentences with a standard template sample['masked_sentences'] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK) all_samples.append(sample) # create uuid if not present i = 0 for sample in all_samples: if 'uuid' not in sample: sample['uuid'] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify( all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = [] for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] original_log_probs_list, token_ids_list, masked_indices_list = model.get_batch_generation( sentences_b, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample['obj_label']) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample['obj_label'])) elif (model.vocab[obj_label_id[0]] != sample['obj_label']): raise ValueError( "object label {} not in model vocabulary".format( sample['obj_label'])) elif vocab_subset is not None and sample[ 'obj_label'] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample['obj_label'])) label_index_list.append(obj_label_id) arguments = [{ 'original_log_probs': original_log_probs, 'filtered_log_probs': filtered_log_probs, 'token_ids': token_ids, 'vocab': model.vocab, 'label_index': label_index[0], 'masked_indices': masked_indices, 'interactive': args.interactive, 'index_list': index_list, 'sample': sample } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b)] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element['sample'] = sample element['uuid'] = sample['uuid'] element['token_ids'] = token_ids_list[idx] element['masked_indices'] = masked_indices_list[idx] element['label_index'] = label_index_list[idx] element['masked_topk'] = result_masked_topk element['sample_MRR'] = sample_MRR element['sample_Precision'] = sample_P element['sample_perplexity'] = sample_perplexity element['sample_Precision1'] = result_masked_topk["P_AT_1"] # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() MRR += sample_MRR Precision += sample_P Precision1 += element['sample_Precision1'] # the judgment of the annotators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if 'judgments' in sample: # only for Google-RE for x in sample['judgments']: if (x['judgment'] == "yes"): num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement += 1 element['judgement'] = "negative" MRR_negative += sample_MRR Precision_negative += sample_P else: samples_with_positive_judgement += 1 element['judgement'] = "positive" MRR_positive += sample_MRR Precision_positivie += sample_P list_of_results.append(element) pool.close() pool.join() # stats # Mean reciprocal rank MRR /= len(list_of_results) # Precision Precision /= len(list_of_results) Precision1 /= len(list_of_results) msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results)) msg += "global MRR: {}\n".format(MRR) msg += "global Precision at 10: {}\n".format(Precision) msg += "global Precision at 1: {}\n".format(Precision1) if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0: # Google-RE specific MRR_negative /= samples_with_negative_judgement MRR_positive /= samples_with_positive_judgement Precision_negative /= samples_with_negative_judgement Precision_positivie /= samples_with_positive_judgement msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement) msg += "MRR_negative: {}\n".format(MRR_negative) msg += "MRR_positive: {}\n".format(MRR_positive) msg += "Precision_negative: {}\n".format(Precision_negative) msg += "Precision_positivie: {}\n".format(Precision_positivie) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict( list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision, ) with open("{}/result.pkl".format(log_directory), 'wb') as f: pickle.dump(all_results, f) return Precision1
def run_experiments( relations, data_path_pre, data_path_post, input_param={ "lm": "bert", "label": "bert_large", "models_names": ["bert"], "bert_model_name": "bert-large-cased", "bert_model_dir": "pre-trained_language_models/bert/cased_L-24_H-1024_A-16", }, use_negated_probes=False, ): model = None pp = pprint.PrettyPrinter(width=41, compact=True) all_Precision1 = [] all_Precision10 = [] type_Precision1 = defaultdict(list) type_count = defaultdict(list) # Append to results_file results_file = open("last_results.csv", "a", encoding='utf-8') results_file.write('\n') for relation in relations: pp.pprint(relation) PARAMETERS = { "dataset_filename": "{}{}{}".format(data_path_pre, relation["relation"], data_path_post), "common_vocab_filename": "pre-trained_language_models/common_vocab_cased.txt", "template": "", "bert_vocab_name": "vocab.txt", "batch_size": 32, "logdir": "output", "full_logdir": "output/results/{}/{}".format(input_param["label"], relation["relation"]), "lowercase": False, "max_sentence_length": 100, "threads": -1, "interactive": False, "use_negated_probes": use_negated_probes, } if "template" in relation: PARAMETERS["template"] = relation["template"] if use_negated_probes: PARAMETERS["template_negated"] = relation["template_negated"] PARAMETERS.update(input_param) print(PARAMETERS) args = argparse.Namespace(**PARAMETERS) relation_name = relation["relation"] if relation_name == "test": relation_name = data_path_pre.replace("/", "") + "_test" # see if file exists try: data = load_file(args.dataset_filename) except Exception as e: print("Relation {} excluded.".format(relation_name)) print("Exception: {}".format(e)) continue if model is None: [model_type_name] = args.models_names model = build_model_by_name(model_type_name, args) Precision1, Precision10 = run_evaluation(args, shuffle_data=False, model=model) print("P@1 : {}".format(Precision1), flush=True) all_Precision1.append(Precision1) all_Precision10.append(Precision10) results_file.write("[{}] {}: {}, P10 = {}, P1 = {}\n".format( datetime.now(), input_param["label"], relation_name, round(Precision10 * 100, 2), round(Precision1 * 100, 2))) results_file.flush() if "type" in relation: type_Precision1[relation["type"]].append(Precision1) data = load_file(PARAMETERS["dataset_filename"]) type_count[relation["type"]].append(len(data)) mean_p1 = statistics.mean(all_Precision1) mean_p10 = statistics.mean(all_Precision10) summaryP1 = "@@@ {} - mean P@10 = {}, mean P@1 = {}".format( input_param["label"], round(mean_p10 * 100, 2), round(mean_p1 * 100, 2)) print(summaryP1) results_file.write(f'{summaryP1}\n') results_file.flush() for t, l in type_Precision1.items(): prec1item = f'@@@ Label={input_param["label"]}, type={t}, samples={sum(type_count[t])}, relations={len(type_count[t])}, mean prec1={round(statistics.mean(l) * 100, 2)}\n' print(prec1item, flush=True) results_file.write(prec1item) results_file.flush() results_file.close() return mean_p1, all_Precision1
def run_experiments( data_path_pre, data_path_post, input_param={ "lm": "bert", "label": "bert_large", "models_names": ["bert"], "bert_model_name": "bert-large-cased", "bert_model_dir": "pre-trained_language_models/bert/cased_L-24_H-1024_A-16", }, use_negated_probes=False, ): model = None pp = pprint.PrettyPrinter(width=41, compact=True) all_Precision1 = [] type_Precision1 = defaultdict(list) type_count = defaultdict(list) for i in range(1): PARAMETERS = { "dataset_filename": "reproduction/data/TREx_filter/different_queries.jsonl", "fact_pair_filename": "reproduction/data/TREx_filter/different_queries_facts.jsonl", "common_vocab_filename": "pre-trained_language_models/common_vocab_cased.txt", "template": "", "bert_vocab_name": "vocab.txt", "batch_size": 10, "logdir": "output", "full_logdir": "output/results/{}/{}".format(input_param["label"], "different_queries"), "lowercase": False, "max_sentence_length": 100, "threads": -1, "interactive": False, "use_negated_probes": use_negated_probes, } PARAMETERS.update(input_param) args = argparse.Namespace(**PARAMETERS) if model is None: [model_type_name] = args.models_names model = build_model_by_name(model_type_name, args) Precision1 = run_evaluation(args, shuffle_data=False, model=model) print("P@1 : {}".format(Precision1), flush=True) all_Precision1.append(Precision1) mean_p1 = statistics.mean(all_Precision1) print("@@@ {} - mean P@1: {}".format(input_param["label"], mean_p1)) for t, l in type_Precision1.items(): print( "@@@ ", input_param["label"], t, statistics.mean(l), sum(type_count[t]), len(type_count[t]), flush=True, ) return mean_p1, all_Precision1
def run_experiments( relations, data_path_pre, data_path_post, refine_template, get_objs, batch_size, dynamic=None, use_prob=False, bt_obj=None, temp_model=None, save=None, load=None, feature_dir=None, enforce_prob=True, num_feat=1, temperature=0.0, use_model2=False, lowercase=False, upper_entity=False, input_param={ "lm": "bert", "label": "bert_large", "models_names": ["bert"], "bert_model_name": "bert-large-cased", "bert_model_dir": "pre-trained_language_models/bert/cased_L-24_H-1024_A-16", }, ): model, model2 = None, None pp = pprint.PrettyPrinter(width=41, compact=True) all_Precision1 = [] type_Precision1 = defaultdict(list) type_count = defaultdict(list) print('use lowercase: {}, use upper entity: {}'.format( lowercase, upper_entity)) results_file = open("last_results.csv", "w+") if refine_template: refine_temp_fout = open(refine_template, 'w') new_relations = [] templates_set = set() rel2numtemp = {} for relation in relations: # collect templates if 'template' in relation: if type(relation['template']) is not list: relation['template'] = [relation['template']] rel2numtemp[relation['relation']] = len(relation['template']) if temp_model is not None: if temp_model.startswith('mixture'): method = temp_model.split('_')[1] if method == 'optimize': # (extract feature) + optimize temp_model = TempModel(rel2numtemp, enforce_prob=enforce_prob, num_feat=num_feat) temp_model.train() optimizer = optim.Adam(temp_model.parameters(), lr=1e-1) temp_model = (temp_model, (optimizer, temperature)) elif method == 'precompute': # extract feature temp_model = (None, 'precompute') elif method == 'predict': # predict temp_model = TempModel( rel2numtemp, enforce_prob=enforce_prob, num_feat=num_feat) # TODO: number of feature if load is not None: temp_model.load_state_dict(torch.load(load)) temp_model.eval() temp_model = (temp_model, None) else: raise NotImplementedError else: raise NotImplementedError for relation in relations: pp.pprint(relation) PARAMETERS = { "relation": relation["relation"], "dataset_filename": "{}/{}{}".format(data_path_pre, relation["relation"], data_path_post), "common_vocab_filename": "pre-trained_language_models/common_vocab_cased.txt", "template": "", "bert_vocab_name": "vocab.txt", "batch_size": batch_size, "logdir": "output", "full_logdir": "output/results/{}/{}".format(input_param["label"], relation["relation"]), "lowercase": lowercase, "upper_entity": upper_entity, "max_sentence_length": 100, "threads": -1, "interactive": False, } dev_param = deepcopy(PARAMETERS) dev_param['dataset_filename'] = '{}/{}{}'.format( data_path_pre + '_dev', relation['relation'], data_path_post) bert_large_param = deepcopy(PARAMETERS) if 'template' in relation: PARAMETERS['template'] = relation['template'] dev_param['template'] = relation['template'] bert_large_param['template'] = relation['template'] PARAMETERS.update(input_param) dev_param.update(input_param) bert_large_param.update( LM_BERT_LARGE ) # this is used to optimize the weights for bert-base and bert-large at the same time print(PARAMETERS) args = argparse.Namespace(**PARAMETERS) dev_args = argparse.Namespace(**dev_param) bert_large_args = argparse.Namespace(**bert_large_param) # see if file exists try: data = load_file(args.dataset_filename) except Exception as e: print("Relation {} excluded.".format(relation["relation"])) print("Exception: {}".format(e)) continue if model is None: [model_type_name] = args.models_names model = build_model_by_name(model_type_name, args) if use_model2: model2 = build_model_by_name(bert_large_args.models_names[0], bert_large_args) if temp_model is not None: if temp_model[1] == 'precompute': features = run_evaluation( args, shuffle_data=False, model=model, refine_template=bool(refine_template), get_objs=get_objs, dynamic=dynamic, use_prob=use_prob, bt_obj=bt_obj, temp_model=temp_model) print('save features for {}'.format(relation['relation'])) torch.save(features, os.path.join(save, relation['relation'] + '.pt')) continue elif temp_model[1] is not None: # train temp model if feature_dir is None: loss = run_evaluation( args, shuffle_data=False, model=model, model2=model2, refine_template=bool(refine_template), get_objs=get_objs, dynamic=dynamic, use_prob=use_prob, bt_obj=bt_obj, temp_model=temp_model) else: temp_model_, (optimizer, temperature) = temp_model temp_model_.cuda() min_loss = 1e10 es = 0 for e in range(500): # SHAPE: (num_sample, num_temp) feature = torch.load( os.path.join(feature_dir, args.relation + '.pt')).cuda() #dev_feature = torch.load(os.path.join(feature_dir + '_dev', args.relation + '.pt')).cuda() #feature = torch.cat([feature, dev_feature], 0) #weight = feature.mean(0) #temp_model[0].set_weight(args.relation, weight) optimizer.zero_grad() loss = temp_model_(args.relation, feature) if os.path.exists(feature_dir + '__dev'): # TODO: debug dev_feature = torch.load( os.path.join(feature_dir + '_dev', args.relation + '.pt')).cuda() dev_loss = temp_model_(args.relation, dev_feature) else: dev_loss = loss loss.backward() optimizer.step() if dev_loss - min_loss < -1e-3: min_loss = dev_loss es = 0 else: es += 1 if es >= 10: print('early stop') break continue Precision1 = run_evaluation(args, shuffle_data=False, model=model, refine_template=bool(refine_template), get_objs=get_objs, dynamic=dynamic, use_prob=use_prob, bt_obj=bt_obj, temp_model=temp_model) if get_objs: return if refine_template and Precision1 is not None: if Precision1 in templates_set: continue templates_set.add(Precision1) new_relation = deepcopy(relation) new_relation['old_template'] = new_relation['template'] new_relation['template'] = Precision1 new_relations.append(new_relation) refine_temp_fout.write(json.dumps(new_relation) + '\n') refine_temp_fout.flush() continue print("P@1 : {}".format(Precision1), flush=True) all_Precision1.append(Precision1) results_file.write("{},{}\n".format(relation["relation"], round(Precision1 * 100, 2))) results_file.flush() if "type" in relation: type_Precision1[relation["type"]].append(Precision1) data = load_file(PARAMETERS["dataset_filename"]) type_count[relation["type"]].append(len(data)) if refine_template: refine_temp_fout.close() return if temp_model is not None: if save is not None and temp_model[0] is not None: torch.save(temp_model[0].state_dict(), save) return mean_p1 = statistics.mean(all_Precision1) print("@@@ {} - mean P@1: {}".format(input_param["label"], mean_p1)) results_file.close() for t, l in type_Precision1.items(): print( "@@@ ", input_param["label"], t, statistics.mean(l), sum(type_count[t]), len(type_count[t]), flush=True, ) return mean_p1, all_Precision1
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError( 'Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names # print("------- Model: {}".format(model)) # print("------- Args: {}".format(args)) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) # logger.info("\n" + msg + "\n") # dump arguments on file for log # with open("{}/args.json".format(log_directory), "w") as outfile: # json.dump(vars(args), outfile) # Mean reciprocal rank MRR = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 data = load_file(args.dataset_filename) all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) # logger.info("\n" + ret_msg + "\n") # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != "": facts = [] for sample in all_samples: sub = sample["sub_label"] obj = sample["obj_label"] if (sub, obj) not in facts: facts.append((sub, obj)) local_msg = "distinct template facts: {}".format(len(facts)) # logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (sub, obj) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj # sobstitute all sentences with a standard template sample["masked_sentences"] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK) if args.use_negated_probes: # substitute all negated sentences with a standard template sample["negated"] = parse_template( args.template_negated.strip(), sample["sub_label"].strip(), base.MASK, ) all_samples.append(sample) # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify( all_samples, args.batch_size) # logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) # list_of_results = [] # list_of_ranks = [] item_count = 0 for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] ( original_log_probs_list, token_ids_list, masked_indices_list, ) = model.get_batch_generation(sentences_b, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif vocab_subset is not None and sample[ "obj_label"] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample["obj_label"])) label_index_list.append(obj_label_id) arguments = [{ "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, )] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result # print("~~~~~~~~~~~~~~~~~~") # print(result_masked_topk) # logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} obj = sample['obj_label'] sub = sample['sub_label'] element["masked_sentences"] = sample["masked_sentences"][0] # element["uuid"] = sample["uuid"] element["subject"] = sub element["object"] = obj element["rank"] = int(result_masked_topk['rank']) # element["sample_Precision1"] = result_masked_topk["P_AT_1"] # element["sample"] = sample # element["token_ids"] = token_ids_list[idx] # element["masked_indices"] = masked_indices_list[idx] # element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk['topk'][:20] # element["sample_MRR"] = sample_MRR # element["sample_Precision"] = sample_P # element["sample_perplexity"] = sample_perplexity # list_of_results[sub + "_" + obj].append(element) # list_of_ranks[sub + "_" + obj].append(element["rank"]) # print("~~~~~~ rank: {}".format(result_masked_topk['rank'])) MRR += sample_MRR Precision += sample_P Precision1 += result_masked_topk["P_AT_1"] item_count += 1 append_data_line_to_jsonl( "reproduction/data/P_AT_K/{}_rank_results.jsonl".format( args.label), element) append_data_line_to_jsonl( "reproduction/data/P_AT_K/{}_rank_list.jsonl".format( args.label), element["rank"]) pool.close() pool.join() Precision1 /= item_count # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_results.jsonl".format(args.label), list_of_results) # 3122 # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_dic.jsonl".format(args.label), list_of_ranks) # 3122 # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_list.jsonl".format(args.label), list(list_of_ranks.values())) # 3122 return Precision1
def run_experiments( relations, data_path_pre, data_path_post, input_param={ "lm": "bert", "label": "bert_large", "models_names": ["bert"], "bert_model_name": "bert-large-cased", "bert_model_dir": "pre-trained_language_models/bert/cased_L-24_H-1024_A-16", }, ): model = None pp = pprint.PrettyPrinter(width=41, compact=True) all_Precision1 = [] type_Precision1 = defaultdict(list) type_count = defaultdict(list) results_file = open("last_results.csv", "w+") for relation in relations: pp.pprint(relation) PARAMETERS = { "dataset_filename": "{}{}{}".format(data_path_pre, relation["relation"], data_path_post), "common_vocab_filename": "pre-trained_language_models/common_vocab_cased.txt", "template": "", "bert_vocab_name": "vocab.txt", "batch_size": 32, "logdir": "output", "full_logdir": "output/results/{}/{}".format(input_param["label"], relation["relation"]), "lowercase": False, "max_sentence_length": 100, "threads": -1, "interactive": False, "use_negated_probes": False, } if "template" in relation: PARAMETERS["template"] = relation["template"] PARAMETERS.update(input_param) print(PARAMETERS) args = argparse.Namespace(**PARAMETERS) # see if file exists try: data = load_file(args.dataset_filename) except Exception as e: print("Relation {} excluded.".format(relation["relation"])) print("Exception: {}".format(e)) continue if model is None: [model_type_name] = args.models_names model = build_model_by_name(model_type_name, args) Precision1 = run_evaluation(args, shuffle_data=False, model=model) print("P@1 : {}".format(Precision1), flush=True) all_Precision1.append(Precision1) results_file.write("{},{}\n".format(relation["relation"], round(Precision1 * 100, 2))) results_file.flush() if "type" in relation: type_Precision1[relation["type"]].append(Precision1) data = load_file(PARAMETERS["dataset_filename"]) type_count[relation["type"]].append(len(data)) mean_p1 = statistics.mean(all_Precision1) print("@@@ {} - mean P@1: {}".format(input_param["label"], mean_p1)) results_file.close() for t, l in type_Precision1.items(): print( "@@@ ", input_param["label"], t, statistics.mean(l), sum(type_count[t]), len(type_count[t]), flush=True, ) return mean_p1, all_Precision1
def main(args): verbose = False moral_template = "{} is a [MASK] behavior ." data_user_study = pickle.load( open('../MoRT/data/parsed_yes_no_BERTBias.p', 'rb')) actions = list(data_user_study.keys()) actions = ['help people', 'help white people', 'help black people'] print("Language Models: {}".format(args.models_names)) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) vocab_subset = None if args.common_vocab_filename is not None: common_vocab = load_vocab(args.common_vocab_filename) print("common vocabulary size: {}".format(len(common_vocab))) vocab_subset = [x for x in common_vocab] data_user_study_BERTKnowledge = dict() for action in actions: action_ = action.split(" ") action_[0] = verb_noun_mapping[action_[0]].capitalize() action_ = " ".join(action_) text = moral_template.format(action_) if args.split_sentence: import spacy # use spacy to tokenize input sentence nlp = spacy.load(args.spacy_model) tokens = nlp(text) print(tokens) sentences = [] for s in tokens.sents: print(" - {}".format(s)) sentences.append(s.text) else: sentences = [text] if len(sentences) > 2: print( "WARNING: only the first two sentences in the text will be considered!" ) sentences = sentences[:2] for model_name, model in models.items(): if model_name not in list(data_user_study_BERTKnowledge.keys()): data_user_study_BERTKnowledge[model_name] = {} if verbose: print("\n{}:".format(model_name)) original_log_probs_list, [token_ids], [ masked_indices ] = model.get_batch_generation([sentences], try_cuda=False) index_list = None if vocab_subset is not None: # filter log_probs filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset) filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list # rank over the subset of the vocab (if defined) for the SINGLE masked tokens if masked_indices and len(masked_indices) > 0: _, _, experiment_result, _ = evaluation_metrics.get_ranking( filtered_log_probs_list[0], masked_indices, model.vocab, index_list=index_list, print_generation=verbose) experiment_result_topk = [(r['i'], r['token_word_form'], r['log_prob']) for r in experiment_result['topk'][:10]] data_user_study_BERTKnowledge[model_name][action] = [ text, experiment_result_topk ] # prediction and perplexity for the whole softmax if verbose: print_sentence_predictions(original_log_probs_list[0], token_ids, model.vocab, masked_indices=masked_indices) print(data_user_study_BERTKnowledge) pickle.dump(data_user_study_BERTKnowledge, open('./parsed_BERTKnowledge_tests.p', 'wb'))
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError('Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names print(model) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) elif model_type_name == "roberta": model_name = "RoBERTa_{}".format(args.roberta_model_name) elif model_type_name == "hfroberta": model_name = "hfRoBERTa_{}".format(args.hfroberta_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger ) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), "w") as outfile: json.dump(vars(args), outfile) # stats samples_with_negative_judgement = 0 samples_with_positive_judgement = 0 # Mean reciprocal rank MRR = 0.0 MRR_negative = 0.0 MRR_positive = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 # spearman rank correlation # overlap at 1 if args.use_negated_probes: Spearman = 0.0 Overlap = 0.0 num_valid_negation = 0.0 data = load_file(args.dataset_filename) print(len(data)) if args.lowercase: # lowercase all samples logger.info("lowercasing all samples...") all_samples = lowercase_samples( data, use_negated_probes=args.use_negated_probes ) else: # keep samples as they are all_samples = data # TREx data for i, sample in enumerate(all_samples): if 'masked_sentences' not in sample.keys(): sample['masked_sentences'] = [] for evidence in sample['evidences']: sample['masked_sentences'].append(evidence['masked_sentence']) if i == 0: print('not masked_sentences, but masked_sentence.') all_samples, ret_msg = filter_samples( model, data, vocab_subset, args.max_sentence_length, args.template ) # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename) # with open(OUT_FILENAME, 'w') as outfile: # for entry in all_samples: # json.dump(entry, outfile) # outfile.write('\n') logger.info("\n" + ret_msg + "\n") print(len(all_samples)) # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != "": facts = [] for sample in all_samples: sub = sample["sub_label"] obj = sample["obj_label"] if (sub, obj) not in facts: facts.append((sub, obj)) local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (sub, obj) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj # sobstitute all sentences with a standard template sample["masked_sentences"] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK ) if args.use_negated_probes: # substitute all negated sentences with a standard template sample["negated"] = parse_template( args.template_negated.strip(), sample["sub_label"].strip(), base.MASK, ) all_samples.append(sample) # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify(all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") if args.use_negated_probes: sentences_batches_negated, ret_msg = batchify_negated( all_samples, args.batch_size ) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = [] for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] ( original_log_probs_list, token_ids_list, masked_indices_list, ) = model.get_batch_generation(sentences_b, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices ) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"] ) ) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"] ) ) elif vocab_subset is not None and sample["obj_label"] not in vocab_subset: raise ValueError( "object label {} not in vocab subset".format(sample["obj_label"]) ) label_index_list.append(obj_label_id) arguments = [ { "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, ) ] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) if args.use_negated_probes: sentences_b_negated = sentences_batches_negated[i] # if no negated sentences in batch if all(s[0] == "" for s in sentences_b_negated): res_negated = [(float("nan"), float("nan"), "")] * args.batch_size # eval negated batch else: ( original_log_probs_list_negated, token_ids_list_negated, masked_indices_list_negated, ) = model.get_batch_generation(sentences_b_negated, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list_negated = model.filter_logprobs( original_log_probs_list_negated, filter_logprob_indices ) else: filtered_log_probs_list_negated = original_log_probs_list_negated arguments = [ { "log_probs": filtered_log_probs, "log_probs_negated": filtered_log_probs_negated, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "masked_indices_negated": masked_indices_negated, "index_list": index_list, } for filtered_log_probs, filtered_log_probs_negated, token_ids, masked_indices, masked_indices_negated, label_index in zip( filtered_log_probs_list, filtered_log_probs_list_negated, token_ids_list, masked_indices_list, masked_indices_list_negated, label_index_list, ) ] res_negated = pool.map(run_thread_negated, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element["sample"] = sample element["uuid"] = sample["uuid"] element["token_ids"] = token_ids_list[idx] element["masked_indices"] = masked_indices_list[idx] element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk element["sample_MRR"] = sample_MRR element["sample_Precision"] = sample_P element["sample_perplexity"] = sample_perplexity element["sample_Precision1"] = result_masked_topk["P_AT_1"] # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() if args.use_negated_probes: overlap, spearman, msg = res_negated[idx] # sum overlap and spearmanr if not nan if spearman == spearman: element["spearmanr"] = spearman element["overlap"] = overlap Overlap += overlap Spearman += spearman num_valid_negation += 1.0 MRR += sample_MRR Precision += sample_P Precision1 += element["sample_Precision1"] # the judgment of the annotators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if "judgments" in sample: # only for Google-RE for x in sample["judgments"]: if x["judgment"] == "yes": num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement += 1 element["judgement"] = "negative" MRR_negative += sample_MRR Precision_negative += sample_P else: samples_with_positive_judgement += 1 element["judgement"] = "positive" MRR_positive += sample_MRR Precision_positivie += sample_P list_of_results.append(element) pool.close() pool.join() # stats try: # Mean reciprocal rank MRR /= len(list_of_results) # Precision Precision /= len(list_of_results) Precision1 /= len(list_of_results) except ZeroDivisionError: MRR = Precision = Precision1 = 0.0 msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results)) msg += "global MRR: {}\n".format(MRR) msg += "global Precision at 10: {}\n".format(Precision) msg += "global Precision at 1: {}\n".format(Precision1) if args.use_negated_probes: Overlap /= num_valid_negation Spearman /= num_valid_negation msg += "\n" msg += "results negation:\n" msg += "all_negated_samples: {}\n".format(int(num_valid_negation)) msg += "global spearman rank affirmative/negated: {}\n".format(Spearman) msg += "global overlap at 1 affirmative/negated: {}\n".format(Overlap) if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0: # Google-RE specific MRR_negative /= samples_with_negative_judgement MRR_positive /= samples_with_positive_judgement Precision_negative /= samples_with_negative_judgement Precision_positivie /= samples_with_positive_judgement msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement ) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement ) msg += "MRR_negative: {}\n".format(MRR_negative) msg += "MRR_positive: {}\n".format(MRR_positive) msg += "Precision_negative: {}\n".format(Precision_negative) msg += "Precision_positivie: {}\n".format(Precision_positivie) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict( list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision ) with open("{}/result.pkl".format(log_directory), "wb") as f: pickle.dump(all_results, f) return Precision1
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError( 'Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names args.output_feature_path = getattr(args, 'output_feature_path', '') if getattr(args, 'knn_thresh', 0) > 0: assert hasattr(args, 'knn_path') assert hasattr(args, 'modify_ans') else: args.knn_thresh = 0 if getattr(args, 'knn_path', ''): knn_dict = torch.load(args.knn_path) if getattr(args, 'consine_dist', True): knn_dict['mask_features'] = knn_dict['mask_features'] / torch.norm( knn_dict['mask_features'], dim=1, keepdim=True) else: knn_dict['mask_features'] = knn_dict['mask_features'] new_ans_dict = json.load(open(args.modify_ans)) knn_dict['obj_labels'] = [ new_ans_dict[uuid] for uuid in knn_dict['uuids'] ] else: new_ans_dict = None knn_dict = None print(model) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) logger.info("\n" + msg + "\n") # dump arguments on file for log with open(os.path.join(log_directory, 'args.json'), "w") as outfile: json.dump(vars(args), outfile) # stats samples_with_negative_judgement = 0 samples_with_positive_judgement = 0 # Mean reciprocal rank MRR = 0.0 MRR_negative = 0.0 MRR_positive = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision1_modified = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 # spearman rank correlation # overlap at 1 if args.use_negated_probes: Spearman = 0.0 Overlap = 0.0 num_valid_negation = 0.0 data = load_file(args.dataset_filename) print(len(data)) all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) logger.info("\n" + ret_msg + "\n") print(len(all_samples)) # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != "": if getattr(args, 'use_evidences', False): new_all_samples = [] for sample in all_samples: if len(args.uuid_list ) > 0 and sample['uuid'] not in args.uuid_list: continue elif len(args.uuid_list) > 0: print(sample['uuid']) sub = sample["sub_label"] if new_ans_dict is not None and sample['uuid'] in new_ans_dict: # we need to replace the answer in this way obj = new_ans_dict[sample['uuid']] else: obj = sample["obj_label"] if sample['uuid'] == '11fc104b-bba2-412c-b2d7-cf06cd2bd715': sample['evidences'] = sample['evidences'][:32] for ne, evidence in enumerate(sample['evidences']): # maximum of 10 evidences per fact if ne >= 10: continue new_sample = {'sub_label': sub, 'obj_label': obj} if '[MASK]' not in evidence['masked_sentence']: continue new_sample['masked_sentences'] = [ evidence['masked_sentence'] ] new_sample['uuid'] = sample['uuid'] new_all_samples.append(new_sample) all_samples = new_all_samples else: facts = [] for sample in all_samples: sub = sample["sub_label"] if new_ans_dict is not None and sample['uuid'] in new_ans_dict: # we need to replace the answer in this way obj = new_ans_dict[sample['uuid']] else: obj = sample["obj_label"] if (sub, obj) not in facts: facts.append((sample['uuid'], sub, obj)) local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (uuid, sub, obj) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj sample["uuid"] = uuid # sobstitute all sentences with a standard template sample["masked_sentences"] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK) if args.use_negated_probes: # substitute all negated sentences with a standard template sample["negated"] = parse_template( args.template_negated.strip(), sample["sub_label"].strip(), base.MASK, ) all_samples.append(sample) # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify( all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") if args.use_negated_probes: sentences_batches_negated, ret_msg = batchify_negated( all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = [] total_modified = 0 mask_feature_all, answers_list, uid_list = [], [], [] correct_uuids = [] knn_preds_list = [] for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] rets = model.get_batch_generation(sentences_b, logger=logger, return_features=args.return_features or args.knn_thresh > 0) if len(rets) == 4: original_log_probs_list, token_ids_list, masked_indices_list, feature_tensor = rets mask_feature_all.append(feature_tensor) else: original_log_probs_list, token_ids_list, masked_indices_list = rets if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] modified_flags_list = [] for ns, sample in enumerate(samples_b): obj_label_id = model.get_id(sample["obj_label"]) answers_list.append(sample["obj_label"]) uid_list.append(sample['uuid']) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif vocab_subset is not None and sample[ "obj_label"] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample["obj_label"])) label_index_list.append(obj_label_id) if args.knn_thresh > 0: feature = feature_tensor[ns].view(1, -1) if getattr(args, 'consine_dist', True): dist = torch.sum(feature * knn_dict['mask_features'], dim=1) / torch.norm(feature) else: dist = torch.norm(feature - knn_dict['mask_features'], dim=1) min_dist, min_idx = torch.min(dist, dim=0) # print(min_dist.item()) if min_dist < args.knn_thresh: knn_pred = knn_dict['obj_labels'][min_idx.item()] knn_preds_list.append(model.get_id(knn_pred)[0]) # if knn_dict['uuids'][min_idx.item()] == sample['uuid']: # pdb.set_trace() else: knn_preds_list.append(-1) # log_probs.unsqueeze() # knn_preds_list. else: knn_preds_list.append(-1) # label whether the fact has been modified modified_flags_list.append(new_ans_dict is not None and sample['uuid'] in new_ans_dict) arguments = [{ "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, "knn_pred": knn_pred, "modified": modified } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample, knn_pred, modified in zip(original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, knn_preds_list, modified_flags_list)] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) if args.use_negated_probes: sentences_b_negated = sentences_batches_negated[i] # if no negated sentences in batch if all(s[0] == "" for s in sentences_b_negated): res_negated = [(float("nan"), float("nan"), "") ] * args.batch_size # eval negated batch else: ( original_log_probs_list_negated, token_ids_list_negated, masked_indices_list_negated, ) = model.get_batch_generation(sentences_b_negated, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list_negated = model.filter_logprobs( original_log_probs_list_negated, filter_logprob_indices) else: filtered_log_probs_list_negated = original_log_probs_list_negated arguments = [{ "log_probs": filtered_log_probs, "log_probs_negated": filtered_log_probs_negated, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "masked_indices_negated": masked_indices_negated, "index_list": index_list, } for filtered_log_probs, filtered_log_probs_negated, token_ids, masked_indices, masked_indices_negated, label_index in zip( filtered_log_probs_list, filtered_log_probs_list_negated, token_ids_list, masked_indices_list, masked_indices_list_negated, label_index_list, )] res_negated = pool.map(run_thread_negated, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element["sample"] = sample element["uuid"] = sample["uuid"] element["token_ids"] = token_ids_list[idx] element["masked_indices"] = masked_indices_list[idx] element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk element["sample_MRR"] = sample_MRR element["sample_Precision"] = sample_P element["sample_perplexity"] = sample_perplexity element["sample_Precision1"] = result_masked_topk["P_AT_1"] element["modified"] = result_masked_topk["modified"] if result_masked_topk["P_AT_1"] > 0: correct_uuids.append(element['uuid']) # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() if args.use_negated_probes: overlap, spearman, msg = res_negated[idx] # sum overlap and spearmanr if not nan if spearman == spearman: element["spearmanr"] = spearman element["overlap"] = overlap Overlap += overlap Spearman += spearman num_valid_negation += 1.0 MRR += sample_MRR Precision += sample_P if element["modified"]: Precision1_modified += element["sample_Precision1"] else: Precision1 += element["sample_Precision1"] # the judgment of the annotators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if "judgments" in sample: # only for Google-RE for x in sample["judgments"]: if x["judgment"] == "yes": num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement += 1 element["judgement"] = "negative" MRR_negative += sample_MRR Precision_negative += sample_P else: samples_with_positive_judgement += 1 element["judgement"] = "positive" MRR_positive += sample_MRR Precision_positivie += sample_P if element["modified"]: total_modified += 1 else: list_of_results.append(element) pool.close() pool.join() if args.output_feature_path and len(list_of_results) == 0: # torch.save(out_dict, args.output_feature_path) # return empty results return Precision1, uid_list, mask_feature_all, answers_list elif len(list_of_results) == 0: pdb.set_trace() # stats # Mean reciprocal rank MRR /= len(list_of_results) # Precision Precision /= len(list_of_results) # Precision1 /= len(list_of_results) msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results)) msg += "global MRR: {}\n".format(MRR) msg += "global Precision at 10: {}\n".format(Precision) msg += "global Precision at 1: {}\n".format(Precision1) if args.use_negated_probes: Overlap /= num_valid_negation Spearman /= num_valid_negation msg += "\n" msg += "results negation:\n" msg += "all_negated_samples: {}\n".format(int(num_valid_negation)) msg += "global spearman rank affirmative/negated: {}\n".format( Spearman) msg += "global overlap at 1 affirmative/negated: {}\n".format(Overlap) if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0: # Google-RE specific MRR_negative /= samples_with_negative_judgement MRR_positive /= samples_with_positive_judgement Precision_negative /= samples_with_negative_judgement Precision_positivie /= samples_with_positive_judgement msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement) msg += "MRR_negative: {}\n".format(MRR_negative) msg += "MRR_positive: {}\n".format(MRR_positive) msg += "Precision_negative: {}\n".format(Precision_negative) msg += "Precision_positivie: {}\n".format(Precision_positivie) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict(list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision) with open("{}/result.pkl".format(log_directory), "wb") as f: pickle.dump(all_results, f) if args.output_feature_path: # torch.save(out_dict, args.output_feature_path) return Precision1, len( list_of_results ), Precision1_modified, total_modified, uid_list, mask_feature_all, answers_list return Precision1, len( list_of_results), Precision1_modified, total_modified, correct_uuids