Beispiel #1
0
def _check_save_load_state(preproc, repeat=1, recreate_from_state=False):
    # copy simple attribute states
    simple_state_attrs = ('language', 'stopwords', 'punctuation',
                          'special_chars', 'n_workers', 'tokenized',
                          'pos_tagged', 'ngrams_generated', 'ngrams_as_tokens')
    pre_state = {
        attr: deepcopy(getattr(preproc, attr))
        for attr in simple_state_attrs
    }

    # copy complex attribute states
    pre_state['docs'] = deepcopy(preproc.docs)

    if preproc.tokenized:
        pre_state['tokens'] = preproc.tokens
        pre_state['vocabulary'] = preproc.vocabulary
    if preproc.pos_tagged:
        pre_state['tokens_with_pos_tags'] = preproc.tokens_with_pos_tags
    if preproc.ngrams_generated:
        pre_state['ngrams'] = preproc.ngrams

    # save and then load the same state
    for _ in range(repeat):
        if recreate_from_state:
            preproc.save_state(TMPREPROC_TEMP_STATE_FILE)
            preproc = TMPreproc.from_state(TMPREPROC_TEMP_STATE_FILE)
        else:
            preproc.save_state(TMPREPROC_TEMP_STATE_FILE).load_state(
                TMPREPROC_TEMP_STATE_FILE)

    # check if states are the same now
    for attr in simple_state_attrs:
        assert pre_state[attr] == getattr(preproc, attr)

    assert set(pre_state['docs'].keys()) == set(preproc.docs.keys())
    assert preproc.n_docs == len(pre_state['docs'])
    assert all(pre_state['docs'][k] == preproc.docs[k]
               for k in preproc.docs.keys())

    if preproc.tokenized:
        assert set(pre_state['tokens'].keys()) == set(preproc.tokens.keys())
        assert all(pre_state['tokens'][k] == preproc.tokens[k]
                   for k in preproc.tokens.keys())

        assert pre_state['vocabulary'] == preproc.vocabulary

    if preproc.pos_tagged:
        assert set(pre_state['tokens_with_pos_tags'].keys()) == set(
            preproc.tokens_with_pos_tags.keys())
        assert all(pre_state['tokens_with_pos_tags'][k] ==
                   preproc.tokens_with_pos_tags[k]
                   for k in preproc.tokens_with_pos_tags.keys())

    if preproc.ngrams_generated:
        assert set(pre_state['ngrams'].keys()) == set(preproc.ngrams.keys())
        assert all(pre_state['ngrams'][k] == preproc.ngrams[k]
                   for k in preproc.ngrams.keys())
preproc.pos_tag()
add_timing('pos_tag')

preproc.lemmatize()
add_timing('lemmatize')

preproc_copy = preproc.copy()
preproc_copy.shutdown_workers()
del preproc_copy
add_timing('copy')

_, statepickle = mkstemp('.pickle')
preproc.save_state(statepickle)
add_timing('save_state')

preproc_copy = TMPreproc.from_state(statepickle)
preproc_copy.shutdown_workers()
del preproc_copy
add_timing('from_state')

preproc_copy = TMPreproc.from_tokens(preproc.tokens_with_metadata,
                                     language='en')
preproc_copy.shutdown_workers()
del preproc_copy
add_timing('from_tokens')

preproc_copy = TMPreproc.from_tokens_datatable(preproc.tokens_datatable,
                                               language='en')
preproc_copy.shutdown_workers()
del preproc_copy
add_timing('from_tokens_datatable')