コード例 #1
0
def autostop_for_large_collection(
        data_name,
        topic_set,
        topic_id,
        query_files,
        qrel_files,
        doc_id_files,
        doc_text_files,  # data parameters
        sampler_type='HTAPPriorSampler',
        epsilon=0.5,
        beta=-0.1,
        stopping_percentage=None,
        stopping_recall=None,
        target_recall=1.0,
        stopping_condition='strict1',  # autostop parameters
        random_state=0):

    # sampler
    model_name = 'autostop' + '-'
    model_name += 'sp' + str(stopping_percentage) + '-'
    model_name += 'sr' + str(stopping_recall) + '-'
    model_name += 'smp' + str(sampler_type) + '-'

    if sampler_type == 'HTMixtureUniformSampler':
        model_name += 'epsilon' + str(epsilon) + '-'
    elif sampler_type == 'HTPowerLawSampler':
        model_name += 'beta' + str(beta) + '-'
    elif sampler_type == 'HHMixtureUniformSampler':
        model_name += 'epsilon' + str(epsilon) + '-'
    elif sampler_type == 'HHPowerLawSampler':
        model_name += 'beta' + str(beta) + '-'
    else:
        raise NotImplementedError

    model_name += 'tr' + str(target_recall) + '-'
    model_name += 'sc' + stopping_condition
    LOGGER.info('Model configuration: {}.'.format(model_name))

    # starting the TAR process after splitting the complete collection into small blocks
    total_shown_dids = []
    for query_file, qrel_file, doc_id_file, doc_text_file in zip(
            query_files, qrel_files, doc_id_files, doc_text_files):

        # loading data
        assessor = Assessor(query_file, qrel_file, doc_id_file, doc_text_file)
        complete_dids = assessor.get_complete_dids()
        complete_labels = assessor.get_complete_labels()
        complete_pseudo_dids = assessor.get_complete_pseudo_dids()
        complete_pseudo_texts = assessor.get_complete_pseudo_texts()
        did2label = assessor.get_did2label()
        total_true_r = assessor.get_total_rel_num()
        total_num = assessor.get_total_doc_num()

        # ranker
        ranker = Ranker()
        ranker.set_did_2_feature(dids=complete_pseudo_dids,
                                 texts=complete_pseudo_texts,
                                 corpus_texts=complete_pseudo_texts)
        ranker.set_features_by_name('complete_dids', complete_dids)

        if sampler_type == 'HTMixtureUniformSampler':
            sampler = HTMixtureUniformSampler()
            sampler.init(complete_dids, complete_labels)
        elif sampler_type == 'HTUniformSampler':
            sampler = HTUniformSampler()
            sampler.init(complete_dids, complete_labels)
            sampler.update_distribution()
        elif sampler_type == 'HTPowerLawSampler':
            sampler = HTPowerLawSampler()
            sampler.init(beta, complete_dids, complete_labels)
            sampler.update_distribution(beta=beta)
        elif sampler_type == 'HTAPPriorSampler':
            sampler = HTAPPriorSampler()
            sampler.init(complete_dids, complete_labels)
            sampler.update_distribution()
        elif sampler_type == 'HHMixtureUniformSampler':
            sampler = HHMixtureUniformSampler()
            sampler.init(total_num, did2label)
        elif sampler_type == 'HHPowerLawSampler':
            sampler = HHPowerLawSampler()
            sampler.init(total_num, did2label)
            sampler.update_distribution(beta=beta)
        elif sampler_type == 'HHAPPriorSampler':
            sampler = HHAPPriorSampler()
            sampler.init(total_num, did2label)
            sampler.update_distribution()
        else:
            raise TypeError

        # local parameters
        stopping = False
        t = 0
        batch_size = 1
        temp_doc_num = 100

        while not stopping:
            t += 1
            train_dids, train_labels = assessor.get_training_data(temp_doc_num)
            train_features = ranker.get_feature_by_did(train_dids)
            ranker.train(train_features, train_labels)

            test_features = ranker.get_features_by_name('complete_dids')
            scores = ranker.predict(test_features)

            zipped = sorted(zip(complete_dids, scores),
                            key=itemgetter(1),
                            reverse=True)
            ranked_dids, _ = zip(*zipped)

            if sampler_type == 'HHMixtureUniformSampler' or sampler_type == 'HTMixtureUniformSampler':
                sampler.update_distribution(epsilon=epsilon, alpha=batch_size)

            sampled_dids = sampler.sample(t, ranked_dids, batch_size,
                                          stopping_condition)
            assessor.update_assess(sampled_dids)

            sampled_state = assessor.get_assessed_state()
            total_esti_r, var1, var2 = sampler.estimate(
                t, stopping_condition, sampled_state)

            # statistics
            sampled_num = assessor.get_assessed_num()
            running_true_r = assessor.get_assessed_rel_num()

            if total_true_r != 0:
                running_true_recall = running_true_r / float(total_true_r)
            else:
                running_true_recall = 0

            # update parameters
            batch_size += math.ceil(batch_size / 10)

            # autostop
            if running_true_r > 0:
                if stopping_condition == 'loose':
                    if running_true_r >= target_recall * total_esti_r:
                        stopping = True
                elif stopping_condition == 'strict1':
                    if running_true_r >= target_recall * (total_esti_r +
                                                          np.sqrt(var1)):
                        stopping = True
                elif stopping_condition == 'strict2':
                    if running_true_r >= target_recall * (total_esti_r +
                                                          np.sqrt(var2)):
                        stopping = True
                else:
                    raise NotImplementedError

            # stop early
            if stopping_recall:
                if running_true_recall >= stopping_recall:
                    stopping = True

            if stopping_percentage:
                if sampled_num >= int(total_num * stopping_percentage):
                    stopping = True

        sampled_dids = assessor.get_assessed_dids()
        shown_features = ranker.get_feature_by_did(sampled_dids)
        shown_scores = ranker.predict(shown_features)
        zipped = sorted(zip(sampled_dids, shown_scores),
                        key=itemgetter(1),
                        reverse=True)
        shown_dids, scores = zip(*zipped)

        total_shown_dids.extend(shown_dids)

    tar_run_file = name_tar_run_file(data_name=data_name,
                                     model_name=model_name,
                                     topic_set=topic_set,
                                     exp_id=random_state,
                                     topic_id=topic_id)
    with open(tar_run_file, 'w', encoding='utf8') as f:
        write_tar_run_file(f=f,
                           topic_id=topic_id,
                           check_func=lambda x: True,
                           shown_dids=total_shown_dids)

    LOGGER.info('TAR is finished.')

    return
