示例#1
0
def run(spath_train,
        tpath_train,
        spath_test,
        tpath_test,
        fn_train,
        fn_predict_all,
        max_sentence_length=17,
        replace_unknown_words=True,
        use_bpe=True,
        num_operations=400,
        vocab_threshold=5,
        padding=True,
        model_name='nn'):

    # data preprocessing
    (spath_train_pp, tpath_train_pp, spath_test_pp,
     tpath_test_pp) = preprocess(spath_train, tpath_train, spath_test,
                                 tpath_test, max_sentence_length,
                                 replace_unknown_words, use_bpe,
                                 num_operations, vocab_threshold)

    print(f'Data files preprocessed ...')
    print()

    # data structures for training
    (slang, tlang, index_array_pairs, s_index_arrays_test,
     max_bpe_length) = dp.prepare_data(spath_train_pp, tpath_train_pp,
                                       spath_test_pp, padding)

    print(f'{len(index_array_pairs)} inputs constructed for training ...')
    print()

    # train and return losses for plotting
    (encoder, attn_decoder, plot_losses,
     plot_every) = fn_train(index_array_pairs, slang.n_words, tlang.n_words,
                            max_bpe_length)

    print(f'Training finished ...')
    print()

    # plot the losses
    showLosses(plot_losses, plot_every, f'../output/{model_name}_losses.png')
    print(f'Losses diagram saved in TODO')

    persistence.save(plot_losses,
                     fp.path_to_outputfile(f'{model_name}.tl', '.trainloss'))

    # save models and data
    torch.save(encoder, f'../output/{model_name}_encoder.pt')
    torch.save(attn_decoder, f'../output/{model_name}_attn_decoder.pt')
    data = (s_index_arrays_test, slang, tlang, max_bpe_length)
    persistence.save(data, f'../output/{model_name}_data_run')
    print(f'Models and data saved to disk')
    print()

    _evaluate(s_index_arrays_test, tpath_test_pp, slang, tlang, encoder,
              attn_decoder, fn_predict_all, max_bpe_length, use_bpe,
              model_name)

    return encoder, attn_decoder, slang, tlang, plot_losses, max_bpe_length
def unregister_chat(chat_id):
    global chat_ids
    if chat_id in chat_ids:
        chat_id.remove(chat_id)
        persistence.save(chat_ids, CHAT_IDS_FILENAME)
        push_message_to_chat(chat_id, "stopped spamming you")
    else:
        push_message_to_chat(chat_id, "not spamming you")
def register_chat(chat_id):
    global chat_ids
    if chat_id not in chat_ids:
        chat_ids.append(chat_id)
        persistence.save(chat_ids, CHAT_IDS_FILENAME)
        push_message_to_chat(chat_id, "starting to spam you")
    else:
        push_message_to_chat(chat_id, "already spamming you")
示例#4
0
def convert_end_date():
    job_postings = load()

    altered_job_postings = []
    for post in job_postings:
        post = list(post)
        post[Columns.end_date] = convert_end_date_to_sortable_format(
            post[Columns.end_date])
        altered_job_postings += [tuple(post)]
    save(altered_job_postings)
def check_projects():
    global projects_state
    gl = None
    try:
        gl = gitlab.Gitlab(GITLAB_ADDRESS, private_token=GITLAB_TOKEN)
    except Exception as e:
        print("ERROR: failed to connect to gitlab: {}".format(e))
        traceback.print_exc()
    if (gl != None):
        new_projects_state = get_projects_state(gl)
        events = find_interesting_events(projects_state, new_projects_state)
        if (len(events) > 0):
            print("EVENTS: {}".format(events))
            send_events(events)
        persistence.save(new_projects_state, PROJECT_STATE_FILENAME)
        projects_state = new_projects_state
    def test_consolidate(self):
        path_1 = self.temp_file + "1"
        path_2 = self.temp_file + "2"
        path_3 = self.temp_file + "3"
        path_new = self.temp_file + "4"

        persistence.save(path_1, "a", 1)
        persistence.save(path_2, "b", 2)
        persistence.save(path_3, "c", 3)
        persistence.save(path_3, "d", 4)
        persistence.consolidate([path_1, path_2, path_3], path_new)

        stored_values = set(persistence.values(path_new))
        self.assertEqual(set([1, 2, 3, 4]), stored_values)
    def test_values(self):
        persistence.save(self.temp_file, "a", 1)
        persistence.save(self.temp_file, "b", 2)
        persistence.save(self.temp_file, "c", 3)

        stored_values = set(persistence.values(self.temp_file))
        self.assertEqual(len(stored_values), 3)
        self.assertEqual(set([1, 2, 3]), stored_values)
 def test_get_shelf(self):
     persistence.save(self.temp_file, "a", 1)
     persistence.save(self.temp_file, "b", 2)
     persistence.save(self.temp_file, "c", 3)
     with persistence.get_shelf(self.temp_file) as shelf:
         self.assertEqual(len(shelf), 3)
 def test_save_load_org(self):
     key = "test_key"
     value = org.Organism(3)
     persistence.save(self.temp_file, key, value)
     result = persistence.load(self.temp_file, key)
     self.assertEqual(value, result)
 def test_save_load(self):
     key = "test_key"
     value = 42
     persistence.save(self.temp_file, key, value)
     result = persistence.load(self.temp_file, key)
     self.assertEqual(value, result)