コード例 #1
0
def test_stop3():
    assert not stop3.is_stop()
    stop3.update_information(example_saver)
    example_saver_local = copy.deepcopy(example_saver)
    assert stop3._accum_cost == 0
    example_saver_local.add_state(
        State(select_index=[2], performance=0.89, cost=[3]))
    stop3.update_information(example_saver_local)
    assert stop3._accum_cost == 3
    assert not stop3.is_stop()
    example_saver_local.add_state(
        State(select_index=[3], performance=0.89, cost=[7]))
    stop3.update_information(example_saver_local)
    assert stop3._accum_cost == 10
    assert stop3.is_stop()
コード例 #2
0
def test_stop4():
    assert not stop4.is_stop()
    stop4.update_information(example_saver)
    example_saver_local = copy.deepcopy(example_saver)
    assert stop4._percent == 0
    example_saver_local.add_state(
        State(select_index=[2], performance=0.89, cost=[3]))
    stop4.update_information(example_saver_local)
    assert stop4._percent == 1 / 6
    assert stop4.is_stop()
コード例 #3
0
def test_stop2():
    assert not stop2.is_stop()
    stop2.update_information(example_saver)
    example_saver_local = copy.deepcopy(example_saver)
    assert stop2._current_iter == 0
    example_saver_local.add_state(State(select_index=[2], performance=0.89))
    stop2.update_information(example_saver_local)
    assert stop2._current_iter == 1
    assert not stop2.is_stop()
    stop2._current_iter = 10
    assert stop2.is_stop()
コード例 #4
0
def target_func(round, train_id, test_id, Lcollection, Ucollection, saver, examples, labels, global_parameters):
    # your query strategy
    qs = QueryInstanceQBC(examples, labels, disagreement='vote_entropy')
    # your model
    reg = linear_model.LogisticRegression()
    reg.fit(X=examples[Lcollection.index, :], y=labels[Lcollection.index])
    # stopping criterion
    while len(Ucollection) > 30:
        select_index = qs.select(Lcollection, Ucollection, reg, n_jobs=1)
        Ucollection.difference_update(select_index)
        Lcollection.update(select_index)

        # update model
        reg.fit(X=examples[Lcollection.index, :], y=labels[Lcollection.index])
        pred = reg.predict(examples[test_id, :])
        accuracy = sum(pred == labels[test_id]) / len(test_id)

        # save intermediate results
        st = State(select_index=select_index, performance=accuracy)
        saver.add_state(st)
        saver.save()      
コード例 #5
0
def run_thread(round, train_id, test_id, Lcollection, Ucollection, saver,
               examples, labels, global_parameters):
    # initialize object
    reg.fit(X=examples[Lcollection.index, :], y=labels[Lcollection.index])
    pred = reg.predict(examples[test_id, :])
    accuracy = sum(pred == labels[test_id]) / len(test_id)
    # initialize StateIO module
    saver.set_initial_point(accuracy)
    while len(Ucollection) > 30:
        select_index = qs.select(Lcollection, Ucollection, reg, n_jobs=1)
        Ucollection.difference_update(select_index)
        Lcollection.update(select_index)

        # update model
        reg.fit(X=examples[Lcollection.index, :], y=labels[Lcollection.index])
        pred = reg.predict(examples[test_id, :])
        accuracy = sum(pred == labels[test_id]) / len(test_id)

        # save intermediate results
        st = State(select_index=select_index, performance=accuracy)
        # add user defined information
        # st.add_element(key='sub_ind', value=sub_ind)
        saver.add_state(st)
        saver.save()
コード例 #6
0
ファイル: StateIO_usage.py プロジェクト: ningkp/acepy
cur_path = os.path.abspath('.')
toolbox = ToolBox(X=X, y=y, query_type='AllLabels', saving_path=cur_path)

# split data
toolbox.split_AL(test_ratio=0.3, initial_label_rate=0.1, split_count=split_count)
train_ind, test_ind, L_ind, U_ind = toolbox.get_split(round=0)
# -------Initialize StateIO----------
saver = StateIO(round=0, train_idx=train_ind, test_idx=test_ind, init_L=L_ind, init_U=U_ind, saving_path='.')
# or by using toolbox 
# saver = toolbox.get_stateio(round=0)