コード例 #2
0
def scal_method(data_name, topic_set, topic_id,
                query_file, qrel_file, doc_id_file, doc_text_file, sub_percentage, bound_bt, ita,  # data parameters
                stopping_percentage=1.0, stopping_recall=1.0, target_recall=1.0,  # autostop parameters
                max_or_min='min', bucket_type='samplerel',  # scal parameters
                random_state=0):
    data_name, topic_id, topic_set, query_file, qrel_file, doc_id_file, doc_text_file,sub_percentage,bound_bt,ita
    """
    Implementation of the S-CAL method [1].

    @param data_name:
    @param topic_set:
    @param topic_id:
    @param stopping_percentage:
    @param stopping_recall:
    @param target_recall:
    @param sub_percentage:
    @param bound_bt:  sample size per batch
    @param max_or_min:
    @param bucket_type:
    @param ita:
    @param random_state:
    @return:
    """
    np.random.seed(random_state)

    # model named with its configuration
    model_name = 'scal' + '-'
    model_name += 'sp' + str(stopping_percentage) + '-'
    model_name += 'sr' + str(stopping_recall) + '-'
    model_name += 'tr' + str(target_recall) + '-'
    model_name += 'spt' + str(sub_percentage) + '-'
    model_name += 'bnd' + str(bound_bt) + '-'
    model_name += 'mxn' + max_or_min + '-'
    model_name += 'bkt' + bucket_type + '-'
    model_name += 'ita' + str(ita)
    LOGGER.info('Model configuration: {}.'.format(model_name))

    # loading data
    datamanager = Assessor(query_file, qrel_file, doc_id_file, doc_text_file)
    complete_dids = datamanager.get_complete_dids()
    complete_pseudo_dids = datamanager.get_complete_pseudo_dids()
    #complete_pseudo_texts = datamanager.get_complete_pseudo_texts()
    #corpus_texts = complete_pseudo_texts
    did2label = datamanager.get_did2label()
    total_true_r = datamanager.get_total_rel_num()
    total_num = datamanager.get_total_doc_num()
    # preparing document features
    ranker = Ranker()
    ranker.set_did_2_feature(dids=complete_pseudo_dids, data_name = data_name)
    ranker.set_features_by_name('complete_dids', complete_dids)

    # SCAL sampler
    sampler = SCALSampler()

    # sampling a large sample set before the TAR process
    if sub_percentage < 1.0:
        u = int(sub_percentage * total_num)
        sub_dids = list(np.random.choice(a=complete_dids, size=u, replace=False).flatten())
    elif sub_percentage == 1.0:
        u = total_num
        sub_dids = complete_dids
    else:
        raise NotImplementedError
    ranker.set_features_by_name('sub_dids', sub_dids)

    # local parameters
    stopping = False
    t = 0
    batch_size = 1
    temp_doc_num = 100
    n = bound_bt

    total_esti_r = 0
    temp_list = []

    # starting the TAR process
    interaction_file = name_interaction_file(data_name=data_name, model_name=model_name, topic_set=topic_set,
                                             exp_id=random_state, topic_id=topic_id)
    with open(interaction_file, 'w', encoding='utf8') as f:
        csvwriter = csv.writer(f)
        while not stopping:
            t += 1

            LOGGER.info('iteration {}, batch_size {}'.format(t, batch_size))

            # train
            train_dids, train_labels = datamanager.get_training_data(temp_doc_num=temp_doc_num)
            train_features = ranker.get_feature_by_did(train_dids)
            ranker.train(train_features, train_labels)

            # predict
            sub_features = ranker.get_features_by_name('sub_dids')
            scores = ranker.predict(sub_features)
            zipped = sorted(zip(sub_dids, scores), key=itemgetter(1), reverse=True)
            ranked_dids, _ = zip(*zipped)

            bucketed_dids, sampled_dids, batch_esti_r = sampler.sample(ranked_dids, n, batch_size, did2label)

            datamanager.update_assess(sampled_dids)

            # estimating
            total_esti_r += batch_esti_r

            # statistics
            sampled_num = datamanager.get_assessed_num()
            running_true_r = datamanager.get_assessed_rel_num()
            if total_esti_r != 0:
                running_esti_recall = running_true_r / float(total_esti_r)
            else:
                running_esti_recall = 0
            if total_true_r != 0:
                running_true_recall = running_true_r / float(total_true_r)
            else:
                running_true_recall = 0
            ap = calculate_ap(did2label, ranked_dids)

            # update parameters
            if batch_size < total_num:  # avoid OverflowError
                batch_size += math.ceil(batch_size / 10)

            # debug: writing values
            csvwriter.writerow(
                (t, batch_size, total_num, sampled_num, total_true_r, total_esti_r, running_true_r, ap, running_esti_recall, running_true_recall))

            cum_bucketed_dids = sampler.get_bucketed_dids()
            cum_sampled_dids = sampler.get_sampled_dids()
            temp_list.append((total_esti_r, ranker, cum_bucketed_dids, cum_sampled_dids))

            # when sub sample is exhausted, stop
            len_bucketed_dids = len(cum_bucketed_dids)
            if len_bucketed_dids == u:
                stopping = True

            # debug: stop early
            if stopping_recall:
                if running_true_recall >= stopping_recall:
                    stopping = True
            if stopping_percentage:
                if sampled_num >= int(total_num * stopping_percentage):
                    stopping = True

    # estimating rho
    final_total_esti_r = ita * total_esti_r  # calibrating the estimation in [1]

    if max_or_min == 'max':
        max_or_min_func = max
    elif max_or_min == 'min':
        max_or_min_func = min
    else:
        raise NotImplementedError

    # finding the first ranker that satisfies the stopping strategy, otherwise using the last ranker
    for total_esti_r, ranker, bucketed_dids, cum_sampled_dids in temp_list:
        if target_recall * final_total_esti_r <= total_esti_r:
            break

    if bucket_type == 'bucket':
        filtered_dids = bucketed_dids
    elif bucket_type == 'sample':
        filtered_dids = cum_sampled_dids
    elif bucket_type == 'samplerel':
        filtered_dids = [did for did in cum_sampled_dids if datamanager.get_rel_label(did) == 1]
    else:
        raise NotImplementedError

    if filtered_dids != []:
        features = ranker.get_feature_by_did(filtered_dids)
        scores = ranker.predict(features)
        threshold = max_or_min_func(scores)
    else:
        threshold = -1

    # rank complete dids
    train_dids, train_labels = datamanager.get_training_data(temp_doc_num=0)
    train_features = ranker.get_feature_by_did(train_dids)
    ranker.train(train_features, train_labels)
    complete_features = ranker.get_feature_by_did(complete_dids)
    complete_scores = ranker.predict(complete_features)
    zipped = sorted(zip(complete_dids, complete_scores), key=itemgetter(1), reverse=True)

    # shown dids
    shown_dids = []
    check_func = datamanager.assess_state_check_func()
    for i, (did, score) in enumerate(zipped):
        if score >= threshold or check_func(did) is True:
            shown_dids.append(did)

    tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set,
                                     exp_id=random_state, topic_id=topic_id)
    with open(tar_run_file, 'w', encoding='utf8') as f:
        write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=shown_dids)

    LOGGER.info('TAR is finished.')

    return
