def main(args): if not args.subject and not args.relation: raise ValueError( 'You need to specify --subject and --relation to query language models.' ) print('Language Models: {}'.format(args.models_names)) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) vocab_subset = None if args.common_vocab_filename is not None: common_vocab = load_vocab(args.common_vocab_filename) print('Common vocabulary size: {}'.format(len(common_vocab))) vocab_subset = [x for x in common_vocab] prompt_file = os.path.join(args.prompts, args.relation + '.jsonl') if not os.path.exists(prompt_file): raise ValueError('Relation "{}" does not exist.'.format(args.relation)) prompts, weights = load_prompt_weights(prompt_file) for model_name, model in models.items(): print('\n{}:'.format(model_name)) index_list = None if vocab_subset is not None: filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset) ensemble_log_probs = 0 for prompt, weight in zip(prompts, weights): prompt = parse_prompt(prompt, args.subject, model.mask_token) log_prob, [token_ids ], [masked_indices ], _, _ = model.get_batch_generation([prompt], try_cuda=True) if vocab_subset is not None: filtered_log_probs = model.filter_logprobs( log_prob, filter_logprob_indices) else: filtered_log_probs = log_prob # rank over the subset of the vocab (if defined) for the SINGLE masked tokens if masked_indices and len(masked_indices) > 0: filtered_log_probs = filtered_log_probs[0][masked_indices[0]] ensemble_log_probs += filtered_log_probs * weight ensemble_log_probs = F.log_softmax(ensemble_log_probs, dim=0) evaluation_metrics.get_ranking(ensemble_log_probs, model.vocab, label_index=None, index_list=index_list, topk=1000, P_AT=10, print_generation=True)
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError('Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names print(model) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) elif model_type_name == "roberta": model_name = "RoBERTa_{}".format(args.roberta_model_name) elif model_type_name == "hfroberta": model_name = "hfRoBERTa_{}".format(args.hfroberta_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger ) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), "w") as outfile: json.dump(vars(args), outfile) # stats samples_with_negative_judgement = 0 samples_with_positive_judgement = 0 # Mean reciprocal rank MRR = 0.0 MRR_negative = 0.0 MRR_positive = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 # spearman rank correlation # overlap at 1 if args.use_negated_probes: Spearman = 0.0 Overlap = 0.0 num_valid_negation = 0.0 data = load_file(args.dataset_filename) print(len(data)) if args.lowercase: # lowercase all samples logger.info("lowercasing all samples...") all_samples = lowercase_samples( data, use_negated_probes=args.use_negated_probes ) else: # keep samples as they are all_samples = data # TREx data for i, sample in enumerate(all_samples): if 'masked_sentences' not in sample.keys(): sample['masked_sentences'] = [] for evidence in sample['evidences']: sample['masked_sentences'].append(evidence['masked_sentence']) if i == 0: print('not masked_sentences, but masked_sentence.') all_samples, ret_msg = filter_samples( model, data, vocab_subset, args.max_sentence_length, args.template ) # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename) # with open(OUT_FILENAME, 'w') as outfile: # for entry in all_samples: # json.dump(entry, outfile) # outfile.write('\n') logger.info("\n" + ret_msg + "\n") print(len(all_samples)) # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != "": facts = [] for sample in all_samples: sub = sample["sub_label"] obj = sample["obj_label"] if (sub, obj) not in facts: facts.append((sub, obj)) local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (sub, obj) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj # sobstitute all sentences with a standard template sample["masked_sentences"] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK ) if args.use_negated_probes: # substitute all negated sentences with a standard template sample["negated"] = parse_template( args.template_negated.strip(), sample["sub_label"].strip(), base.MASK, ) all_samples.append(sample) # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify(all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") if args.use_negated_probes: sentences_batches_negated, ret_msg = batchify_negated( all_samples, args.batch_size ) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = [] for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] ( original_log_probs_list, token_ids_list, masked_indices_list, ) = model.get_batch_generation(sentences_b, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices ) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"] ) ) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"] ) ) elif vocab_subset is not None and sample["obj_label"] not in vocab_subset: raise ValueError( "object label {} not in vocab subset".format(sample["obj_label"]) ) label_index_list.append(obj_label_id) arguments = [ { "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, ) ] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) if args.use_negated_probes: sentences_b_negated = sentences_batches_negated[i] # if no negated sentences in batch if all(s[0] == "" for s in sentences_b_negated): res_negated = [(float("nan"), float("nan"), "")] * args.batch_size # eval negated batch else: ( original_log_probs_list_negated, token_ids_list_negated, masked_indices_list_negated, ) = model.get_batch_generation(sentences_b_negated, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list_negated = model.filter_logprobs( original_log_probs_list_negated, filter_logprob_indices ) else: filtered_log_probs_list_negated = original_log_probs_list_negated arguments = [ { "log_probs": filtered_log_probs, "log_probs_negated": filtered_log_probs_negated, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "masked_indices_negated": masked_indices_negated, "index_list": index_list, } for filtered_log_probs, filtered_log_probs_negated, token_ids, masked_indices, masked_indices_negated, label_index in zip( filtered_log_probs_list, filtered_log_probs_list_negated, token_ids_list, masked_indices_list, masked_indices_list_negated, label_index_list, ) ] res_negated = pool.map(run_thread_negated, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element["sample"] = sample element["uuid"] = sample["uuid"] element["token_ids"] = token_ids_list[idx] element["masked_indices"] = masked_indices_list[idx] element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk element["sample_MRR"] = sample_MRR element["sample_Precision"] = sample_P element["sample_perplexity"] = sample_perplexity element["sample_Precision1"] = result_masked_topk["P_AT_1"] # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() if args.use_negated_probes: overlap, spearman, msg = res_negated[idx] # sum overlap and spearmanr if not nan if spearman == spearman: element["spearmanr"] = spearman element["overlap"] = overlap Overlap += overlap Spearman += spearman num_valid_negation += 1.0 MRR += sample_MRR Precision += sample_P Precision1 += element["sample_Precision1"] # the judgment of the annotators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if "judgments" in sample: # only for Google-RE for x in sample["judgments"]: if x["judgment"] == "yes": num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement += 1 element["judgement"] = "negative" MRR_negative += sample_MRR Precision_negative += sample_P else: samples_with_positive_judgement += 1 element["judgement"] = "positive" MRR_positive += sample_MRR Precision_positivie += sample_P list_of_results.append(element) pool.close() pool.join() # stats try: # Mean reciprocal rank MRR /= len(list_of_results) # Precision Precision /= len(list_of_results) Precision1 /= len(list_of_results) except ZeroDivisionError: MRR = Precision = Precision1 = 0.0 msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results)) msg += "global MRR: {}\n".format(MRR) msg += "global Precision at 10: {}\n".format(Precision) msg += "global Precision at 1: {}\n".format(Precision1) if args.use_negated_probes: Overlap /= num_valid_negation Spearman /= num_valid_negation msg += "\n" msg += "results negation:\n" msg += "all_negated_samples: {}\n".format(int(num_valid_negation)) msg += "global spearman rank affirmative/negated: {}\n".format(Spearman) msg += "global overlap at 1 affirmative/negated: {}\n".format(Overlap) if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0: # Google-RE specific MRR_negative /= samples_with_negative_judgement MRR_positive /= samples_with_positive_judgement Precision_negative /= samples_with_negative_judgement Precision_positivie /= samples_with_positive_judgement msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement ) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement ) msg += "MRR_negative: {}\n".format(MRR_negative) msg += "MRR_positive: {}\n".format(MRR_positive) msg += "Precision_negative: {}\n".format(Precision_negative) msg += "Precision_positivie: {}\n".format(Precision_positivie) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict( list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision ) with open("{}/result.pkl".format(log_directory), "wb") as f: pickle.dump(all_results, f) return Precision1
def main(args): if not args.text and not args.interactive: msg = "ERROR: either you start LAMA eval_generation with the " \ "interactive option (--i) or you pass in input a piece of text (--t)" raise ValueError(msg) stopping_condition = True print("Language Models: {}".format(args.models_names)) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) vocab_subset = None if args.common_vocab_filename is not None: common_vocab = load_vocab(args.common_vocab_filename) print("common vocabulary size: {}".format(len(common_vocab))) vocab_subset = [x for x in common_vocab] while stopping_condition: if args.text: text = args.text stopping_condition = False else: text = input("insert text:") if args.split_sentence: import spacy # use spacy to tokenize input sentence nlp = spacy.load(args.spacy_model) tokens = nlp(text) print(tokens) sentences = [] for s in tokens.sents: print(" - {}".format(s)) sentences.append(s.text) else: sentences = [text] if len(sentences) > 2: print( "WARNING: only the first two sentences in the text will be considered!" ) sentences = sentences[:2] for model_name, model in models.items(): print("\n{}:".format(model_name)) original_log_probs_list, [token_ids], [ masked_indices ] = model.get_batch_generation([sentences], try_cuda=False) index_list = None if vocab_subset is not None: # filter log_probs filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset) filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list # rank over the subset of the vocab (if defined) for the SINGLE masked tokens if masked_indices and len(masked_indices) > 0: evaluation_metrics.get_ranking(filtered_log_probs_list[0], masked_indices, model.vocab, index_list=index_list) # prediction and perplexity for the whole softmax print_sentence_predictions(original_log_probs_list[0], token_ids, model.vocab, masked_indices=masked_indices)
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError( 'Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names args.output_feature_path = getattr(args, 'output_feature_path', '') if getattr(args, 'knn_thresh', 0) > 0: assert hasattr(args, 'knn_path') assert hasattr(args, 'modify_ans') else: args.knn_thresh = 0 if getattr(args, 'knn_path', ''): knn_dict = torch.load(args.knn_path) if getattr(args, 'consine_dist', True): knn_dict['mask_features'] = knn_dict['mask_features'] / torch.norm( knn_dict['mask_features'], dim=1, keepdim=True) else: knn_dict['mask_features'] = knn_dict['mask_features'] new_ans_dict = json.load(open(args.modify_ans)) knn_dict['obj_labels'] = [ new_ans_dict[uuid] for uuid in knn_dict['uuids'] ] else: new_ans_dict = None knn_dict = None print(model) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) logger.info("\n" + msg + "\n") # dump arguments on file for log with open(os.path.join(log_directory, 'args.json'), "w") as outfile: json.dump(vars(args), outfile) # stats samples_with_negative_judgement = 0 samples_with_positive_judgement = 0 # Mean reciprocal rank MRR = 0.0 MRR_negative = 0.0 MRR_positive = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision1_modified = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 # spearman rank correlation # overlap at 1 if args.use_negated_probes: Spearman = 0.0 Overlap = 0.0 num_valid_negation = 0.0 data = load_file(args.dataset_filename) print(len(data)) all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) logger.info("\n" + ret_msg + "\n") print(len(all_samples)) # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != "": if getattr(args, 'use_evidences', False): new_all_samples = [] for sample in all_samples: if len(args.uuid_list ) > 0 and sample['uuid'] not in args.uuid_list: continue elif len(args.uuid_list) > 0: print(sample['uuid']) sub = sample["sub_label"] if new_ans_dict is not None and sample['uuid'] in new_ans_dict: # we need to replace the answer in this way obj = new_ans_dict[sample['uuid']] else: obj = sample["obj_label"] if sample['uuid'] == '11fc104b-bba2-412c-b2d7-cf06cd2bd715': sample['evidences'] = sample['evidences'][:32] for ne, evidence in enumerate(sample['evidences']): # maximum of 10 evidences per fact if ne >= 10: continue new_sample = {'sub_label': sub, 'obj_label': obj} if '[MASK]' not in evidence['masked_sentence']: continue new_sample['masked_sentences'] = [ evidence['masked_sentence'] ] new_sample['uuid'] = sample['uuid'] new_all_samples.append(new_sample) all_samples = new_all_samples else: facts = [] for sample in all_samples: sub = sample["sub_label"] if new_ans_dict is not None and sample['uuid'] in new_ans_dict: # we need to replace the answer in this way obj = new_ans_dict[sample['uuid']] else: obj = sample["obj_label"] if (sub, obj) not in facts: facts.append((sample['uuid'], sub, obj)) local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (uuid, sub, obj) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj sample["uuid"] = uuid # sobstitute all sentences with a standard template sample["masked_sentences"] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK) if args.use_negated_probes: # substitute all negated sentences with a standard template sample["negated"] = parse_template( args.template_negated.strip(), sample["sub_label"].strip(), base.MASK, ) all_samples.append(sample) # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify( all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") if args.use_negated_probes: sentences_batches_negated, ret_msg = batchify_negated( all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = [] total_modified = 0 mask_feature_all, answers_list, uid_list = [], [], [] correct_uuids = [] knn_preds_list = [] for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] rets = model.get_batch_generation(sentences_b, logger=logger, return_features=args.return_features or args.knn_thresh > 0) if len(rets) == 4: original_log_probs_list, token_ids_list, masked_indices_list, feature_tensor = rets mask_feature_all.append(feature_tensor) else: original_log_probs_list, token_ids_list, masked_indices_list = rets if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] modified_flags_list = [] for ns, sample in enumerate(samples_b): obj_label_id = model.get_id(sample["obj_label"]) answers_list.append(sample["obj_label"]) uid_list.append(sample['uuid']) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif vocab_subset is not None and sample[ "obj_label"] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample["obj_label"])) label_index_list.append(obj_label_id) if args.knn_thresh > 0: feature = feature_tensor[ns].view(1, -1) if getattr(args, 'consine_dist', True): dist = torch.sum(feature * knn_dict['mask_features'], dim=1) / torch.norm(feature) else: dist = torch.norm(feature - knn_dict['mask_features'], dim=1) min_dist, min_idx = torch.min(dist, dim=0) # print(min_dist.item()) if min_dist < args.knn_thresh: knn_pred = knn_dict['obj_labels'][min_idx.item()] knn_preds_list.append(model.get_id(knn_pred)[0]) # if knn_dict['uuids'][min_idx.item()] == sample['uuid']: # pdb.set_trace() else: knn_preds_list.append(-1) # log_probs.unsqueeze() # knn_preds_list. else: knn_preds_list.append(-1) # label whether the fact has been modified modified_flags_list.append(new_ans_dict is not None and sample['uuid'] in new_ans_dict) arguments = [{ "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, "knn_pred": knn_pred, "modified": modified } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample, knn_pred, modified in zip(original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, knn_preds_list, modified_flags_list)] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) if args.use_negated_probes: sentences_b_negated = sentences_batches_negated[i] # if no negated sentences in batch if all(s[0] == "" for s in sentences_b_negated): res_negated = [(float("nan"), float("nan"), "") ] * args.batch_size # eval negated batch else: ( original_log_probs_list_negated, token_ids_list_negated, masked_indices_list_negated, ) = model.get_batch_generation(sentences_b_negated, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list_negated = model.filter_logprobs( original_log_probs_list_negated, filter_logprob_indices) else: filtered_log_probs_list_negated = original_log_probs_list_negated arguments = [{ "log_probs": filtered_log_probs, "log_probs_negated": filtered_log_probs_negated, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "masked_indices_negated": masked_indices_negated, "index_list": index_list, } for filtered_log_probs, filtered_log_probs_negated, token_ids, masked_indices, masked_indices_negated, label_index in zip( filtered_log_probs_list, filtered_log_probs_list_negated, token_ids_list, masked_indices_list, masked_indices_list_negated, label_index_list, )] res_negated = pool.map(run_thread_negated, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element["sample"] = sample element["uuid"] = sample["uuid"] element["token_ids"] = token_ids_list[idx] element["masked_indices"] = masked_indices_list[idx] element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk element["sample_MRR"] = sample_MRR element["sample_Precision"] = sample_P element["sample_perplexity"] = sample_perplexity element["sample_Precision1"] = result_masked_topk["P_AT_1"] element["modified"] = result_masked_topk["modified"] if result_masked_topk["P_AT_1"] > 0: correct_uuids.append(element['uuid']) # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() if args.use_negated_probes: overlap, spearman, msg = res_negated[idx] # sum overlap and spearmanr if not nan if spearman == spearman: element["spearmanr"] = spearman element["overlap"] = overlap Overlap += overlap Spearman += spearman num_valid_negation += 1.0 MRR += sample_MRR Precision += sample_P if element["modified"]: Precision1_modified += element["sample_Precision1"] else: Precision1 += element["sample_Precision1"] # the judgment of the annotators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if "judgments" in sample: # only for Google-RE for x in sample["judgments"]: if x["judgment"] == "yes": num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement += 1 element["judgement"] = "negative" MRR_negative += sample_MRR Precision_negative += sample_P else: samples_with_positive_judgement += 1 element["judgement"] = "positive" MRR_positive += sample_MRR Precision_positivie += sample_P if element["modified"]: total_modified += 1 else: list_of_results.append(element) pool.close() pool.join() if args.output_feature_path and len(list_of_results) == 0: # torch.save(out_dict, args.output_feature_path) # return empty results return Precision1, uid_list, mask_feature_all, answers_list elif len(list_of_results) == 0: pdb.set_trace() # stats # Mean reciprocal rank MRR /= len(list_of_results) # Precision Precision /= len(list_of_results) # Precision1 /= len(list_of_results) msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results)) msg += "global MRR: {}\n".format(MRR) msg += "global Precision at 10: {}\n".format(Precision) msg += "global Precision at 1: {}\n".format(Precision1) if args.use_negated_probes: Overlap /= num_valid_negation Spearman /= num_valid_negation msg += "\n" msg += "results negation:\n" msg += "all_negated_samples: {}\n".format(int(num_valid_negation)) msg += "global spearman rank affirmative/negated: {}\n".format( Spearman) msg += "global overlap at 1 affirmative/negated: {}\n".format(Overlap) if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0: # Google-RE specific MRR_negative /= samples_with_negative_judgement MRR_positive /= samples_with_positive_judgement Precision_negative /= samples_with_negative_judgement Precision_positivie /= samples_with_positive_judgement msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement) msg += "MRR_negative: {}\n".format(MRR_negative) msg += "MRR_positive: {}\n".format(MRR_positive) msg += "Precision_negative: {}\n".format(Precision_negative) msg += "Precision_positivie: {}\n".format(Precision_positivie) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict(list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision) with open("{}/result.pkl".format(log_directory), "wb") as f: pickle.dump(all_results, f) if args.output_feature_path: # torch.save(out_dict, args.output_feature_path) return Precision1, len( list_of_results ), Precision1_modified, total_modified, uid_list, mask_feature_all, answers_list return Precision1, len( list_of_results), Precision1_modified, total_modified, correct_uuids
def main(args): verbose = False moral_template = "{} is a [MASK] behavior ." data_user_study = pickle.load( open('../MoRT/data/parsed_yes_no_BERTBias.p', 'rb')) actions = list(data_user_study.keys()) actions = ['help people', 'help white people', 'help black people'] print("Language Models: {}".format(args.models_names)) models = {} for lm in args.models_names: models[lm] = build_model_by_name(lm, args) vocab_subset = None if args.common_vocab_filename is not None: common_vocab = load_vocab(args.common_vocab_filename) print("common vocabulary size: {}".format(len(common_vocab))) vocab_subset = [x for x in common_vocab] data_user_study_BERTKnowledge = dict() for action in actions: action_ = action.split(" ") action_[0] = verb_noun_mapping[action_[0]].capitalize() action_ = " ".join(action_) text = moral_template.format(action_) if args.split_sentence: import spacy # use spacy to tokenize input sentence nlp = spacy.load(args.spacy_model) tokens = nlp(text) print(tokens) sentences = [] for s in tokens.sents: print(" - {}".format(s)) sentences.append(s.text) else: sentences = [text] if len(sentences) > 2: print( "WARNING: only the first two sentences in the text will be considered!" ) sentences = sentences[:2] for model_name, model in models.items(): if model_name not in list(data_user_study_BERTKnowledge.keys()): data_user_study_BERTKnowledge[model_name] = {} if verbose: print("\n{}:".format(model_name)) original_log_probs_list, [token_ids], [ masked_indices ] = model.get_batch_generation([sentences], try_cuda=False) index_list = None if vocab_subset is not None: # filter log_probs filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset) filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list # rank over the subset of the vocab (if defined) for the SINGLE masked tokens if masked_indices and len(masked_indices) > 0: _, _, experiment_result, _ = evaluation_metrics.get_ranking( filtered_log_probs_list[0], masked_indices, model.vocab, index_list=index_list, print_generation=verbose) experiment_result_topk = [(r['i'], r['token_word_form'], r['log_prob']) for r in experiment_result['topk'][:10]] data_user_study_BERTKnowledge[model_name][action] = [ text, experiment_result_topk ] # prediction and perplexity for the whole softmax if verbose: print_sentence_predictions(original_log_probs_list[0], token_ids, model.vocab, masked_indices=masked_indices) print(data_user_study_BERTKnowledge) pickle.dump(data_user_study_BERTKnowledge, open('./parsed_BERTKnowledge_tests.p', 'wb'))
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError( "Please specify a single language model (e.g., --lm \"bert\").") msg = "" [model_type_name] = args.models_names print(model) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == 'fairseq': model_name = 'fairseq_{}'.format(args.fairseq_model_name) elif model_type_name == 'bert': model_name = 'BERT_{}'.format(args.bert_model_name) elif model_type_name == 'elmo': model_name = 'ELMo_{}'.format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), 'w') as outfile: json.dump(vars(args), outfile) # stats samples_with_negative_judgement = 0 samples_with_positive_judgement = 0 # Mean reciprocal rank MRR = 0. MRR_negative = 0. MRR_positive = 0. # Precision at (default 10) Precision = 0. Precision1 = 0. Precision_negative = 0. Precision_positivie = 0. data = load_file(args.dataset_filename) print(len(data)) if args.lowercase: # lowercase all samples logger.info("lowercasing all samples...") all_samples = lowercase_samples(data) else: # keep samples as they are all_samples = data all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename) # with open(OUT_FILENAME, 'w') as outfile: # for entry in all_samples: # json.dump(entry, outfile) # outfile.write('\n') logger.info("\n" + ret_msg + "\n") print(len(all_samples)) # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != '': facts = [] for sample in all_samples: sub = sample['sub_label'] obj = sample['obj_label'] if (sub, obj) not in facts: facts.append((sub, obj)) local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (sub, obj) = fact sample = {} sample['sub_label'] = sub sample['obj_label'] = obj # sobstitute all sentences with a standard template sample['masked_sentences'] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK) all_samples.append(sample) # create uuid if not present i = 0 for sample in all_samples: if 'uuid' not in sample: sample['uuid'] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify( all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = [] for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] original_log_probs_list, token_ids_list, masked_indices_list = model.get_batch_generation( sentences_b, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample['obj_label']) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample['obj_label'])) elif (model.vocab[obj_label_id[0]] != sample['obj_label']): raise ValueError( "object label {} not in model vocabulary".format( sample['obj_label'])) elif vocab_subset is not None and sample[ 'obj_label'] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample['obj_label'])) label_index_list.append(obj_label_id) arguments = [{ 'original_log_probs': original_log_probs, 'filtered_log_probs': filtered_log_probs, 'token_ids': token_ids, 'vocab': model.vocab, 'label_index': label_index[0], 'masked_indices': masked_indices, 'interactive': args.interactive, 'index_list': index_list, 'sample': sample } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b)] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element['sample'] = sample element['uuid'] = sample['uuid'] element['token_ids'] = token_ids_list[idx] element['masked_indices'] = masked_indices_list[idx] element['label_index'] = label_index_list[idx] element['masked_topk'] = result_masked_topk element['sample_MRR'] = sample_MRR element['sample_Precision'] = sample_P element['sample_perplexity'] = sample_perplexity element['sample_Precision1'] = result_masked_topk["P_AT_1"] # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() MRR += sample_MRR Precision += sample_P Precision1 += element['sample_Precision1'] # the judgment of the annotators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if 'judgments' in sample: # only for Google-RE for x in sample['judgments']: if (x['judgment'] == "yes"): num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement += 1 element['judgement'] = "negative" MRR_negative += sample_MRR Precision_negative += sample_P else: samples_with_positive_judgement += 1 element['judgement'] = "positive" MRR_positive += sample_MRR Precision_positivie += sample_P list_of_results.append(element) pool.close() pool.join() # stats # Mean reciprocal rank MRR /= len(list_of_results) # Precision Precision /= len(list_of_results) Precision1 /= len(list_of_results) msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results)) msg += "global MRR: {}\n".format(MRR) msg += "global Precision at 10: {}\n".format(Precision) msg += "global Precision at 1: {}\n".format(Precision1) if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0: # Google-RE specific MRR_negative /= samples_with_negative_judgement MRR_positive /= samples_with_positive_judgement Precision_negative /= samples_with_negative_judgement Precision_positivie /= samples_with_positive_judgement msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement) msg += "MRR_negative: {}\n".format(MRR_negative) msg += "MRR_positive: {}\n".format(MRR_positive) msg += "Precision_negative: {}\n".format(Precision_negative) msg += "Precision_positivie: {}\n".format(Precision_positivie) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict( list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision, ) with open("{}/result.pkl".format(log_directory), 'wb') as f: pickle.dump(all_results, f) return Precision1
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError( 'Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names # print("------- Model: {}".format(model)) # print("------- Args: {}".format(args)) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) # logger.info("\n" + msg + "\n") # dump arguments on file for log # with open("{}/args.json".format(log_directory), "w") as outfile: # json.dump(vars(args), outfile) # Mean reciprocal rank MRR = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 data = load_file(args.dataset_filename) all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) # logger.info("\n" + ret_msg + "\n") # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != "": facts = [] for sample in all_samples: sub = sample["sub_label"] obj = sample["obj_label"] if (sub, obj) not in facts: facts.append((sub, obj)) local_msg = "distinct template facts: {}".format(len(facts)) # logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (sub, obj) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj # sobstitute all sentences with a standard template sample["masked_sentences"] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK) if args.use_negated_probes: # substitute all negated sentences with a standard template sample["negated"] = parse_template( args.template_negated.strip(), sample["sub_label"].strip(), base.MASK, ) all_samples.append(sample) # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify( all_samples, args.batch_size) # logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) # list_of_results = [] # list_of_ranks = [] item_count = 0 for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] ( original_log_probs_list, token_ids_list, masked_indices_list, ) = model.get_batch_generation(sentences_b, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif vocab_subset is not None and sample[ "obj_label"] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample["obj_label"])) label_index_list.append(obj_label_id) arguments = [{ "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, )] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result # print("~~~~~~~~~~~~~~~~~~") # print(result_masked_topk) # logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} obj = sample['obj_label'] sub = sample['sub_label'] element["masked_sentences"] = sample["masked_sentences"][0] # element["uuid"] = sample["uuid"] element["subject"] = sub element["object"] = obj element["rank"] = int(result_masked_topk['rank']) # element["sample_Precision1"] = result_masked_topk["P_AT_1"] # element["sample"] = sample # element["token_ids"] = token_ids_list[idx] # element["masked_indices"] = masked_indices_list[idx] # element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk['topk'][:20] # element["sample_MRR"] = sample_MRR # element["sample_Precision"] = sample_P # element["sample_perplexity"] = sample_perplexity # list_of_results[sub + "_" + obj].append(element) # list_of_ranks[sub + "_" + obj].append(element["rank"]) # print("~~~~~~ rank: {}".format(result_masked_topk['rank'])) MRR += sample_MRR Precision += sample_P Precision1 += result_masked_topk["P_AT_1"] item_count += 1 append_data_line_to_jsonl( "reproduction/data/P_AT_K/{}_rank_results.jsonl".format( args.label), element) append_data_line_to_jsonl( "reproduction/data/P_AT_K/{}_rank_list.jsonl".format( args.label), element["rank"]) pool.close() pool.join() Precision1 /= item_count # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_results.jsonl".format(args.label), list_of_results) # 3122 # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_dic.jsonl".format(args.label), list_of_ranks) # 3122 # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_list.jsonl".format(args.label), list(list_of_ranks.values())) # 3122 return Precision1
def main(args, shuffle_data=True, model=None, model2=None, refine_template=False, get_objs=False, dynamic='none', use_prob=False, bt_obj=None, temp_model=None): if len(args.models_names) > 1: raise ValueError( 'Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names #print(model) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename, lower=args.lowercase) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), "w") as outfile: json.dump(vars(args), outfile) if dynamic == 'all_topk': # save topk results for different k # stats samples_with_negative_judgement = [ 0 for _ in range(len(args.template)) ] samples_with_positive_judgement = [ 0 for _ in range(len(args.template)) ] # Mean reciprocal rank MRR = [0.0 for _ in range(len(args.template))] MRR_negative = [0.0 for _ in range(len(args.template))] MRR_positive = [0.0 for _ in range(len(args.template))] # Precision at (default 10) Precision = [0.0 for _ in range(len(args.template))] Precision1 = [0.0 for _ in range(len(args.template))] Precision_negative = [0.0 for _ in range(len(args.template))] Precision_positivie = [0.0 for _ in range(len(args.template))] list_of_results = [[] for _ in range(len(args.template))] P1_li = [[] for _ in range(len(args.template))] else: # stats samples_with_negative_judgement = [0] samples_with_positive_judgement = [0] # Mean reciprocal rank MRR = [0.0] MRR_negative = [0.0] MRR_positive = [0.0] # Precision at (default 10) Precision = [0.0] Precision1 = [0.0] Precision_negative = [0.0] Precision_positivie = [0.0] list_of_results = [[]] P1_li = [[]] data = load_file(args.dataset_filename) for s in data: s['raw_sub_label'] = s['sub_label'] s['raw_obj_label'] = s['obj_label'] if args.lowercase: # lowercase all samples logger.info("lowercasing all samples...") data = lowercase_samples(data) all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename) # with open(OUT_FILENAME, 'w') as outfile: # for entry in all_samples: # json.dump(entry, outfile) # outfile.write('\n') logger.info("\n" + ret_msg + "\n") print('#head-tails {} -> {}'.format(len(data), len(all_samples))) samples_batches_li, sentences_batches_li = [], [] for template in args.template: # if template is active (1) use a single example for (sub,obj) and (2) ... if template and template != "": facts = [] samples = [] for sample in all_samples: sub = sample["sub_label"] obj = sample["obj_label"] if (sub, obj) not in facts: facts.append((sub, obj)) samples.append(sample) local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") new_all_samples = [] for fact, raw_sample in zip(facts, samples): (sub, obj) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj # sobstitute all sentences with a standard template sample["masked_sentences"] = parse_template( template.strip(), raw_sample["raw_sub_label"].strip() if args.upper_entity else sample["sub_label"].strip(), model.mask_token) sub_uri = raw_sample[ 'sub_uri'] if 'sub_uri' in raw_sample else raw_sample['sub'] sample['entity_list'] = get_entity_list( template.strip(), raw_sample['raw_sub_label'].strip(), sub_uri, None, None) if dynamic.startswith('bt_topk') or (temp_model is not None and bt_obj): sample['sub_masked_sentences'] = parse_template_tokenize( template.strip(), sample["sub_label"].strip(), model, mask_part='sub') if dynamic == 'real_lm' or dynamic.startswith('real_lm_topk'): sample["tokenized_sentences"] = parse_template_tokenize( template.strip(), sample["sub_label"].strip(), model, mask_part='relation') # substitute sub and obj placeholder in template with corresponding str # and add bracket to the relational phrase sample['bracket_sentences'] = bracket_relational_phrase( template.strip(), sample['sub_label'].strip(), sample['obj_label'].strip()) new_all_samples.append(sample) # create uuid if not present i = 0 for sample in new_all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 if args.lowercase and not args.upper_entity: # lowercase all samples logger.info("lowercasing all samples...") new_all_samples = lowercase_samples(new_all_samples) # shuffle data if shuffle_data: perm = np.random.permutation(len(new_all_samples)) new_all_samples = np.array(new_all_samples)[perm] raise Exception samples_batches, sentences_batches, ret_msg = batchify( new_all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") samples_batches_li.append(samples_batches) sentences_batches_li.append(sentences_batches) sub_obj_labels = [(sample['sub_label'], sample['obj_label']) for batch in samples_batches for sample in batch] if get_objs: print('sub_obj_label {}'.format('\t'.join( map(lambda p: '{}\t{}'.format(*p), sub_obj_labels)))) return if refine_template: bracket_sentences = [ sample['bracket_sentences'] for sample in new_all_samples ] new_temp = model.refine_cloze(bracket_sentences, batch_size=32, try_cuda=True) new_temp = replace_template(template.strip(), ' '.join(new_temp)) print('old temp: {}'.format(template.strip())) print('new temp: {}'.format(new_temp)) return new_temp # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) samples_batches_li = list(zip(*samples_batches_li)) sentences_batches_li = list(zip(*sentences_batches_li)) c_inc_meaning = ['top12 prob gap', 'top1 prob'] c_inc_stat = np.zeros((2, 3)) # [[*, c_num], [*, inc_num]] loss_list = [] features_list = [] features_list2 = [] bt_features_list = [] label_index_tensor_list = [] for i in tqdm(range(len(samples_batches_li))): samples_b_all = samples_batches_li[i] sentences_b_all = sentences_batches_li[i] filter_lp_merge = None filter_lp_merge2 = None samples_b = samples_b_all[-1] max_score = float('-inf') consist_score_li = [] samples_b_prev = None for sentences_b, samples_b_this in zip(sentences_b_all, samples_b_all): if samples_b_prev is not None: for ps, ts in zip(samples_b_prev, samples_b_this): assert ps['uuid'] == ts['uuid'] entity_list_b = [s['entity_list'] for s in samples_b_this] # TODO: add tokens_tensor and mask_tensor for more models original_log_probs_list, token_ids_list, masked_indices_list, tokens_tensor, mask_tensor = \ model.get_batch_generation(sentences_b, logger=logger, entity_list=entity_list_b) if model2 is not None: original_log_probs_list2, token_ids_list2, masked_indices_list2, tokens_tensor2, mask_tensor2 = \ model2.get_batch_generation(sentences_b, logger=logger) if use_prob: # use prob instead of log prob original_log_probs_list = original_log_probs_list.exp() if model2 is not None: original_log_probs_list2 = original_log_probs_list2.exp() if dynamic == 'real_lm' or dynamic.startswith('real_lm_topk'): sentences_b_mask_rel = [ s['tokenized_sentences'][0] for s in samples_b_this ] relation_mask = [ s['tokenized_sentences'][1] for s in samples_b_this ] consist_log_probs_list, _, _, tokens_tensor, mask_tensor = \ model.get_batch_generation(sentences_b_mask_rel, logger=logger, relation_mask=relation_mask) else: consist_log_probs_list = original_log_probs_list if dynamic == 'lm' or dynamic == 'real_lm' or dynamic.startswith( 'real_lm_topk'): # use avg prob of the templates as score mask_tensor = mask_tensor.float() consist_log_probs_list_flat = consist_log_probs_list.view( -1, consist_log_probs_list.size(-1)) token_logprob = torch.gather( consist_log_probs_list_flat, dim=1, index=tokens_tensor.view( -1, 1)).view(*consist_log_probs_list.size()[:2]) token_logprob = token_logprob * mask_tensor consist_score = token_logprob.sum(-1) / mask_tensor.sum( -1) # normalized prob ''' if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices ) else: filtered_log_probs_list = original_log_probs_list ''' # get the prediction probability if vocab_subset is not None: filtered_log_probs_list = [ flp[masked_indices_list[ind][0]].index_select( dim=-1, index=filter_logprob_indices) for ind, flp in enumerate(original_log_probs_list) ] if model2 is not None: filtered_log_probs_list2 = [ flp[masked_indices_list2[ind][0]].index_select( dim=-1, index=filter_logprob_indices) for ind, flp in enumerate(original_log_probs_list2) ] else: filtered_log_probs_list = [ flp[masked_indices_list[ind][0]] for ind, flp in enumerate(original_log_probs_list) ] if model2 is not None: filtered_log_probs_list2 = [ flp[masked_indices_list2[ind][0]] for ind, flp in enumerate(original_log_probs_list2) ] if dynamic.startswith('bt_topk'): obj_topk = int(dynamic.rsplit('-', 1)[1]) top_obj_pred = [ flp.topk(k=obj_topk) for flp in filtered_log_probs_list ] top_obj_logprob, top_obj_pred = zip(*top_obj_pred) if dynamic.startswith('obj_lm_topk'): # use highest obj prob as consistency score consist_score = torch.tensor( [torch.max(flp).item() for flp in filtered_log_probs_list]) elif dynamic.startswith('obj_lmgap_topk'): # the gap between the highest prediction log p1 - log p2 get_gap = lambda top2: (top2[0] - top2[1]).item() consist_score = torch.tensor([ get_gap(torch.topk(flp, k=2)[0]) for flp in filtered_log_probs_list ]) elif dynamic.startswith('bt_topk'): # use the obj_topk highest obj to "back translate" sub consist_score_obj_topk = [] used_vocab = vocab_subset if vocab_subset is not None else model.vocab for obj_i in range(obj_topk): sentences_b_mask_sub = [[ replace_list(s['sub_masked_sentences'][0][0], model.mask_token, used_vocab[obj_pred[obj_i].item()]) ] for s, obj_pred in zip(samples_b_this, top_obj_pred)] sub_mask = [ s['sub_masked_sentences'][1] for s in samples_b_this ] # TODO: only masked lm can do this consist_log_probs_list, _, _, tokens_tensor, mask_tensor = \ model.get_batch_generation(sentences_b_mask_sub, logger=logger, relation_mask=sub_mask) # use avg prob of the sub as score mask_tensor = mask_tensor.float() consist_log_probs_list_flat = consist_log_probs_list.view( -1, consist_log_probs_list.size(-1)) token_logprob = torch.gather( consist_log_probs_list_flat, dim=1, index=tokens_tensor.view( -1, 1)).view(*consist_log_probs_list.size()[:2]) token_logprob = token_logprob * mask_tensor consist_score = token_logprob.sum(-1) / mask_tensor.sum( -1) # normalized prob consist_score_obj_topk.append(consist_score) # SHAPE: (batch_size, obj_topk) consist_score_obj_topk = torch.stack( consist_score_obj_topk).permute(1, 0) consist_score_weight = torch.stack(top_obj_logprob).exp() # SHAPE: (batch_size) consist_score = (consist_score_obj_topk * consist_score_weight).sum(-1) / ( consist_score_weight.sum(-1) + 1e-10) # add to overall probability if filter_lp_merge is None: filter_lp_merge = filtered_log_probs_list if model2 is not None: filter_lp_merge2 = filtered_log_probs_list2 if dynamic == 'lm' or dynamic == 'real_lm': max_score = consist_score elif dynamic.startswith('real_lm_topk') or \ dynamic.startswith('obj_lm_topk') or \ dynamic.startswith('obj_lmgap_topk') or \ dynamic.startswith('bt_topk'): consist_score_li.append(consist_score) else: if dynamic == 'none' and temp_model is None: filter_lp_merge = [ a + b for a, b in zip(filter_lp_merge, filtered_log_probs_list) ] elif dynamic == 'all_topk': filter_lp_merge.extend(filtered_log_probs_list) elif dynamic == 'lm' or dynamic == 'real_lm': filter_lp_merge = \ [a if c >= d else b for a, b, c, d in zip(filter_lp_merge, filtered_log_probs_list, max_score, consist_score)] max_score = torch.max(max_score, consist_score) elif dynamic.startswith('real_lm_topk') or \ dynamic.startswith('obj_lm_topk') or \ dynamic.startswith('obj_lmgap_topk') or \ dynamic.startswith('bt_topk'): filter_lp_merge.extend(filtered_log_probs_list) consist_score_li.append(consist_score) elif temp_model is not None: filter_lp_merge.extend(filtered_log_probs_list) if model2 is not None: filter_lp_merge2.extend(filtered_log_probs_list2) samples_b_prev = samples_b_this label_index_list = [] obj_word_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif vocab_subset is not None and sample[ "obj_label"] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample["obj_label"])) label_index_list.append(obj_label_id) obj_word_list.append(sample['obj_label']) if dynamic == 'all_topk' or \ dynamic.startswith('real_lm_topk') or \ dynamic.startswith('obj_lm_topk') or \ dynamic.startswith('obj_lmgap_topk') or \ dynamic.startswith('bt_topk') or \ temp_model is not None: # analyze prob # SHAPE: (batch_size, num_temp, filter_vocab_size) filter_lp_merge = torch.stack(filter_lp_merge, 0).view( len(sentences_b_all), len(filter_lp_merge) // len(sentences_b_all), -1).permute(1, 0, 2) if model2 is not None: filter_lp_merge2 = torch.stack(filter_lp_merge2, 0).view( len(sentences_b_all), len(filter_lp_merge2) // len(sentences_b_all), -1).permute(1, 0, 2) # SHAPE: (batch_size) label_index_tensor = torch.tensor( [index_list.index(li[0]) for li in label_index_list]) c_inc = np.array( metrics.analyze_prob(filter_lp_merge, label_index_tensor, output=False, method='sample')) c_inc_stat += c_inc elif dynamic == 'none': # SHAPE: (batch_size, 1, filter_vocab_size) filter_lp_merge = torch.stack(filter_lp_merge, 0).unsqueeze(1) # SHAPE: (batch_size, num_temp, filter_vocab_size) filter_lp_unmerge = filter_lp_merge if temp_model is not None: # optimize template weights temp_model_, optimizer = temp_model if optimizer is None: # predict filter_lp_merge = temp_model_(args.relation, filter_lp_merge.detach(), target=None) elif optimizer == 'precompute': # pre-compute and save featuers lp = filter_lp_merge # SHAPE: (batch_size * num_temp) features = torch.gather(lp.contiguous().view(-1, lp.size(-1)), dim=1, index=label_index_tensor.repeat( lp.size(1)).view(-1, 1)) features = features.view(-1, lp.size(1)) features_list.append(features) if not bt_obj: continue elif optimizer is not None: # train on the fly features_list.append( filter_lp_merge ) # collect features that will later be used in optimization if model2 is not None: features_list2.append( filter_lp_merge2 ) # collect features that will later be used in optimization label_index_tensor_list.append( label_index_tensor) # collect labels if not bt_obj: continue else: #filter_lp_merge = temp_model_(args.relation, filter_lp_merge.detach(), target=None) filter_lp_merge = filter_lp_merge.mean( 1) # use average prob to beam search if dynamic.startswith('real_lm_topk') or \ dynamic.startswith('obj_lm_topk') or \ dynamic.startswith('obj_lmgap_topk') or \ dynamic.startswith('bt_topk'): # dynamic ensemble real_lm_topk = min( int(dynamic[dynamic.find('topk') + 4:].split('-')[0]), len(consist_score_li)) # SHAPE: (batch_size, num_temp) consist_score_li = torch.stack(consist_score_li, -1) # SHAPE: (batch_size, topk) consist_score, consist_ind = consist_score_li.topk(real_lm_topk, dim=-1) # SHAPE: (batch_size, 1) consist_score = consist_score.min(-1, keepdim=True)[0] # SHAPE: (batch_size, num_temp, 1) consist_mask = (consist_score_li >= consist_score).float().unsqueeze(-1) # avg over top k filter_lp_merge = filter_lp_merge * consist_mask filter_lp_merge = filter_lp_merge.sum(1) / consist_mask.sum(1) if bt_obj: # choose top bt_obj objects and bach-translate subject # get the top bt_obj objects with highest probability used_vocab = vocab_subset if vocab_subset is not None else model.vocab temp_model_, optimizer = temp_model if optimizer is None: # use beam search # SHAPE: (batch_size, bt_obj) objs_score, objs_ind = filter_lp_merge.topk(bt_obj, dim=-1) objs_ind = torch.sort(objs_ind, dim=-1)[0] # the index must be ascending elif optimizer == 'precompute': # use ground truth objs_ind = label_index_tensor.view(-1, 1) bt_obj = 1 elif optimizer is not None: # get both ground truth and beam search # SHAPE: (batch_size, bt_obj) objs_score, objs_ind = filter_lp_merge.topk(bt_obj, dim=-1) objs_ind = torch.cat( [objs_ind, label_index_tensor.view(-1, 1)], -1) objs_ind = torch.sort(objs_ind, dim=-1)[0] # the index must be ascending bt_obj += 1 # bach translation sub_lp_list = [] for sentences_b, samples_b_this in zip( sentences_b_all, samples_b_all): # iter over templates for obj_i in range(bt_obj): # iter over objs sentences_b_mask_sub = [] for s, obj_pred, obj_word in zip(samples_b_this, objs_ind, obj_word_list): replace_tok = used_vocab[obj_pred[obj_i].item()] if optimizer == 'precompute': assert replace_tok.strip() == obj_word.strip() sentences_b_mask_sub.append([ replace_list(s['sub_masked_sentences'][0][0], model.mask_token, replace_tok) ]) sub_mask = [ s['sub_masked_sentences'][1] for s in samples_b_this ] # TODO: only masked lm can do this lp, _, _, tokens_tensor, mask_tensor = \ model.get_batch_generation(sentences_b_mask_sub, logger=logger, relation_mask=sub_mask) # use avg prob of the sub as score mask_tensor = mask_tensor.float() lp_flat = lp.view(-1, lp.size(-1)) sub_lp = torch.gather(lp_flat, dim=1, index=tokens_tensor.view( -1, 1)).view(*lp.size()[:2]) sub_lp = sub_lp * mask_tensor sub_lp_avg = sub_lp.sum(-1) / mask_tensor.sum( -1) # normalized prob sub_lp_list.append(sub_lp_avg) # SHAPE: (batch_size, num_temp, top_obj_num) num_temp = len(sentences_b_all) sub_lp_list = torch.cat(sub_lp_list, 0).view(num_temp, bt_obj, -1).permute(2, 0, 1) if optimizer == 'precompute': bt_features_list.append(sub_lp_list.squeeze(-1)) continue elif optimizer is not None: sub_lp_list_expand = torch.zeros_like(filter_lp_unmerge) # SHAPE: (batch_size, num_temp, vocab_size) sub_lp_list_expand.scatter_( -1, objs_ind.unsqueeze(1).repeat(1, num_temp, 1), sub_lp_list) bt_features_list.append(sub_lp_list_expand) bt_obj -= 1 continue # select obj prob expand_mask = torch.zeros_like(filter_lp_unmerge) expand_mask.scatter_(-1, objs_ind.unsqueeze(1).repeat(1, num_temp, 1), 1) # SHAPE: (batch_size, num_temp, top_obj_num) obj_lp_list = torch.masked_select(filter_lp_unmerge, expand_mask.eq(1)).view( -1, num_temp, bt_obj) # run temp model # SHAPE: (batch_size, vocab_size) filter_lp_merge_expand = torch.zeros_like(filter_lp_merge) # SHAPE: (batch_size, top_obj_num) filter_lp_merge = temp_model_(args.relation, torch.cat([obj_lp_list, sub_lp_list], 1), target=None) # expand results to vocab_size filter_lp_merge_expand.scatter_(-1, objs_ind, filter_lp_merge) filter_lp_merge = filter_lp_merge_expand + expand_mask[:, 0, :].log( ) # mask out other objs if len(filter_lp_merge.size()) == 2: filter_lp_merge = filter_lp_merge.unsqueeze(1) for temp_id in range(filter_lp_merge.size(1)): arguments = [{ "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filter_lp_merge[:, :temp_id + 1].sum(1), token_ids_list, masked_indices_list, label_index_list, samples_b, )] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element["sample"] = sample element["uuid"] = sample["uuid"] element["token_ids"] = token_ids_list[idx] element["masked_indices"] = masked_indices_list[idx] element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk element["sample_MRR"] = sample_MRR element["sample_Precision"] = sample_P element["sample_perplexity"] = sample_perplexity element["sample_Precision1"] = result_masked_topk["P_AT_1"] # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() MRR[temp_id] += sample_MRR Precision[temp_id] += sample_P Precision1[temp_id] += element["sample_Precision1"] P1_li[temp_id].append(element["sample_Precision1"]) ''' if element["sample_Precision1"] == 1: print(element["sample"]) input(1) else: print(element["sample"]) input(0) ''' # the judgment of the annotators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if "judgments" in sample: # only for Google-RE for x in sample["judgments"]: if x["judgment"] == "yes": num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement[temp_id] += 1 element["judgement"] = "negative" MRR_negative[temp_id] += sample_MRR Precision_negative[temp_id] += sample_P else: samples_with_positive_judgement[temp_id] += 1 element["judgement"] = "positive" MRR_positive[temp_id] += sample_MRR Precision_positivie[temp_id] += sample_P list_of_results[temp_id].append(element) if temp_model is not None: if temp_model[1] == 'precompute': features = torch.cat(features_list, 0) if bt_obj: bt_features = torch.cat(bt_features_list, 0) features = torch.cat([features, bt_features], 1) return features if temp_model[1] is not None: # optimize the model on the fly temp_model_, (optimizer, temperature) = temp_model temp_model_.cuda() # SHAPE: (batch_size, num_temp, vocab_size) features = torch.cat(features_list, 0) if model2 is not None: features2 = torch.cat(features_list2, 0) if bt_obj: bt_features = torch.cat(bt_features_list, 0) features = torch.cat([features, bt_features], 1) # compute weight # SHAPE: (batch_size,) label_index_tensor = torch.cat(label_index_tensor_list, 0) label_count = torch.bincount(label_index_tensor) label_count = torch.index_select(label_count, 0, label_index_tensor) sample_weight = F.softmax( temperature * torch.log(1.0 / label_count.float()), 0) * label_index_tensor.size(0) min_loss = 1e10 es = 0 batch_size = 128 for e in range(500): # loss = temp_model_(args.relation, features.cuda(), target=label_index_tensor.cuda(), use_softmax=True) loss_li = [] for b in range(0, features.size(0), batch_size): features_b = features[b:b + batch_size].cuda() label_index_tensor_b = label_index_tensor[b:b + batch_size].cuda( ) sample_weight_b = sample_weight[b:b + batch_size].cuda() loss = temp_model_(args.relation, features_b, target=label_index_tensor_b, sample_weight=sample_weight_b, use_softmax=True) if model2 is not None: features2_b = features2[b:b + batch_size].cuda() loss2 = temp_model_(args.relation, features2_b, target=label_index_tensor_b, sample_weight=sample_weight_b, use_softmax=True) loss = loss + loss2 optimizer.zero_grad() loss.backward() optimizer.step() loss_li.append(loss.cpu().item()) dev_loss = np.mean(loss_li) if dev_loss - min_loss < -1e-3: min_loss = dev_loss es = 0 else: es += 1 if es >= 30: print('early stop') break temp_model_.cpu() return min_loss pool.close() pool.join() for temp_id in range(len(P1_li)): # stats # Mean reciprocal rank MRR[temp_id] /= len(list_of_results[temp_id]) # Precision Precision[temp_id] /= len(list_of_results[temp_id]) Precision1[temp_id] /= len(list_of_results[temp_id]) msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results[temp_id])) msg += "global MRR: {}\n".format(MRR[temp_id]) msg += "global Precision at 10: {}\n".format(Precision[temp_id]) msg += "global Precision at 1: {}\n".format(Precision1[temp_id]) if samples_with_negative_judgement[ temp_id] > 0 and samples_with_positive_judgement[temp_id] > 0: # Google-RE specific MRR_negative[temp_id] /= samples_with_negative_judgement[temp_id] MRR_positive[temp_id] /= samples_with_positive_judgement[temp_id] Precision_negative[temp_id] /= samples_with_negative_judgement[ temp_id] Precision_positivie[temp_id] /= samples_with_positive_judgement[ temp_id] msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement[temp_id]) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement[temp_id]) msg += "MRR_negative: {}\n".format(MRR_negative[temp_id]) msg += "MRR_positive: {}\n".format(MRR_positive[temp_id]) msg += "Precision_negative: {}\n".format( Precision_negative[temp_id]) msg += "Precision_positivie: {}\n".format( Precision_positivie[temp_id]) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict(list_of_results=list_of_results[temp_id], global_MRR=MRR, global_P_at_10=Precision) with open("{}/result.pkl".format(log_directory), "wb") as f: pickle.dump(all_results, f) print('P1all {}'.format('\t'.join(map(str, P1_li[temp_id])))) print('meaning: {}'.format(c_inc_meaning)) print('correct-incorrect {}'.format('\t'.join( map(str, (c_inc_stat[:, :-1] / (c_inc_stat[:, -1:] + 1e-5)).reshape(-1))))) return Precision1[-1]
def main(args, shuffle_data=True, model=None): if len(args.models_names) > 1: raise ValueError( 'Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names # print("------- Model: {}".format(model)) # print("------- Args: {}".format(args)) if model is None: model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), "w") as outfile: json.dump(vars(args), outfile) # Mean reciprocal rank MRR = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 data = load_file(args.dataset_filename) # data = data[:2000] fact_pair = load_file(args.fact_pair_filename) # print("@@@@@@@@@@@@@@") # print(fact_pair) # print(len(fact_pair)) # print("$$$$$$$$$$$$$$") all_samples, ret_msg = filter_samples(model, data, vocab_subset, args.max_sentence_length, args.template) # print("!!!!!!!!!!!!!") # print(len(all_samples)) # 30847 logger.info("\n" + ret_msg + "\n") # for sample in all_samples: # sample["masked_sentences"] = [sample['evidences'][0]['masked_sentence']] # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify( all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = {d['subject'] + "_" + d['object']: [] for d in fact_pair} list_of_ranks = {d['subject'] + "_" + d['object']: [] for d in fact_pair} for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] ( original_log_probs_list, token_ids_list, masked_indices_list, ) = model.get_batch_generation(sentences_b, logger=logger) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif model.vocab[obj_label_id[0]] != sample["obj_label"]: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"])) elif vocab_subset is not None and sample[ "obj_label"] not in vocab_subset: raise ValueError("object label {} not in vocab subset".format( sample["obj_label"])) label_index_list.append(obj_label_id) arguments = [{ "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0], "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, } for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, )] # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} obj = sample['obj_label'] sub = sample['sub_label'] element["masked_sentences"] = sample["masked_sentences"][0] element["uuid"] = sample["uuid"] element["subject"] = sub element["object"] = obj element["rank"] = int(result_masked_topk['rank']) element["sample_Precision1"] = result_masked_topk["P_AT_1"] # element["sample"] = sample # element["token_ids"] = token_ids_list[idx] # element["masked_indices"] = masked_indices_list[idx] # element["label_index"] = label_index_list[idx] # element["masked_topk"] = result_masked_topk # element["sample_MRR"] = sample_MRR # element["sample_Precision"] = sample_P # element["sample_perplexity"] = sample_perplexity list_of_results[sub + "_" + obj].append(element) list_of_ranks[sub + "_" + obj].append(element["rank"]) # print("~~~~~~ rank: {}".format(result_masked_topk['rank'])) MRR += sample_MRR Precision += sample_P Precision1 += element["sample_Precision1"] append_data_line_to_jsonl( "reproduction/data/TREx_filter/{}_rank_results.jsonl".format( args.label), element) # 3122 # list_of_results.append(element) pool.close() pool.join() # stats # Mean reciprocal rank # MRR /= len(list_of_results) # # Precision # Precision /= len(list_of_results) # Precision1 /= len(list_of_results) # msg = "all_samples: {}\n".format(len(all_samples)) # # msg += "list_of_results: {}\n".format(len(list_of_results)) # msg += "global MRR: {}\n".format(MRR) # msg += "global Precision at 10: {}\n".format(Precision) # msg += "global Precision at 1: {}\n".format(Precision1) # logger.info("\n" + msg + "\n") # print("\n" + msg + "\n") # dump pickle with the result of the experiment # all_results = dict( # list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision # ) # with open("{}/result.pkl".format(log_directory), "wb") as f: # pickle.dump(all_results, f) # print() # model_name = args.models_names[0] # if args.models_names[0] == "bert": # model_name = args.bert_model_name # elif args.models_names[0] == "elmo": # if args.bert_model_name == args.bert_model_name # else: # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_results.jsonl".format(args.label), list_of_results) # 3122 # save_data_line_to_jsonl("reproduction/data/TREx_filter/{}_rank_dic.jsonl".format(args.label), list_of_ranks) # 3122 save_data_line_to_jsonl( "reproduction/data/TREx_filter/{}_rank_list.jsonl".format(args.label), list(list_of_ranks.values())) # 3122 return Precision1
def main(args, shuffle_data=True, model=None, zsre=True, context_filter=None, single_token=False, inference_top_k=10, decode_to_end=False, seed=''): if len(args.models_names) > 1: raise ValueError('Please specify a single language model (e.g., --lm "bert").') msg = "" [model_type_name] = args.models_names print(model) #if model is None: # #model = build_model_by_name(model_type_name, args) if model_type_name == "fairseq": model_name = "fairseq_{}".format(args.fairseq_model_name) elif model_type_name == "bert": model_name = "BERT_{}".format(args.bert_model_name) elif model_type_name == "elmo": model_name = "ELMo_{}".format(args.elmo_model_name) else: model_name = model_type_name.title() # initialize logging if args.full_logdir: log_directory = args.full_logdir else: log_directory = create_logdir_with_timestamp(args.logdir, model_name) logger = init_logging(log_directory) msg += "model name: {}\n".format(model_name) # deal with vocab subset vocab_subset = None index_list = None msg += "args: {}\n".format(args) if args.common_vocab_filename is not None: vocab_subset = load_vocab(args.common_vocab_filename) msg += "common vocabulary size: {}\n".format(len(vocab_subset)) # optimization for some LM (such as ELMo) model.optimize_top_layer(vocab_subset) filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( vocab_subset, logger ) logger.info("\n" + msg + "\n") # dump arguments on file for log with open("{}/args.json".format(log_directory), "w") as outfile: json.dump(vars(args), outfile) # stats samples_with_negative_judgement = 0 samples_with_positive_judgement = 0 # Mean reciprocal rank MRR = 0.0 MRR_negative = 0.0 MRR_positive = 0.0 # Precision at (default 10) Precision = 0.0 Precision1 = 0.0 Precision_negative = 0.0 Precision_positivie = 0.0 # EM EM = 0.0 # F1 F1 = 0.0 is_error = 0 pred_too_large = 0 pred_too_small = 0 should_be_empty = 0 should_be_not_empty = 0 anchor_outside = 0 mismatch = 0 data = load_file(args.dataset_filename) print(len(data)) if args.lowercase: # lowercase all samples logger.info("lowercasing all samples...") all_samples = lowercase_samples(data) else: # keep samples as they are all_samples = data all_samples, ret_msg = filter_samples( model, data, vocab_subset, args.max_sentence_length, args.template, is_zsre=zsre ) # OUT_FILENAME = "{}.jsonl".format(args.dataset_filename) # with open(OUT_FILENAME, 'w') as outfile: # for entry in all_samples: # json.dump(entry, outfile) # outfile.write('\n') logger.info("\n" + ret_msg + "\n") print(len(all_samples)) # if template is active (1) use a single example for (sub,obj) and (2) ... if args.template and args.template != "": facts = [] sub_objs = [] for sample in all_samples: sub = sample["sub_label"] obj = sample["obj_label"] target = sample['reconstructed_word'] if 'reconstructed_word' not in sample: raise Exception('Reconstructed word not in sample... fix this') else: if 'masked_sentences' in sample: # Some of the masked sentences don't have a mask in them, need to find first with mask context = None for sent in sample['masked_sentences']: if '[MASK]' in sent: context = sent.replace('[MASK]', sample['reconstructed_word']) break if context is None: print('No valid context found, skipping sample') continue else: context = None for evidence in sample['evidences']: if not zsre: if '[MASK]' in evidence['masked_sentence']: context = evidence['masked_sentence'].replace('[MASK]', sample['reconstructed_word']) break else: context = evidence['masked_sentence'] if context is None: print('No valid context found, skipping sample') continue #context = context.replace('(', '') #context = context.replace(')', '') if (sub, target, context) not in sub_objs: sub_objs.append((sub, target, context)) if 'reconstructed_word' in sample: facts.append((sub, obj, context, sample['reconstructed_word'])) else: facts.append((sub, obj, context, obj)) #break local_msg = "distinct template facts: {}".format(len(facts)) logger.info("\n" + local_msg + "\n") print(local_msg) all_samples = [] for fact in facts: (sub, obj, context, rw) = fact sample = {} sample["sub_label"] = sub sample["obj_label"] = obj sample["reconstructed_word"] = rw # sobstitute all sentences with a standard template sample['context'] = context sample["masked_sentences"] = parse_template( args.template.strip(), sample["sub_label"].strip(), base.MASK ) #query = sample['masked_sentences'][0].replace(base.MASK, '') #sample['query'] = query #print(f'query={query}') #docs = retrieve_docs(query, ranker, conn, 30) #sample['context'] = docs[0] #print(f'docs={docs}') all_samples.append(sample) #else: # for sample in all_samples: # query = sample['masked_sentences'][0].replace(base.MASK, '') # sample['query'] = query # #print(f'query={query}') # docs = retrieve_docs(query, ranker, conn, 1) # sample['context'] = docs[0] # create uuid if not present i = 0 for sample in all_samples: if "uuid" not in sample: sample["uuid"] = i i += 1 # shuffle data if shuffle_data: shuffle(all_samples) samples_batches, sentences_batches, ret_msg = batchify(all_samples, args.batch_size) logger.info("\n" + ret_msg + "\n") # ThreadPool num_threads = args.threads if num_threads <= 0: # use all available threads num_threads = multiprocessing.cpu_count() pool = ThreadPool(num_threads) list_of_results = [] #example_results = [] #cf_results= [] #qualitative_results = [] for i in tqdm(range(len(samples_batches))): samples_b = samples_batches[i] sentences_b = sentences_batches[i] mymodel_probs_list = [] predictions_list = [] template = args.template.strip() for sample in samples_b: if decode_to_end: prediction = decode_lm(model, seed + args.question + ' = [Y]', sample['sub_label'], decode_to_end) predictions_list.append(prediction) else: prediction = decode_lm(model, args.template.strip(), sample['sub_label'], decode_to_end) predictions_list.append(prediction) original_log_probs_list, token_ids_list, masked_indices_list = model.get_batch_generation( sentences_b, logger=logger ) #original_log_probs_list, token_ids_list, masked_indices_list, predictions = model.get_batch_generation( #predictions_list = predictions mymodel_probs_list = original_log_probs_list #obj_len = 0 #for obj in gc.get_objects(): # try: # if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # print(type(obj), obj.size()) # obj_len += 1 # except: # pass #print(obj_len) if vocab_subset is not None: # filter log_probs filtered_log_probs_list = model.filter_logprobs( original_log_probs_list, filter_logprob_indices ) else: filtered_log_probs_list = original_log_probs_list label_index_list = [] for sample in samples_b: obj_label_id = model.get_id(sample["obj_label"]) # MAKE SURE THAT obj_label IS IN VOCABULARIES if obj_label_id is None: raise ValueError( "object label {} not in model vocabulary".format( sample["obj_label"] ) ) #elif model.vocab[obj_label_id[0]] != sample["obj_label"]: # raise ValueError( # "object label {} not in model vocabulary".format( # sample["obj_label"] # ) # ) elif vocab_subset is not None and sample["obj_label"] not in vocab_subset: raise ValueError( "object label {} not in vocab subset".format(sample["obj_label"]) ) label_index_list.append(obj_label_id) arguments = [ { "is_zsre": zsre, "mymodel_probs": mymodel_probs, "original_log_probs": original_log_probs, "filtered_log_probs": filtered_log_probs, "target": sample["reconstructed_word"], "prediction": pred, "token_ids": token_ids, "vocab": model.vocab, "label_index": label_index[0] if len(label_index) > 0 else 0, "masked_indices": masked_indices, "interactive": args.interactive, "index_list": index_list, "sample": sample, #"prediction": prediction } for mymodel_probs, original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample, pred in zip( mymodel_probs_list, original_log_probs_list, filtered_log_probs_list, token_ids_list, masked_indices_list, label_index_list, samples_b, predictions_list ) ] #for mymodel_probs, original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample, prediction in zip( # mymodel_probs_list, # original_log_probs_list, # filtered_log_probs_list, # token_ids_list, # masked_indices_list, # label_index_list, # samples_b, # predictions_list #) # single thread for debug # for isx,a in enumerate(arguments): # print(samples_b[isx]) # run_thread(a) # multithread res = pool.map(run_thread, arguments) for idx, result in enumerate(res): result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg, sample_em, sample_f1 = result logger.info("\n" + msg + "\n") sample = samples_b[idx] element = {} element["sample"] = sample element["uuid"] = sample["uuid"] element["token_ids"] = token_ids_list[idx] element["masked_indices"] = masked_indices_list[idx] element["label_index"] = label_index_list[idx] element["masked_topk"] = result_masked_topk element["sample_MRR"] = sample_MRR element["sample_Precision"] = sample_P element["sample_perplexity"] = sample_perplexity element["sample_Precision1"] = result_masked_topk["P_AT_1"] element['sample_em'] = sample_em element['sample_f1'] = sample_f1 # print() # print("idx: {}".format(idx)) # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) # for yi in range(10): # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) # print("masked_indices_list: {}".format(masked_indices_list[idx])) # print("sample_MRR: {}".format(sample_MRR)) # print("sample_P: {}".format(sample_P)) # print("sample: {}".format(sample)) # print() MRR += sample_MRR Precision += sample_P Precision1 += element["sample_Precision1"] EM += sample_em F1 += sample_f1 # the judgment of the annouators recording whether they are # evidence in the sentence that indicates a relation between two entities. num_yes = 0 num_no = 0 if "judgments" in sample: # only for Google-RE for x in sample["judgments"]: if x["judgment"] == "yes": num_yes += 1 else: num_no += 1 if num_no >= num_yes: samples_with_negative_judgement += 1 element["judgement"] = "negative" MRR_negative += sample_MRR Precision_negative += sample_P else: samples_with_positive_judgement += 1 element["judgement"] = "positive" MRR_positive += sample_MRR Precision_positivie += sample_P list_of_results.append(element) #df = pd.DataFrame(example_results) #df.to_csv('example_results.csv') pool.close() pool.join() # stats # Mean reciprocal rank MRR /= len(list_of_results) # Precision Precision /= len(list_of_results) Precision1 /= len(list_of_results) EM /= len(list_of_results) F1 /= len(list_of_results) msg = "all_samples: {}\n".format(len(all_samples)) msg += "list_of_results: {}\n".format(len(list_of_results)) msg += "global MRR: {}\n".format(MRR) msg += "global Precision at 10: {}\n".format(Precision) msg += "global Precision at 1: {}\n".format(Precision1) msg += "global EM {}\n".format(EM) msg += "global F1: {}\n".format(F1) if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0: # Google-RE specific MRR_negative /= samples_with_negative_judgement MRR_positive /= samples_with_positive_judgement Precision_negative /= samples_with_negative_judgement Precision_positivie /= samples_with_positive_judgement msg += "samples_with_negative_judgement: {}\n".format( samples_with_negative_judgement ) msg += "samples_with_positive_judgement: {}\n".format( samples_with_positive_judgement ) msg += "MRR_negative: {}\n".format(MRR_negative) msg += "MRR_positive: {}\n".format(MRR_positive) msg += "Precision_negative: {}\n".format(Precision_negative) msg += "Precision_positivie: {}\n".format(Precision_positivie) logger.info("\n" + msg + "\n") print("\n" + msg + "\n") # dump pickle with the result of the experiment all_results = dict( list_of_results=list_of_results, global_MRR=MRR, global_P_at_10=Precision ) with open("{}/result.pkl".format(log_directory), "wb") as f: pickle.dump(all_results, f) #logdf = pd.DataFrame(cf_results, columns=['Context', 'Subject', 'Template', 'Relation', 'Object', 'Score']) #logdf.to_csv("{}/cf_results4.csv".format(log_directory)) #logdf = pd.DataFrame(qualitative_results, columns=['Template', 'Subject', 'Context', 'Object', 'Prediction', 'Expanded Prediction']) #logdf.to_csv("{}/qual_results.csv".format(log_directory)) return Precision1, Precision, MRR, EM, F1