saver.init_L.difference_update([0, 1, 2])
saver.init_U.update([0, 1, 2])

# -------Basic operations------------
st1_batch1 = State(select_index=[1], performance=0.89)
my_value = 'my_entry_info'
st1_batch1.add_element(key='my_entry', value=my_value)
st1_batch2 = State(select_index=[0, 1], performance=0.89)
st2_batch1 = State(select_index=[0], performance=0.89)
st3_batch1 = State(select_index=[2], performance=0.89)

saver.add_state(st1_batch1)
saver.add_state(st1_batch2)
saver.add_state(st2_batch1)

saver.save()

prev_st = saver.get_state(index=1) # get 2nd query
# or use the index operation directly
prev_st = saver[1]
コード例 #7
0

labels = [0, 1 ,0]
cost = [2, 1, 2]
from acepy.oracle import Oracle
oracle = Oracle(labels=labels, cost=cost)

labels, cost = oracle.query_by_index(indexes=[1])

from acepy.experiment import State
st = State(select_index=select_ind, performance=accuracy, cost=cost)

radom_result = [[(1, 0.6), (2, 0.7), (2, 0.8), (1, 0.9)],
                [(1, 0.7), (1, 0.7), (1.5, 0.75), (2.5, 0.85)]]  # 2 folds, 4 queries for each fold.
uncertainty_result = [saver1, saver2]  # each State object in the saver must have the 'cost' entry.
from acepy.experiment import ExperimentAnalyser

analyser = ExperimentAnalyser(x_axis='cost')
analyser.add_method('random', radom_result)
analyser.add_method('uncertainty', uncertainty_result)
コード例 #8
0
    pred = model.predict(X[test_idx, :])
    accuracy = sum(pred == y[test_idx]) / len(test_idx)

    saver.set_initial_point(accuracy)
    while not stopping_criterion.is_stop():
        select_ind = QBCStrategy.select(Lind, Uind, model=model)
        Lind.update(select_ind)
        Uind.difference_update(select_ind)

        # update model and calc performance
        model.fit(X=X[Lind.index, :], y=y[Lind.index])
        pred = model.predict(X[test_idx, :])
        accuracy = sum(pred == y[test_idx]) / len(test_idx)

        # save intermediate result
        st = State(select_index=select_ind, performance=accuracy)
        saver.add_state(st)
        saver.save()

        # update stopping_criteria
        stopping_criterion.update_information(saver)
    stopping_criterion.reset()
    QBC_result.append(copy.deepcopy(saver))

random_result = []
for round in range(split_count):
    train_idx, test_idx, Lind, Uind = acebox.get_split(round)
    saver = acebox.get_stateio(round)

    # calc the initial point
    model.fit(X=X[Lind.index, :], y=y[Lind.index])
コード例 #9
0
from acepy.toolbox import ToolBox as acebox

X, y = load_iris(return_X_y=True)
split_count = 5
cur_path = os.path.abspath('.')
toolbox = acebox(X=X, y=y, query_type='AllLabels', saving_path=cur_path)

# split data
toolbox.split_AL(test_ratio=0.3,
                 initial_label_rate=0.1,
                 split_count=split_count)
saver = toolbox.get_stateio(round=0)
saver.init_L.difference_update([0, 1, 2])
saver.init_U.update([0, 1, 2])

st1_batch2 = State(select_index=[0, 1], performance=0.89)
st1_batch1 = State(select_index=[1], performance=0.89)
st2_batch1 = State(select_index=[0], performance=0.89)
st3_batch1 = State(select_index=[2], performance=0.89)


def test_stateio_validity_checking():
    saver.add_state(st1_batch1)
    saver.add_state(st1_batch2)
    saver.add_state(st2_batch1)
    assert not saver.check_batch_size()
    assert saver.cost_inall == 0
    nq, cost = saver.refresh_info()
    assert nq == 4
    assert cost == 0