コード例 #3
0
def autostop_method(data_name, topic_set, topic_id,
                    query_file, qrel_file, doc_id_file, doc_text_file, target_recall,  # data parameters
                    sampler_type, stopping_condition, epsilon=0.5, beta=-0.1,
                    stopping_percentage=None, stopping_recall=1.0,  # autostop parameters
                    random_state=0):

    np.random.seed(random_state)

    datamanager = Assessor(query_file, qrel_file, doc_id_file, doc_text_file)
    complete_dids = datamanager.get_complete_dids()
    complete_pseudo_dids = datamanager.get_complete_pseudo_dids()
    #complete_pseudo_texts = datamanager.get_complete_pseudo_texts()
    #corpus_texts = complete_pseudo_texts
    did2label = datamanager.get_did2label()
    total_true_r = datamanager.get_total_rel_num()
    total_num = datamanager.get_total_doc_num()
    complete_labels = datamanager.get_complete_labels()
    # preparing document features
    ranker = Ranker()
    ranker.set_did_2_feature(dids=complete_pseudo_dids, data_name = data_name)
    ranker.set_features_by_name('complete_dids', complete_dids)
    # loading data
    # sampler
    model_name = 'autostop' +'-'
    model_name += 'sp' + str(stopping_percentage) + '-'
    model_name += 'sr' + str(stopping_recall) + '-'
    model_name += 'smp' + str(sampler_type) + '-'

    if sampler_type == 'HTMixtureUniformSampler':
        sampler = HTMixtureUniformSampler()
        sampler.init(complete_dids, complete_labels)
    elif sampler_type == 'HTUniformSampler':
        sampler = HTUniformSampler()
        sampler.init(complete_dids, complete_labels)
        sampler.update_distribution()
    elif sampler_type == 'HTPowerLawSampler':
        sampler = HTPowerLawSampler()
        sampler.init(beta, complete_dids, complete_labels)
        sampler.update_distribution(beta=beta)
    elif sampler_type == 'HTAPPriorSampler':
        sampler = HTAPPriorSampler()
        sampler.init(complete_dids, complete_labels)
        sampler.update_distribution()
    elif sampler_type == 'HHMixtureUniformSampler':
        sampler = HHMixtureUniformSampler()
        sampler.init(total_num, did2label)
    elif sampler_type == 'HHPowerLawSampler':
        sampler = HHPowerLawSampler()
        sampler.init(total_num, did2label)
        sampler.update_distribution(beta=beta)
    elif sampler_type == 'HHAPPriorSampler':
        sampler = HHAPPriorSampler()
        sampler.init(total_num, did2label)
        sampler.update_distribution()
    else:
        print(sampler_type)
        print(stopping_condition)
        raise TypeError

    model_name += 'tr' + str(target_recall) + '-'
    model_name += 'sc' + stopping_condition

    # local parameters
    stopping = False
    t = 0
    batch_size = 1
    temp_doc_num = 100

    # starting the TAR process
    interaction_file = name_interaction_file(data_name=data_name, model_name=model_name, topic_set=topic_set,
                                             exp_id=random_state, topic_id=topic_id)
    with open(interaction_file, 'w', encoding='utf8') as f:
        csvwriter = csv.writer(f)
        while not stopping:
            t += 1
            train_dids, train_labels = datamanager.get_training_data(temp_doc_num)
            train_features = ranker.get_feature_by_did(train_dids)
            ranker.train(train_features, train_labels)

            test_features = ranker.get_features_by_name('complete_dids')
            scores = ranker.predict(test_features)

            zipped = sorted(zip(complete_dids, scores), key=itemgetter(1), reverse=True)
            ranked_dids, _ = zip(*zipped)

            if sampler_type == 'HHMixtureUniformSampler' or sampler_type == 'HTMixtureUniformSampler':
                sampler.update_distribution(epsilon=epsilon, alpha=batch_size)

            sampled_dids = sampler.sample(t, ranked_dids, batch_size, stopping_condition)

            datamanager.update_assess(sampled_dids)

            sampled_state = datamanager.get_assessed_state()

            total_esti_r, var1, var2 = sampler.estimate(t, stopping_condition, sampled_state)

            # statistics
            sampled_num = datamanager.get_assessed_num()
            running_true_r = datamanager.get_assessed_rel_num()

            if total_esti_r != 0:
                running_esti_recall = running_true_r / float(total_esti_r)
            else:
                running_esti_recall = 0
            if total_true_r != 0:
                running_true_recall = running_true_r / float(total_true_r)
            else:
                running_true_recall = 0
            ap = calculate_ap(did2label, ranked_dids)

            # update parameters
            batch_size += math.ceil(batch_size / 10)

            # debug: writing values
            csvwriter.writerow(
                (t, batch_size, total_num, sampled_num, total_true_r, total_esti_r, var1, var2,
                 running_true_r, ap, running_esti_recall, running_true_recall))

            # autostop
            if running_true_r > 0:
                if stopping_condition == 'loose':
                    if running_true_r >= target_recall * total_esti_r:
                        stopping = True
                elif stopping_condition == 'strict1':
                    if running_true_r >= target_recall * (total_esti_r + np.sqrt(var1)):
                        stopping = True
                elif stopping_condition == 'strict2':
                    if running_true_r >= target_recall * (total_esti_r + np.sqrt(var2)):
                        stopping = True
                else:
                    raise NotImplementedError

            # debug: stop early
            if stopping_recall:
                if running_true_recall >= stopping_recall:
                    stopping = True
            if stopping_percentage:
                if sampled_num >= int(total_num * stopping_percentage):
                    stopping = True

    # writing
    sampled_dids = datamanager.get_assessed_dids()
    shown_features = ranker.get_feature_by_did(sampled_dids)
    shown_scores = ranker.predict(shown_features)
    zipped = sorted(zip(sampled_dids, shown_scores), key=itemgetter(1), reverse=True)
    shown_dids, scores = zip(*zipped)

    check_func = datamanager.assess_state_check_func()
    tar_run_file = name_tar_run_file(data_name=data_name, model_name=model_name, topic_set=topic_set,
                                     exp_id=random_state, topic_id=topic_id)
    with open(tar_run_file, 'w', encoding='utf8') as f:
        write_tar_run_file(f=f, topic_id=topic_id, check_func=check_func, shown_dids=shown_dids)

    LOGGER.info('TAR is finished.')
    return
