def score_distribion_training_fitting( data_name, topic_set, topic_id, query_file, qrel_file, doc_id_file, doc_text_file, # data parameters training_query_files, training_qrel_files, training_doc_id_files, training_doc_text_files, # training data target_recall=0.99, # if target_recall=1.0, score_bound will be -inf random_state=0): """ Implementation of the score distribution model -- training-fitting. :param data_name: :param topic_id: :param exp_id: :param training_topics: :return: """ np.random.seed(random_state) # model named with its configuration model_name = 'sdtf' + '-' model_name += 'tr' + str(target_recall) LOGGER.info('Model configuration: {}.'.format(model_name)) # collecting rel scores from tuning_topic_set training_rel_scores = [] for tquery_file, tqrel_file, tdoc_id_file, tdoc_text_file in zip( training_query_files, training_qrel_files, training_doc_id_files, training_doc_text_files): tquery = DataLoader.read_title(tquery_file) did2label = DataLoader.read_qrels(tqrel_file) complete_dids = DataLoader.read_doc_ids(tdoc_id_file) did2text = DataLoader.read_doc_texts(tdoc_text_file) complete_texts = [did2text[did] for did in complete_dids] # ranked dids, ranking scores ranked_dids, ranked_scores = bm25_okapi_rank(complete_dids, complete_texts, tquery) # normalized scores scaler = StandardScaler() ranked_scores = np.array(ranked_scores).reshape(-1, 1) norm_scores = list(scaler.fit_transform(ranked_scores).flatten()) # rel scores rel_scores = [ score for did, score in zip(ranked_dids, norm_scores) if did2label[did] == REL ] training_rel_scores.extend(rel_scores) # ranking dids, ranking scores assessor = Assessor(query_file, qrel_file, doc_id_file, doc_text_file) complete_dids = assessor.get_complete_dids() complete_texts = assessor.get_complete_texts() query = assessor.get_title() ranked_dids, ranked_scores = bm25_okapi_rank(complete_dids, complete_texts, query) # normalizing scores scaler = StandardScaler() ranked_scores = np.array(ranked_scores).reshape(-1, 1) norm_scores = list(scaler.fit_transform(ranked_scores).flatten()) # calculating cutoff loc, scale = norm.fit(training_rel_scores) score_bound = norm.ppf(1 - target_recall, loc=loc, scale=scale) small2big_ranked_scores = list(reversed(norm_scores)) cutoff = bisect.bisect_left(small2big_ranked_scores, score_bound) cutoff = max(1, cutoff) screen_dids = ranked_dids[:-cutoff] # output run file check_func = assessor.assess_state_check_func() tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(tar_run_file, 'w', encoding='utf8') as f: write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=screen_dids) LOGGER.info('TAR is finished.') return
def autostop_for_large_collection( data_name, topic_set, topic_id, query_files, qrel_files, doc_id_files, doc_text_files, # data parameters sampler_type='HTAPPriorSampler', epsilon=0.5, beta=-0.1, stopping_percentage=None, stopping_recall=None, target_recall=1.0, stopping_condition='strict1', # autostop parameters random_state=0): # sampler model_name = 'autostop' + '-' model_name += 'sp' + str(stopping_percentage) + '-' model_name += 'sr' + str(stopping_recall) + '-' model_name += 'smp' + str(sampler_type) + '-' if sampler_type == 'HTMixtureUniformSampler': model_name += 'epsilon' + str(epsilon) + '-' elif sampler_type == 'HTPowerLawSampler': model_name += 'beta' + str(beta) + '-' elif sampler_type == 'HHMixtureUniformSampler': model_name += 'epsilon' + str(epsilon) + '-' elif sampler_type == 'HHPowerLawSampler': model_name += 'beta' + str(beta) + '-' else: raise NotImplementedError model_name += 'tr' + str(target_recall) + '-' model_name += 'sc' + stopping_condition LOGGER.info('Model configuration: {}.'.format(model_name)) # starting the TAR process after splitting the complete collection into small blocks total_shown_dids = [] for query_file, qrel_file, doc_id_file, doc_text_file in zip( query_files, qrel_files, doc_id_files, doc_text_files): # loading data assessor = Assessor(query_file, qrel_file, doc_id_file, doc_text_file) complete_dids = assessor.get_complete_dids() complete_labels = assessor.get_complete_labels() complete_pseudo_dids = assessor.get_complete_pseudo_dids() complete_pseudo_texts = assessor.get_complete_pseudo_texts() did2label = assessor.get_did2label() total_true_r = assessor.get_total_rel_num() total_num = assessor.get_total_doc_num() # ranker ranker = Ranker() ranker.set_did_2_feature(dids=complete_pseudo_dids, texts=complete_pseudo_texts, corpus_texts=complete_pseudo_texts) ranker.set_features_by_name('complete_dids', complete_dids) if sampler_type == 'HTMixtureUniformSampler': sampler = HTMixtureUniformSampler() sampler.init(complete_dids, complete_labels) elif sampler_type == 'HTUniformSampler': sampler = HTUniformSampler() sampler.init(complete_dids, complete_labels) sampler.update_distribution() elif sampler_type == 'HTPowerLawSampler': sampler = HTPowerLawSampler() sampler.init(beta, complete_dids, complete_labels) sampler.update_distribution(beta=beta) elif sampler_type == 'HTAPPriorSampler': sampler = HTAPPriorSampler() sampler.init(complete_dids, complete_labels) sampler.update_distribution() elif sampler_type == 'HHMixtureUniformSampler': sampler = HHMixtureUniformSampler() sampler.init(total_num, did2label) elif sampler_type == 'HHPowerLawSampler': sampler = HHPowerLawSampler() sampler.init(total_num, did2label) sampler.update_distribution(beta=beta) elif sampler_type == 'HHAPPriorSampler': sampler = HHAPPriorSampler() sampler.init(total_num, did2label) sampler.update_distribution() else: raise TypeError # local parameters stopping = False t = 0 batch_size = 1 temp_doc_num = 100 while not stopping: t += 1 train_dids, train_labels = assessor.get_training_data(temp_doc_num) train_features = ranker.get_feature_by_did(train_dids) ranker.train(train_features, train_labels) test_features = ranker.get_features_by_name('complete_dids') scores = ranker.predict(test_features) zipped = sorted(zip(complete_dids, scores), key=itemgetter(1), reverse=True) ranked_dids, _ = zip(*zipped) if sampler_type == 'HHMixtureUniformSampler' or sampler_type == 'HTMixtureUniformSampler': sampler.update_distribution(epsilon=epsilon, alpha=batch_size) sampled_dids = sampler.sample(t, ranked_dids, batch_size, stopping_condition) assessor.update_assess(sampled_dids) sampled_state = assessor.get_assessed_state() total_esti_r, var1, var2 = sampler.estimate( t, stopping_condition, sampled_state) # statistics sampled_num = assessor.get_assessed_num() running_true_r = assessor.get_assessed_rel_num() if total_true_r != 0: running_true_recall = running_true_r / float(total_true_r) else: running_true_recall = 0 # update parameters batch_size += math.ceil(batch_size / 10) # autostop if running_true_r > 0: if stopping_condition == 'loose': if running_true_r >= target_recall * total_esti_r: stopping = True elif stopping_condition == 'strict1': if running_true_r >= target_recall * (total_esti_r + np.sqrt(var1)): stopping = True elif stopping_condition == 'strict2': if running_true_r >= target_recall * (total_esti_r + np.sqrt(var2)): stopping = True else: raise NotImplementedError # stop early if stopping_recall: if running_true_recall >= stopping_recall: stopping = True if stopping_percentage: if sampled_num >= int(total_num * stopping_percentage): stopping = True sampled_dids = assessor.get_assessed_dids() shown_features = ranker.get_feature_by_did(sampled_dids) shown_scores = ranker.predict(shown_features) zipped = sorted(zip(sampled_dids, shown_scores), key=itemgetter(1), reverse=True) shown_dids, scores = zip(*zipped) total_shown_dids.extend(shown_dids) tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(tar_run_file, 'w', encoding='utf8') as f: write_tar_run_file(f=f, topic_id=topic_id, check_func=lambda x: True, shown_dids=total_shown_dids) LOGGER.info('TAR is finished.') return
def score_distribion_feedback_uniform( data_name, topic_set, topic_id, query_file, qrel_file, doc_id_file, doc_text_file, # data parameters sample_percentage=0.1, target_recall=0.99, random_state=0): np.random.seed(random_state) # model named with its configuration model_name = 'sdfu' + '-' model_name += 'smp' + str(sample_percentage) + '-' model_name += 'tr' + str(target_recall) LOGGER.info('Model configuration: {}.'.format(model_name)) # loading data assessor = Assessor(query_file, qrel_file, doc_id_file, doc_text_file) complete_dids = assessor.get_complete_dids() complete_texts = assessor.get_complete_texts() query = assessor.get_title() total_num = assessor.get_total_doc_num() # ranking dids, ranking scores ranked_dids, ranked_scores = bm25_okapi_rank(complete_dids, complete_texts, query) # normalizing scores scaler = StandardScaler() ranked_scores = np.array(ranked_scores).reshape(-1, 1) norm_scores = list(scaler.fit_transform(ranked_scores).flatten()) # uniformly sampling some documents to fit Gaussian sample_num = int(sample_percentage * total_num) sampled_dids = list( np.random.choice(a=complete_dids, size=sample_num, replace=False).flatten()) sampled_rel_scores = [ norm_scores[ranked_dids.index(did)] for did in sampled_dids if assessor.get_rel_label(did) == REL ] if sampled_rel_scores == []: sampled_rel_scores = [0] # standard normal distribution # calculating cutoff mean, std = norm.fit(sampled_rel_scores) try: score_bound = norm.ppf(1 - target_recall, loc=mean, scale=std) except: score_bound = norm.ppf(1 - target_recall) cutoff = bisect.bisect_left(list(reversed(norm_scores)), score_bound) cutoff = max(1, cutoff) screen_dids = ranked_dids[:-cutoff] # output run file # NO INTERACTION: only apply ranked_dids order check_func = assessor.assess_state_check_func() shown_dids = screen_dids for did in ranked_dids[-cutoff:]: if check_func(did) is True: screen_dids.append(did) tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(tar_run_file, 'w', encoding='utf8') as f: write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=shown_dids) LOGGER.info('TAR is finished.') return
def scal_method(data_name, topic_set, topic_id, query_file, qrel_file, doc_id_file, doc_text_file, sub_percentage, bound_bt, ita, # data parameters stopping_percentage=1.0, stopping_recall=1.0, target_recall=1.0, # autostop parameters max_or_min='min', bucket_type='samplerel', # scal parameters random_state=0): data_name, topic_id, topic_set, query_file, qrel_file, doc_id_file, doc_text_file,sub_percentage,bound_bt,ita """ Implementation of the S-CAL method [1]. @param data_name: @param topic_set: @param topic_id: @param stopping_percentage: @param stopping_recall: @param target_recall: @param sub_percentage: @param bound_bt: sample size per batch @param max_or_min: @param bucket_type: @param ita: @param random_state: @return: """ np.random.seed(random_state) # model named with its configuration model_name = 'scal' + '-' model_name += 'sp' + str(stopping_percentage) + '-' model_name += 'sr' + str(stopping_recall) + '-' model_name += 'tr' + str(target_recall) + '-' model_name += 'spt' + str(sub_percentage) + '-' model_name += 'bnd' + str(bound_bt) + '-' model_name += 'mxn' + max_or_min + '-' model_name += 'bkt' + bucket_type + '-' model_name += 'ita' + str(ita) LOGGER.info('Model configuration: {}.'.format(model_name)) # loading data datamanager = Assessor(query_file, qrel_file, doc_id_file, doc_text_file) complete_dids = datamanager.get_complete_dids() complete_pseudo_dids = datamanager.get_complete_pseudo_dids() #complete_pseudo_texts = datamanager.get_complete_pseudo_texts() #corpus_texts = complete_pseudo_texts did2label = datamanager.get_did2label() total_true_r = datamanager.get_total_rel_num() total_num = datamanager.get_total_doc_num() # preparing document features ranker = Ranker() ranker.set_did_2_feature(dids=complete_pseudo_dids, data_name = data_name) ranker.set_features_by_name('complete_dids', complete_dids) # SCAL sampler sampler = SCALSampler() # sampling a large sample set before the TAR process if sub_percentage < 1.0: u = int(sub_percentage * total_num) sub_dids = list(np.random.choice(a=complete_dids, size=u, replace=False).flatten()) elif sub_percentage == 1.0: u = total_num sub_dids = complete_dids else: raise NotImplementedError ranker.set_features_by_name('sub_dids', sub_dids) # local parameters stopping = False t = 0 batch_size = 1 temp_doc_num = 100 n = bound_bt total_esti_r = 0 temp_list = [] # starting the TAR process interaction_file = name_interaction_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(interaction_file, 'w', encoding='utf8') as f: csvwriter = csv.writer(f) while not stopping: t += 1 LOGGER.info('iteration {}, batch_size {}'.format(t, batch_size)) # train train_dids, train_labels = datamanager.get_training_data(temp_doc_num=temp_doc_num) train_features = ranker.get_feature_by_did(train_dids) ranker.train(train_features, train_labels) # predict sub_features = ranker.get_features_by_name('sub_dids') scores = ranker.predict(sub_features) zipped = sorted(zip(sub_dids, scores), key=itemgetter(1), reverse=True) ranked_dids, _ = zip(*zipped) bucketed_dids, sampled_dids, batch_esti_r = sampler.sample(ranked_dids, n, batch_size, did2label) datamanager.update_assess(sampled_dids) # estimating total_esti_r += batch_esti_r # statistics sampled_num = datamanager.get_assessed_num() running_true_r = datamanager.get_assessed_rel_num() if total_esti_r != 0: running_esti_recall = running_true_r / float(total_esti_r) else: running_esti_recall = 0 if total_true_r != 0: running_true_recall = running_true_r / float(total_true_r) else: running_true_recall = 0 ap = calculate_ap(did2label, ranked_dids) # update parameters if batch_size < total_num: # avoid OverflowError batch_size += math.ceil(batch_size / 10) # debug: writing values csvwriter.writerow( (t, batch_size, total_num, sampled_num, total_true_r, total_esti_r, running_true_r, ap, running_esti_recall, running_true_recall)) cum_bucketed_dids = sampler.get_bucketed_dids() cum_sampled_dids = sampler.get_sampled_dids() temp_list.append((total_esti_r, ranker, cum_bucketed_dids, cum_sampled_dids)) # when sub sample is exhausted, stop len_bucketed_dids = len(cum_bucketed_dids) if len_bucketed_dids == u: stopping = True # debug: stop early if stopping_recall: if running_true_recall >= stopping_recall: stopping = True if stopping_percentage: if sampled_num >= int(total_num * stopping_percentage): stopping = True # estimating rho final_total_esti_r = ita * total_esti_r # calibrating the estimation in [1] if max_or_min == 'max': max_or_min_func = max elif max_or_min == 'min': max_or_min_func = min else: raise NotImplementedError # finding the first ranker that satisfies the stopping strategy, otherwise using the last ranker for total_esti_r, ranker, bucketed_dids, cum_sampled_dids in temp_list: if target_recall * final_total_esti_r <= total_esti_r: break if bucket_type == 'bucket': filtered_dids = bucketed_dids elif bucket_type == 'sample': filtered_dids = cum_sampled_dids elif bucket_type == 'samplerel': filtered_dids = [did for did in cum_sampled_dids if datamanager.get_rel_label(did) == 1] else: raise NotImplementedError if filtered_dids != []: features = ranker.get_feature_by_did(filtered_dids) scores = ranker.predict(features) threshold = max_or_min_func(scores) else: threshold = -1 # rank complete dids train_dids, train_labels = datamanager.get_training_data(temp_doc_num=0) train_features = ranker.get_feature_by_did(train_dids) ranker.train(train_features, train_labels) complete_features = ranker.get_feature_by_did(complete_dids) complete_scores = ranker.predict(complete_features) zipped = sorted(zip(complete_dids, complete_scores), key=itemgetter(1), reverse=True) # shown dids shown_dids = [] check_func = datamanager.assess_state_check_func() for i, (did, score) in enumerate(zipped): if score >= threshold or check_func(did) is True: shown_dids.append(did) tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(tar_run_file, 'w', encoding='utf8') as f: write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=shown_dids) LOGGER.info('TAR is finished.') return
def autostop_method(data_name, topic_set, topic_id, query_file, qrel_file, doc_id_file, doc_text_file, target_recall, # data parameters sampler_type, stopping_condition, epsilon=0.5, beta=-0.1, stopping_percentage=None, stopping_recall=1.0, # autostop parameters random_state=0): np.random.seed(random_state) datamanager = Assessor(query_file, qrel_file, doc_id_file, doc_text_file) complete_dids = datamanager.get_complete_dids() complete_pseudo_dids = datamanager.get_complete_pseudo_dids() #complete_pseudo_texts = datamanager.get_complete_pseudo_texts() #corpus_texts = complete_pseudo_texts did2label = datamanager.get_did2label() total_true_r = datamanager.get_total_rel_num() total_num = datamanager.get_total_doc_num() complete_labels = datamanager.get_complete_labels() # preparing document features ranker = Ranker() ranker.set_did_2_feature(dids=complete_pseudo_dids, data_name = data_name) ranker.set_features_by_name('complete_dids', complete_dids) # loading data # sampler model_name = 'autostop' +'-' model_name += 'sp' + str(stopping_percentage) + '-' model_name += 'sr' + str(stopping_recall) + '-' model_name += 'smp' + str(sampler_type) + '-' if sampler_type == 'HTMixtureUniformSampler': sampler = HTMixtureUniformSampler() sampler.init(complete_dids, complete_labels) elif sampler_type == 'HTUniformSampler': sampler = HTUniformSampler() sampler.init(complete_dids, complete_labels) sampler.update_distribution() elif sampler_type == 'HTPowerLawSampler': sampler = HTPowerLawSampler() sampler.init(beta, complete_dids, complete_labels) sampler.update_distribution(beta=beta) elif sampler_type == 'HTAPPriorSampler': sampler = HTAPPriorSampler() sampler.init(complete_dids, complete_labels) sampler.update_distribution() elif sampler_type == 'HHMixtureUniformSampler': sampler = HHMixtureUniformSampler() sampler.init(total_num, did2label) elif sampler_type == 'HHPowerLawSampler': sampler = HHPowerLawSampler() sampler.init(total_num, did2label) sampler.update_distribution(beta=beta) elif sampler_type == 'HHAPPriorSampler': sampler = HHAPPriorSampler() sampler.init(total_num, did2label) sampler.update_distribution() else: print(sampler_type) print(stopping_condition) raise TypeError model_name += 'tr' + str(target_recall) + '-' model_name += 'sc' + stopping_condition # local parameters stopping = False t = 0 batch_size = 1 temp_doc_num = 100 # starting the TAR process interaction_file = name_interaction_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(interaction_file, 'w', encoding='utf8') as f: csvwriter = csv.writer(f) while not stopping: t += 1 train_dids, train_labels = datamanager.get_training_data(temp_doc_num) train_features = ranker.get_feature_by_did(train_dids) ranker.train(train_features, train_labels) test_features = ranker.get_features_by_name('complete_dids') scores = ranker.predict(test_features) zipped = sorted(zip(complete_dids, scores), key=itemgetter(1), reverse=True) ranked_dids, _ = zip(*zipped) if sampler_type == 'HHMixtureUniformSampler' or sampler_type == 'HTMixtureUniformSampler': sampler.update_distribution(epsilon=epsilon, alpha=batch_size) sampled_dids = sampler.sample(t, ranked_dids, batch_size, stopping_condition) datamanager.update_assess(sampled_dids) sampled_state = datamanager.get_assessed_state() total_esti_r, var1, var2 = sampler.estimate(t, stopping_condition, sampled_state) # statistics sampled_num = datamanager.get_assessed_num() running_true_r = datamanager.get_assessed_rel_num() if total_esti_r != 0: running_esti_recall = running_true_r / float(total_esti_r) else: running_esti_recall = 0 if total_true_r != 0: running_true_recall = running_true_r / float(total_true_r) else: running_true_recall = 0 ap = calculate_ap(did2label, ranked_dids) # update parameters batch_size += math.ceil(batch_size / 10) # debug: writing values csvwriter.writerow( (t, batch_size, total_num, sampled_num, total_true_r, total_esti_r, var1, var2, running_true_r, ap, running_esti_recall, running_true_recall)) # autostop if running_true_r > 0: if stopping_condition == 'loose': if running_true_r >= target_recall * total_esti_r: stopping = True elif stopping_condition == 'strict1': if running_true_r >= target_recall * (total_esti_r + np.sqrt(var1)): stopping = True elif stopping_condition == 'strict2': if running_true_r >= target_recall * (total_esti_r + np.sqrt(var2)): stopping = True else: raise NotImplementedError # debug: stop early if stopping_recall: if running_true_recall >= stopping_recall: stopping = True if stopping_percentage: if sampled_num >= int(total_num * stopping_percentage): stopping = True # writing sampled_dids = datamanager.get_assessed_dids() shown_features = ranker.get_feature_by_did(sampled_dids) shown_scores = ranker.predict(shown_features) zipped = sorted(zip(sampled_dids, shown_scores), key=itemgetter(1), reverse=True) shown_dids, scores = zip(*zipped) check_func = datamanager.assess_state_check_func() tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(tar_run_file, 'w', encoding='utf8') as f: write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=shown_dids) LOGGER.info('TAR is finished.') return
def knee_method( data_name, topic_set, topic_id, query_file, qrel_file, doc_id_file, doc_text_file, text_metrics, # data parameters stopping_beta=100, stopping_percentage=1.0, stopping_recall=None, # autostop parameters rho='6', random_state=0): """ Implementation of the Knee method. See @param data_name: @param topic_set: @param topic_id: @param stopping_beta: stopping_beta: only stop TAR process until at least beta documents had been screen @param stopping_percentage: @param stopping_recall: @param rho: @param random_state: @return: """ np.random.seed(random_state) # model named with its configuration model_name = 'knee' + '-' model_name += 'sb' + str(stopping_beta) + '-' model_name += 'sp' + str(stopping_percentage) + '-' model_name += 'sr' + str(stopping_recall) + '-' model_name += 'rho' + str(rho) LOGGER.info('Model configuration: {}.'.format(model_name)) # loading data assessor = Assessor(query_file, qrel_file, doc_id_file, doc_text_file, text_metrics) complete_dids = assessor.get_complete_dids() complete_pseudo_dids = assessor.get_complete_pseudo_dids() complete_pseudo_texts = assessor.get_complete_pseudo_texts() complete_pseudo_metrics = assessor.get_complete_pseudo_metrics() did2label = assessor.get_did2label() total_true_r = assessor.get_total_rel_num() total_num = assessor.get_total_doc_num() # preparing document features ranker = Ranker() ranker.set_did_2_feature(dids=complete_pseudo_dids, texts=complete_pseudo_texts, corpus_texts=complete_pseudo_texts, metrics=complete_pseudo_metrics) ranker.set_features_by_name('complete_dids', complete_dids) # local parameters stopping = False t = 0 batch_size = 1 temp_doc_num = 100 knee_data = [] # starting the TAR process interaction_file = name_interaction_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(interaction_file, 'w', encoding='utf8') as f: csvwriter = csv.writer(f) while not stopping: t += 1 LOGGER.info('TAR: iteration={}'.format(t)) train_dids, train_labels = assessor.get_training_data(temp_doc_num) train_features = ranker.get_feature_by_did(train_dids) train_features_metrics = ranker.get_feature_metric_by_did( train_dids) ranker.train(train_features, train_features_metrics, train_labels) test_features = ranker.get_features_by_name('complete_dids') test_metrics = ranker.get_metrics_by_name(complete_dids) scores = ranker.predict(test_features, test_metrics) zipped = sorted(zip(complete_dids, scores), key=itemgetter(1), reverse=True) ranked_dids, _ = zip(*zipped) # cutting off instead of sampling selected_dids = assessor.get_top_assessed_dids( ranked_dids, batch_size) assessor.update_assess(selected_dids) # statistics sampled_num = assessor.get_assessed_num() sampled_percentage = sampled_num / total_num running_true_r = assessor.get_assessed_rel_num() running_true_recall = running_true_r / float(total_true_r) ap = calculate_ap(did2label, ranked_dids) # update parameters batch_size += math.ceil(batch_size / 10) # debug: writing values csvwriter.writerow( (t, batch_size, total_num, sampled_num, total_true_r, running_true_r, ap, running_true_recall)) # detect knee knee_data.append((sampled_num, running_true_r)) knee_indice = detect_knee( knee_data) # x: sampled_percentage, y: running_true_r if knee_indice is not None: knee_index = knee_indice[-1] rank1, r1 = knee_data[knee_index] rank2, r2 = knee_data[-1] try: current_rho = float(r1 / rank1) / float( (r2 - r1 + 1) / (rank2 - rank1)) except: print( '(rank1, r1) = ({} {}), (rank2, r2) = ({} {})'.format( rank1, r1, rank2, r2)) current_rho = 0 # do not stop if rho == 'dynamic': rho = 156 - min(running_true_r, 150) # rho is in [6, 156], see [1] else: rho = float(rho) if current_rho > rho: if sampled_num > stopping_beta: stopping = True # debug: stop early if stopping_recall: if running_true_recall >= stopping_recall: stopping = True if stopping_percentage: if sampled_num >= int(total_num * stopping_percentage): stopping = True shown_dids = assessor.get_assessed_dids() check_func = assessor.assess_state_check_func() tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(tar_run_file, 'w', encoding='utf8') as f: write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=shown_dids) LOGGER.info('TAR is finished.') return
def target_method( data_name, topic_set, topic_id, query_file, qrel_file, doc_id_file, doc_text_file, # data parameters stopping_percentage=None, stopping_recall=None, # autostop parameters target_rel_num=10, # target parameter random_state=0): np.random.seed(random_state) # model named with its configuration model_name = 'target' + '-' model_name += 'sp' + str(stopping_percentage) + '-' model_name += 'sr' + str(stopping_recall) + '-' model_name += 'trn' + str(target_rel_num) LOGGER.info('Model configuration: {}.'.format(model_name)) # loading data assessor = Assessor(query_file, qrel_file, doc_id_file, doc_text_file) complete_dids = assessor.get_complete_dids() complete_pseudo_dids = assessor.get_complete_pseudo_dids() complete_pseudo_texts = assessor.get_complete_pseudo_texts() did2label = assessor.get_did2label() total_true_r = assessor.get_total_rel_num() total_num = assessor.get_total_doc_num() # preparing document features ranker = Ranker() ranker.set_did_2_feature(dids=complete_pseudo_dids, texts=complete_pseudo_texts, corpus_texts=complete_pseudo_texts) ranker.set_features_by_name('complete_dids', complete_dids) # sample target set: the target set should not affect the interaction process. population = complete_dids target_set = set() sample_size = 100 sample_rel_num = 0 while not (len(population) == 0 or sample_rel_num >= target_rel_num): population = list(set(complete_dids).difference(target_set)) sample_size = min(len(population), sample_size) sampled_dids = set( np.random.choice(a=population, size=sample_size, replace=False)) # unique random elements sample_rel_num += len( [did for did in sampled_dids if assessor.get_rel_label(did) == 1]) target_set = target_set.union(sampled_dids) target_rel_set = set( [did for did in target_set if assessor.get_rel_label(did) == 1]) # starting the TAR process stopping = False t = 0 batch_size = 100 temp_doc_num = 100 interaction_file = name_interaction_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(interaction_file, 'w', encoding='utf8') as f: csvwriter = csv.writer(f) while not stopping: t += 1 # train train_dids, train_labels = assessor.get_training_data(temp_doc_num) train_features = ranker.get_feature_by_did(train_dids) ranker.train(train_features, train_labels) # predict complete_features = ranker.get_features_by_name('complete_dids') scores = ranker.predict(complete_features) zipped = sorted(zip(complete_dids, scores), key=itemgetter(1), reverse=True) ranked_dids, scores = zip(*zipped) # cutting off instead of sampling selected_dids = assessor.get_top_assessed_dids( ranked_dids, batch_size) assessor.update_assess(selected_dids) # statistics sampled_dids = assessor.get_assessed_dids() sampled_num = len(set(sampled_dids).union(target_set)) running_true_r = assessor.get_assessed_rel_num() running_true_recall = running_true_r / float(total_true_r) ap = calculate_ap(did2label, ranked_dids) # update parameters batch_size += math.ceil(batch_size / 10) # debug: writing values csvwriter.writerow( (t, batch_size, total_num, sampled_num, total_true_r, running_true_r, ap, running_true_recall)) sampled_rel_set = set(assessor.get_assessed_rel_dids()) if set(target_rel_set).issubset(sampled_rel_set): stopping = True if sampled_num >= total_num: # dids are exhausted stopping = True # stop early if stopping_recall: if running_true_recall >= stopping_recall: stopping = True if stopping_percentage: if sampled_num >= int(total_num * stopping_percentage): stopping = True # INTERACTION METHOD: apply ranked_dids order. sampled_dids = assessor.get_assessed_dids() shown_dids = list(target_set.union(set(sampled_dids))) shown_features = ranker.get_feature_by_did(shown_dids) shown_scores = ranker.predict(shown_features) zipped = sorted(zip(shown_dids, shown_scores), key=itemgetter(1), reverse=True) shown_dids, scores = zip(*zipped) check_func = assessor.assess_state_check_func() tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(tar_run_file, 'w', encoding='utf8') as f: write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=shown_dids) LOGGER.info('TAR is finished.') return
def autotar_method( data_name, topic_set, topic_id, query_file, qrel_file, doc_id_file, doc_text_file, # data parameters stopping_percentage=1.0, stopping_recall=None, # autostop parameters, for debug ranker_tfidf_corpus_files=[], classifier='lr', min_df=2, C=1.0, # ranker parameters random_state=0): """ Implementation of the TAR process. @param data_name: dataset name @param topic_set: parameter-tuning set or test set @param topic_id: topic id @param stopping_percentage: stop TAR when x percentage of documents have been screened @param stopping_recall: stop TAR when x recall is achieved @param corpus_type: indicates what corpus to use when building features, see Ranker @param min_df: parameter of Ranker @param C: parameter of Ranker @param save_did2feature: save the did2feature dict as a pickle to fasten experiments @param random_state: random seed @return: """ np.random.seed(random_state) # model named with its configuration model_name = 'autotar' + '-' model_name += 'sp' + str(stopping_percentage) + '-' model_name += 'sr' + str(stopping_recall) + '-' model_name += 'ct' + str(ranker_tfidf_corpus_files) + '-' model_name += 'csf' + classifier + '-' model_name += 'md' + str(min_df) + '-' model_name += 'c' + str(C) LOGGER.info('Model configuration: {}.'.format(model_name)) # loading data datamanager = Assessor(query_file, qrel_file, doc_id_file, doc_text_file) complete_dids = datamanager.get_complete_dids() complete_pseudo_dids = datamanager.get_complete_pseudo_dids() #complete_pseudo_texts = datamanager.get_complete_pseudo_texts() #corpus_texts = complete_pseudo_texts did2label = datamanager.get_did2label() total_true_r = datamanager.get_total_rel_num() total_num = datamanager.get_total_doc_num() complete_labels = datamanager.get_complete_labels() # preparing document features ranker = Ranker() ranker.set_did_2_feature(dids=complete_pseudo_dids, data_name=data_name) ranker.set_features_by_name('complete_dids', complete_dids) # local parameters are set according to [1] stopping = False t = 0 batch_size = 1 temp_doc_num = 100 interaction_file = name_interaction_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(interaction_file, 'w', encoding='utf8') as f: csvwriter = csv.writer(f) while not stopping: t += 1 train_dids, train_labels = datamanager.get_training_data( temp_doc_num) train_features = ranker.get_feature_by_did(train_dids) ranker.train(train_features, train_labels) test_features = ranker.get_features_by_name('complete_dids') scores = ranker.predict(test_features) zipped = sorted(zip(complete_dids, scores), key=itemgetter(1), reverse=True) ranked_dids, scores = zip(*zipped) # cutting off instead of sampling selected_dids = datamanager.get_top_assessed_dids( ranked_dids, batch_size) datamanager.update_assess(selected_dids) # statistics sampled_num = datamanager.get_assessed_num() running_true_r = datamanager.get_assessed_rel_num() running_true_recall = running_true_r / float(total_true_r) ap = calculate_ap(did2label, ranked_dids) # update parameters batch_size += math.ceil(batch_size / 10) # debug: writing values csvwriter.writerow( (t, batch_size, total_num, sampled_num, total_true_r, running_true_r, ap, running_true_recall)) # debug: stop early if stopping_recall: if running_true_recall >= stopping_recall: stopping = True if stopping_percentage: if sampled_num >= int(total_num * stopping_percentage): stopping = True # tar run file shown_dids = datamanager.get_assessed_dids() check_func = datamanager.assess_state_check_func() tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set, exp_id=random_state, topic_id=topic_id) with open(tar_run_file, 'w', encoding='utf8') as f: write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=shown_dids) LOGGER.info('TAR is finished.') return