コード例 #4
0
def knee_method(
        data_name,
        topic_set,
        topic_id,
        query_file,
        qrel_file,
        doc_id_file,
        doc_text_file,
        text_metrics,  # data parameters
        stopping_beta=100,
        stopping_percentage=1.0,
        stopping_recall=None,  # autostop parameters
        rho='6',
        random_state=0):
    """
    Implementation of the Knee method.
    See
    @param data_name:
    @param topic_set:
    @param topic_id:
    @param stopping_beta: stopping_beta: only stop TAR process until at least beta documents had been screen
    @param stopping_percentage:
    @param stopping_recall:
    @param rho:
    @param random_state:
    @return:
    """
    np.random.seed(random_state)

    # model named with its configuration
    model_name = 'knee' + '-'
    model_name += 'sb' + str(stopping_beta) + '-'
    model_name += 'sp' + str(stopping_percentage) + '-'
    model_name += 'sr' + str(stopping_recall) + '-'

    model_name += 'rho' + str(rho)
    LOGGER.info('Model configuration: {}.'.format(model_name))

    # loading data
    assessor = Assessor(query_file, qrel_file, doc_id_file, doc_text_file,
                        text_metrics)
    complete_dids = assessor.get_complete_dids()
    complete_pseudo_dids = assessor.get_complete_pseudo_dids()

    complete_pseudo_texts = assessor.get_complete_pseudo_texts()
    complete_pseudo_metrics = assessor.get_complete_pseudo_metrics()
    did2label = assessor.get_did2label()
    total_true_r = assessor.get_total_rel_num()
    total_num = assessor.get_total_doc_num()

    # preparing document features
    ranker = Ranker()
    ranker.set_did_2_feature(dids=complete_pseudo_dids,
                             texts=complete_pseudo_texts,
                             corpus_texts=complete_pseudo_texts,
                             metrics=complete_pseudo_metrics)
    ranker.set_features_by_name('complete_dids', complete_dids)

    # local parameters
    stopping = False
    t = 0
    batch_size = 1
    temp_doc_num = 100
    knee_data = []

    # starting the TAR process
    interaction_file = name_interaction_file(data_name=data_name,
                                             model_name=model_name,
                                             topic_set=topic_set,
                                             exp_id=random_state,
                                             topic_id=topic_id)
    with open(interaction_file, 'w', encoding='utf8') as f:
        csvwriter = csv.writer(f)
        while not stopping:
            t += 1
            LOGGER.info('TAR: iteration={}'.format(t))
            train_dids, train_labels = assessor.get_training_data(temp_doc_num)
            train_features = ranker.get_feature_by_did(train_dids)
            train_features_metrics = ranker.get_feature_metric_by_did(
                train_dids)
            ranker.train(train_features, train_features_metrics, train_labels)

            test_features = ranker.get_features_by_name('complete_dids')

            test_metrics = ranker.get_metrics_by_name(complete_dids)
            scores = ranker.predict(test_features, test_metrics)

            zipped = sorted(zip(complete_dids, scores),
                            key=itemgetter(1),
                            reverse=True)
            ranked_dids, _ = zip(*zipped)

            # cutting off instead of sampling
            selected_dids = assessor.get_top_assessed_dids(
                ranked_dids, batch_size)
            assessor.update_assess(selected_dids)

            # statistics
            sampled_num = assessor.get_assessed_num()
            sampled_percentage = sampled_num / total_num
            running_true_r = assessor.get_assessed_rel_num()
            running_true_recall = running_true_r / float(total_true_r)
            ap = calculate_ap(did2label, ranked_dids)

            # update parameters
            batch_size += math.ceil(batch_size / 10)

            # debug: writing values
            csvwriter.writerow(
                (t, batch_size, total_num, sampled_num, total_true_r,
                 running_true_r, ap, running_true_recall))

            # detect knee
            knee_data.append((sampled_num, running_true_r))
            knee_indice = detect_knee(
                knee_data)  # x: sampled_percentage, y: running_true_r

            if knee_indice is not None:
                knee_index = knee_indice[-1]
                rank1, r1 = knee_data[knee_index]
                rank2, r2 = knee_data[-1]

                try:
                    current_rho = float(r1 / rank1) / float(
                        (r2 - r1 + 1) / (rank2 - rank1))

                except:
                    print(
                        '(rank1, r1) = ({} {}), (rank2, r2) = ({} {})'.format(
                            rank1, r1, rank2, r2))
                    current_rho = 0  # do not stop

                if rho == 'dynamic':
                    rho = 156 - min(running_true_r,
                                    150)  # rho is in [6, 156], see [1]
                else:
                    rho = float(rho)

                if current_rho > rho:
                    if sampled_num > stopping_beta:
                        stopping = True

            # debug: stop early
            if stopping_recall:
                if running_true_recall >= stopping_recall:
                    stopping = True
            if stopping_percentage:
                if sampled_num >= int(total_num * stopping_percentage):
                    stopping = True

    shown_dids = assessor.get_assessed_dids()
    check_func = assessor.assess_state_check_func()
    tar_run_file = name_tar_run_file(data_name=data_name,
                                     model_name=model_name,
                                     topic_set=topic_set,
                                     exp_id=random_state,
                                     topic_id=topic_id)
    with open(tar_run_file, 'w', encoding='utf8') as f:
        write_tar_run_file(f=f,
                           topic_id=topic_id,
                           check_func=check_func,
                           shown_dids=shown_dids)

    LOGGER.info('TAR is finished.')

    return
コード例 #5
0
ファイル: target.py プロジェクト: dli1/auto-stop-tar
def target_method(
        data_name,
        topic_set,
        topic_id,
        query_file,
        qrel_file,
        doc_id_file,
        doc_text_file,  # data parameters
        stopping_percentage=None,
        stopping_recall=None,  # autostop parameters
        target_rel_num=10,  # target parameter
        random_state=0):

    np.random.seed(random_state)

    # model named with its configuration
    model_name = 'target' + '-'
    model_name += 'sp' + str(stopping_percentage) + '-'
    model_name += 'sr' + str(stopping_recall) + '-'
    model_name += 'trn' + str(target_rel_num)
    LOGGER.info('Model configuration: {}.'.format(model_name))

    # loading data
    assessor = Assessor(query_file, qrel_file, doc_id_file, doc_text_file)
    complete_dids = assessor.get_complete_dids()
    complete_pseudo_dids = assessor.get_complete_pseudo_dids()
    complete_pseudo_texts = assessor.get_complete_pseudo_texts()
    did2label = assessor.get_did2label()
    total_true_r = assessor.get_total_rel_num()
    total_num = assessor.get_total_doc_num()

    # preparing document features
    ranker = Ranker()
    ranker.set_did_2_feature(dids=complete_pseudo_dids,
                             texts=complete_pseudo_texts,
                             corpus_texts=complete_pseudo_texts)
    ranker.set_features_by_name('complete_dids', complete_dids)

    # sample target set: the target set should not affect the interaction process.
    population = complete_dids
    target_set = set()
    sample_size = 100
    sample_rel_num = 0
    while not (len(population) == 0 or sample_rel_num >= target_rel_num):
        population = list(set(complete_dids).difference(target_set))
        sample_size = min(len(population), sample_size)
        sampled_dids = set(
            np.random.choice(a=population, size=sample_size,
                             replace=False))  # unique random elements
        sample_rel_num += len(
            [did for did in sampled_dids if assessor.get_rel_label(did) == 1])
        target_set = target_set.union(sampled_dids)

    target_rel_set = set(
        [did for did in target_set if assessor.get_rel_label(did) == 1])

    # starting the TAR process
    stopping = False
    t = 0
    batch_size = 100
    temp_doc_num = 100

    interaction_file = name_interaction_file(data_name=data_name,
                                             model_name=model_name,
                                             topic_set=topic_set,
                                             exp_id=random_state,
                                             topic_id=topic_id)
    with open(interaction_file, 'w', encoding='utf8') as f:
        csvwriter = csv.writer(f)
        while not stopping:
            t += 1

            # train
            train_dids, train_labels = assessor.get_training_data(temp_doc_num)
            train_features = ranker.get_feature_by_did(train_dids)
            ranker.train(train_features, train_labels)

            # predict
            complete_features = ranker.get_features_by_name('complete_dids')
            scores = ranker.predict(complete_features)
            zipped = sorted(zip(complete_dids, scores),
                            key=itemgetter(1),
                            reverse=True)
            ranked_dids, scores = zip(*zipped)

            # cutting off instead of sampling
            selected_dids = assessor.get_top_assessed_dids(
                ranked_dids, batch_size)
            assessor.update_assess(selected_dids)

            # statistics
            sampled_dids = assessor.get_assessed_dids()
            sampled_num = len(set(sampled_dids).union(target_set))
            running_true_r = assessor.get_assessed_rel_num()
            running_true_recall = running_true_r / float(total_true_r)
            ap = calculate_ap(did2label, ranked_dids)

            # update parameters
            batch_size += math.ceil(batch_size / 10)

            # debug: writing values
            csvwriter.writerow(
                (t, batch_size, total_num, sampled_num, total_true_r,
                 running_true_r, ap, running_true_recall))

            sampled_rel_set = set(assessor.get_assessed_rel_dids())

            if set(target_rel_set).issubset(sampled_rel_set):
                stopping = True

            if sampled_num >= total_num:  # dids are exhausted
                stopping = True

            # stop early
            if stopping_recall:
                if running_true_recall >= stopping_recall:
                    stopping = True
            if stopping_percentage:
                if sampled_num >= int(total_num * stopping_percentage):
                    stopping = True

    # INTERACTION METHOD: apply ranked_dids order.
    sampled_dids = assessor.get_assessed_dids()
    shown_dids = list(target_set.union(set(sampled_dids)))

    shown_features = ranker.get_feature_by_did(shown_dids)
    shown_scores = ranker.predict(shown_features)
    zipped = sorted(zip(shown_dids, shown_scores),
                    key=itemgetter(1),
                    reverse=True)
    shown_dids, scores = zip(*zipped)

    check_func = assessor.assess_state_check_func()
    tar_run_file = name_tar_run_file(data_name=data_name,
                                     model_name=model_name,
                                     topic_set=topic_set,
                                     exp_id=random_state,
                                     topic_id=topic_id)
    with open(tar_run_file, 'w', encoding='utf8') as f:
        write_tar_run_file(f=f,
                           topic_id=topic_id,
                           check_func=check_func,
                           shown_dids=shown_dids)

    LOGGER.info('TAR is finished.')

    return
コード例 #6
0
def autotar_method(
        data_name,
        topic_set,
        topic_id,
        query_file,
        qrel_file,
        doc_id_file,
        doc_text_file,  # data parameters
        stopping_percentage=1.0,
        stopping_recall=None,  # autostop parameters, for debug
        ranker_tfidf_corpus_files=[],
        classifier='lr',
        min_df=2,
        C=1.0,  # ranker parameters
        random_state=0):
    """
    Implementation of the TAR process.
    @param data_name: dataset name
    @param topic_set: parameter-tuning set or test set
    @param topic_id: topic id
    @param stopping_percentage: stop TAR when x percentage of documents have been screened
    @param stopping_recall: stop TAR when x recall is achieved
    @param corpus_type: indicates what corpus to use when building features, see Ranker
    @param min_df: parameter of Ranker
    @param C: parameter of Ranker
    @param save_did2feature: save the did2feature dict as a pickle to fasten experiments
    @param random_state: random seed
    @return:
    """
    np.random.seed(random_state)

    # model named with its configuration
    model_name = 'autotar' + '-'
    model_name += 'sp' + str(stopping_percentage) + '-'
    model_name += 'sr' + str(stopping_recall) + '-'
    model_name += 'ct' + str(ranker_tfidf_corpus_files) + '-'
    model_name += 'csf' + classifier + '-'
    model_name += 'md' + str(min_df) + '-'
    model_name += 'c' + str(C)
    LOGGER.info('Model configuration: {}.'.format(model_name))

    # loading data
    datamanager = Assessor(query_file, qrel_file, doc_id_file, doc_text_file)
    complete_dids = datamanager.get_complete_dids()
    complete_pseudo_dids = datamanager.get_complete_pseudo_dids()
    #complete_pseudo_texts = datamanager.get_complete_pseudo_texts()
    #corpus_texts = complete_pseudo_texts
    did2label = datamanager.get_did2label()
    total_true_r = datamanager.get_total_rel_num()
    total_num = datamanager.get_total_doc_num()
    complete_labels = datamanager.get_complete_labels()
    # preparing document features
    ranker = Ranker()
    ranker.set_did_2_feature(dids=complete_pseudo_dids, data_name=data_name)
    ranker.set_features_by_name('complete_dids', complete_dids)

    # local parameters are set according to [1]
    stopping = False
    t = 0
    batch_size = 1
    temp_doc_num = 100

    interaction_file = name_interaction_file(data_name=data_name,
                                             model_name=model_name,
                                             topic_set=topic_set,
                                             exp_id=random_state,
                                             topic_id=topic_id)
    with open(interaction_file, 'w', encoding='utf8') as f:
        csvwriter = csv.writer(f)
        while not stopping:
            t += 1

            train_dids, train_labels = datamanager.get_training_data(
                temp_doc_num)
            train_features = ranker.get_feature_by_did(train_dids)
            ranker.train(train_features, train_labels)

            test_features = ranker.get_features_by_name('complete_dids')
            scores = ranker.predict(test_features)

            zipped = sorted(zip(complete_dids, scores),
                            key=itemgetter(1),
                            reverse=True)
            ranked_dids, scores = zip(*zipped)

            # cutting off instead of sampling
            selected_dids = datamanager.get_top_assessed_dids(
                ranked_dids, batch_size)
            datamanager.update_assess(selected_dids)

            # statistics
            sampled_num = datamanager.get_assessed_num()
            running_true_r = datamanager.get_assessed_rel_num()
            running_true_recall = running_true_r / float(total_true_r)
            ap = calculate_ap(did2label, ranked_dids)

            # update parameters
            batch_size += math.ceil(batch_size / 10)

            # debug: writing values
            csvwriter.writerow(
                (t, batch_size, total_num, sampled_num, total_true_r,
                 running_true_r, ap, running_true_recall))

            # debug: stop early
            if stopping_recall:
                if running_true_recall >= stopping_recall:
                    stopping = True
            if stopping_percentage:
                if sampled_num >= int(total_num * stopping_percentage):
                    stopping = True

    # tar run file
    shown_dids = datamanager.get_assessed_dids()
    check_func = datamanager.assess_state_check_func()
    tar_run_file = name_tar_run_file(data_name=data_name,
                                     model_name=model_name,
                                     topic_set=topic_set,
                                     exp_id=random_state,
                                     topic_id=topic_id)
    with open(tar_run_file, 'w', encoding='utf8') as f:
        write_tar_run_file(f=f,
                           topic_id=topic_id,
                           check_func=check_func,
                           shown_dids=shown_dids)

    LOGGER.info('TAR is finished.')
    return